#!/usr/bin/env python3
"""snmptrapd handler — POSTs every received trap to the Vexor events API.

snmptrapd invokes this script with the trap on stdin in the default
"traphandle" format:

    <source-host>
    <source-ip-addr-and-port>
    <varbind1 oid> <varbind1 value>
    <varbind2 oid> <varbind2 value>
    ...

The second varbind is conventionally SNMPv2-MIB::snmpTrapOID.0, whose value
identifies the trap type and is what we match matchers against.

Severity is classified via the snmp_trap_matchers table (configured in the
Vexor UI) and falls back to "warning" otherwise.
"""
from __future__ import annotations

import fnmatch
import json
import sys
import urllib.request
import urllib.error
from pathlib import Path

API_URL = "http://127.0.0.1:8080/api/v1/events"
TOKEN_FILE = "/etc/vexor/notify-token"


def _token() -> str:
    try:
        return Path(TOKEN_FILE).read_text().strip()
    except OSError:
        return ""


def _fetch_matchers() -> list[dict]:
    import subprocess
    try:
        out = subprocess.check_output([
            "mysql", "-uvexor", "-pvexor2026", "-N", "-B", "-e",
            "SELECT oid_pattern, severity, facility, enabled "
            "FROM snmp_trap_matchers WHERE enabled = 1", "vexor",
        ], stderr=subprocess.DEVNULL, timeout=2)
        rows = []
        for line in out.decode().splitlines():
            parts = line.split("\t")
            if len(parts) >= 4:
                rows.append({"oid_pattern": parts[0], "severity": parts[1],
                             "facility": parts[2]})
        return rows
    except Exception:
        return []


def _classify(trap_oid: str, matchers: list[dict]) -> tuple[str, str]:
    for m in matchers:
        pat = m["oid_pattern"]
        if pat == trap_oid or fnmatch.fnmatch(trap_oid, pat):
            return m["severity"], m["facility"]
    return "warning", "snmp-trap"


def _parse_stdin() -> dict:
    lines = sys.stdin.read().splitlines()
    host = lines[0].strip() if lines else "unknown"
    source_ip = lines[1].strip() if len(lines) > 1 else ""
    varbinds: list[tuple[str, str]] = []
    trap_oid = ""
    for ln in lines[2:]:
        ln = ln.strip()
        if not ln:
            continue
        parts = ln.split(None, 1)
        if not parts:
            continue
        oid = parts[0]
        val = parts[1] if len(parts) > 1 else ""
        varbinds.append((oid, val))
        if oid.endswith("snmpTrapOID.0") or oid == "1.3.6.1.6.3.1.1.4.1.0":
            trap_oid = val.strip()
    return {"host": host, "source_ip": source_ip,
            "trap_oid": trap_oid, "varbinds": varbinds}


def _post(payload: dict, token: str) -> None:
    body = json.dumps(payload).encode("utf-8")
    req = urllib.request.Request(
        API_URL, data=body, method="POST",
        headers={"Content-Type": "application/json",
                 "X-Internal-Token": token},
    )
    try:
        urllib.request.urlopen(req, timeout=5).read()
    except urllib.error.URLError as exc:
        sys.stderr.write(f"vexor-trap-handler: POST failed: {exc}\n")


def main() -> int:
    trap = _parse_stdin()
    matchers = _fetch_matchers()
    severity, facility = _classify(trap["trap_oid"], matchers)
    summary_parts = [trap["trap_oid"] or "unknown-trap"]
    for oid, val in trap["varbinds"][:5]:
        summary_parts.append(f"{oid.rsplit('::', 1)[-1]}={val}")
    message = " | ".join(summary_parts)
    payload = {
        "source": "snmp-trap",
        "host": trap["host"] or trap["source_ip"] or "unknown",
        "severity": severity,
        "facility": facility,
        "message": message[:2000],
    }
    token = _token()
    if not token:
        sys.stderr.write("vexor-trap-handler: no internal token configured\n")
        return 1
    _post(payload, token)
    return 0


if __name__ == "__main__":
    sys.exit(main())
