#!/usr/bin/python3
"""check_wmi_simple — Nagios-compatible WMI plugin via impacket.

Replaces the missing samba `wmic` binary on RHEL 10 / Rocky 10. Uses
impacket's DCOM/WMI implementation (no extra deps beyond impacket which
is already in /usr/local/bin/wmiquery.py).

Modes:
  cpu        — Win32_PerfFormattedData_PerfOS_Processor _Total
  memory     — Win32_OperatingSystem FreePhysicalMemory / TotalVisibleMemorySize
  disk       — Win32_LogicalDisk by DeviceID (e.g. C:)
  service    — Win32_Service by Name, OK only if state == Running
  uptime     — seconds since LastBootUpTime
  falcon     — CrowdStrike Falcon sensor health (CSFalconService / CSAgent)
  wql        — run an arbitrary WQL query and print the first row's columns

Exit codes follow Nagios conventions: 0 OK, 1 WARNING, 2 CRITICAL, 3 UNKNOWN.

Threshold semantics:
  cpu / memory / disk — percent in use; -w 80 -c 90 → warn at >=80%, crit >=90%
  service             — thresholds ignored; CRIT if not Running
  uptime              — seconds; warn if uptime <= -w, crit if uptime <= -c
"""
from __future__ import annotations

import argparse
import re
import sys
from datetime import datetime, timezone

OK, WARN, CRIT, UNKNOWN = 0, 1, 2, 3


def out(state: int, msg: str, perf: str = "") -> "None":
    label = {OK: "OK", WARN: "WARNING", CRIT: "CRITICAL", UNKNOWN: "UNKNOWN"}[state]
    line = f"WMI {label} - {msg}"
    if perf:
        line += f"|{perf}"
    print(line)
    sys.exit(state)


def _connect(host: str, user: str, password: str, domain: str):
    """Return (DCOMConnection, IWbemServices). Caller must release."""
    try:
        from impacket.dcerpc.v5.dcomrt import DCOMConnection
        from impacket.dcerpc.v5.dcom.wmi import (
            CLSID_WbemLevel1Login, IID_IWbemLevel1Login, IWbemLevel1Login,
        )
        from impacket.dcerpc.v5.dtypes import NULL
    except ImportError as e:
        out(UNKNOWN, f"impacket missing: {e}")

    try:
        dcom = DCOMConnection(
            host, user, password, domain or "",
            lmhash="", nthash="", aesKey="",
            oxidResolver=True, doKerberos=False,
        )
        iface = dcom.CoCreateInstanceEx(CLSID_WbemLevel1Login, IID_IWbemLevel1Login)
        login = IWbemLevel1Login(iface)
        services = login.NTLMLogin("//./root/cimv2", NULL, NULL)
        login.RemRelease()
        return dcom, services
    except Exception as e:
        try:
            dcom.disconnect()  # type: ignore[name-defined]
        except Exception:
            pass
        out(CRIT, f"WMI connect failed: {type(e).__name__}: {e}")


def _query(services, wql: str) -> "list[dict]":
    rows: list[dict] = []
    try:
        enumr = services.ExecQuery(wql.strip())
        while True:
            try:
                obj = enumr.Next(0xffffffff, 1)[0]
            except Exception:
                break
            try:
                rows.append(obj.getProperties())
            except Exception:
                pass
        try:
            enumr.RemRelease()
        except Exception:
            pass
    except Exception as e:
        out(CRIT, f"WQL failed: {type(e).__name__}: {e} -- query={wql!r}")
    return rows


def _prop(row: dict, name: str):
    """impacket returns {name: {value: x, type: ...}} — flatten."""
    if name not in row:
        return None
    v = row[name]
    if isinstance(v, dict) and "value" in v:
        return v["value"]
    return v


def _bytes_h(n: float) -> str:
    for u in ("B", "KB", "MB", "GB", "TB"):
        if n < 1024:
            return f"{n:.1f}{u}"
        n /= 1024
    return f"{n:.1f}PB"


# ---------------------------------------------------------------------------
# Mode handlers
# ---------------------------------------------------------------------------

def mode_cpu(services, warn: float, crit: float):
    rows = _query(services,
        "SELECT PercentProcessorTime FROM Win32_PerfFormattedData_PerfOS_Processor "
        "WHERE Name='_Total'")
    if not rows:
        out(UNKNOWN, "no CPU data returned")
    pct = float(_prop(rows[0], "PercentProcessorTime") or 0)
    state = CRIT if pct >= crit else WARN if pct >= warn else OK
    out(state, f"CPU usage {pct:.1f}%",
        f"cpu={pct:.1f}%;{warn};{crit};0;100")


def mode_memory(services, warn: float, crit: float):
    rows = _query(services,
        "SELECT FreePhysicalMemory, TotalVisibleMemorySize FROM Win32_OperatingSystem")
    if not rows:
        out(UNKNOWN, "no OS data returned")
    free_kb = float(_prop(rows[0], "FreePhysicalMemory") or 0)
    total_kb = float(_prop(rows[0], "TotalVisibleMemorySize") or 0)
    if total_kb <= 0:
        out(UNKNOWN, "total memory is zero")
    used_pct = (1 - free_kb / total_kb) * 100
    state = CRIT if used_pct >= crit else WARN if used_pct >= warn else OK
    free_h = _bytes_h(free_kb * 1024)
    total_h = _bytes_h(total_kb * 1024)
    out(state,
        f"Memory {used_pct:.1f}% used ({free_h} free of {total_h})",
        f"used_pct={used_pct:.1f}%;{warn};{crit};0;100 "
        f"free_bytes={int(free_kb * 1024)}B "
        f"total_bytes={int(total_kb * 1024)}B")


def mode_disk(services, drive: str, warn: float, crit: float):
    drive_id = drive.rstrip("\\").rstrip("/")
    if not drive_id.endswith(":"):
        drive_id += ":"
    rows = _query(services,
        f"SELECT DeviceID, FreeSpace, Size FROM Win32_LogicalDisk WHERE DeviceID='{drive_id}'")
    if not rows:
        out(UNKNOWN, f"drive {drive_id} not found")
    free = float(_prop(rows[0], "FreeSpace") or 0)
    size = float(_prop(rows[0], "Size") or 0)
    if size <= 0:
        out(UNKNOWN, f"drive {drive_id} reports zero size")
    used_pct = (1 - free / size) * 100
    state = CRIT if used_pct >= crit else WARN if used_pct >= warn else OK
    out(state,
        f"Drive {drive_id} {used_pct:.1f}% used ({_bytes_h(free)} free of {_bytes_h(size)})",
        f"used_pct={used_pct:.1f}%;{warn};{crit};0;100 "
        f"free_bytes={int(free)}B size_bytes={int(size)}B")


def mode_service(services, name: str):
    if not name:
        out(UNKNOWN, "-a SERVICE_NAME is required for mode=service")
    rows = _query(services,
        f"SELECT Name, State, DisplayName FROM Win32_Service WHERE Name='{name}'")
    if not rows:
        out(CRIT, f"service '{name}' not found")
    state_str = str(_prop(rows[0], "State") or "Unknown")
    display = str(_prop(rows[0], "DisplayName") or name)
    state = OK if state_str == "Running" else CRIT
    out(state, f"{display} ({name}) is {state_str}")


def mode_uptime(services, warn: float, crit: float):
    rows = _query(services, "SELECT LastBootUpTime FROM Win32_OperatingSystem")
    if not rows:
        out(UNKNOWN, "no OS data returned")
    raw = str(_prop(rows[0], "LastBootUpTime") or "")
    try:
        ts = raw.split(".")[0]
        boot = datetime.strptime(ts, "%Y%m%d%H%M%S").replace(tzinfo=timezone.utc)
    except Exception:
        out(UNKNOWN, f"could not parse LastBootUpTime={raw!r}")
    uptime_s = (datetime.now(timezone.utc) - boot).total_seconds()
    state = OK
    if crit > 0 and uptime_s <= crit:
        state = CRIT
    elif warn > 0 and uptime_s <= warn:
        state = WARN
    days = uptime_s / 86400
    out(state, f"uptime {days:.1f} days ({int(uptime_s)} s)",
        f"uptime={int(uptime_s)}s;{int(warn) if warn else ''};{int(crit) if crit else ''};0;")


def mode_falcon(services):
    """CrowdStrike Falcon sensor health on Windows.

    Checks Win32_Service for CSFalconService (modern) / CSAgent (legacy).
    Best-effort CID lookup via Win32_Process command-line arguments — the
    actual CID lives in the registry but reading it remotely requires RRP
    which we don't pull in. If we can't extract it we just report the
    service state.
    """
    svc_rows = _query(services,
        "SELECT Name, State, DisplayName, PathName FROM Win32_Service "
        "WHERE Name='CSFalconService' OR Name='CSAgent'")
    if not svc_rows:
        out(CRIT, "CrowdStrike Falcon sensor not installed "
                  "(neither CSFalconService nor CSAgent found)")
    row = svc_rows[0]
    name = str(_prop(row, "Name") or "")
    state_str = str(_prop(row, "State") or "Unknown")
    display = str(_prop(row, "DisplayName") or name)
    path = str(_prop(row, "PathName") or "")

    # Try to find version from the binary path component (CSFalconService.exe
    # lives under C:\Program Files\CrowdStrike\<version>\...). Best-effort.
    version = ""
    for part in path.replace('"', "").split("\\"):
        if part and part[0].isdigit() and "." in part:
            version = part
            break

    if state_str != "Running":
        out(CRIT, f"{display} ({name}) is {state_str}"
                  + (f" — version {version}" if version else ""),
            "falcon_running=0")

    summary = f"{display} ({name}) running"
    if version:
        summary += f", version {version}"
    out(OK, summary, "falcon_running=1")


_WQL_ARG_RE = re.compile(r"^[A-Za-z0-9_:./\\s\-=\\'\"<>%(),*]+$")


def mode_wql(services, wql: str):
    if wql and not _WQL_ARG_RE.match(wql):
        out(UNKNOWN, "invalid characters in --wql query")
    if not wql:
        out(UNKNOWN, "--wql QUERY is required for mode=wql")
    rows = _query(services, wql)
    if not rows:
        out(WARN, "query returned no rows")
    first = rows[0]
    parts = [f"{k}={_prop(first, k)}" for k in first.keys()]
    out(OK, f"{len(rows)} row(s): " + ", ".join(parts))


def main():
    p = argparse.ArgumentParser(prog="check_wmi_simple",
                                description=__doc__.split("\n\n")[0])
    p.add_argument("-H", "--host", required=True)
    p.add_argument("-u", "--user", required=True)
    p.add_argument("-p", "--password", default="")
    p.add_argument("-P", "--password-file", default=None,
                   help="Read password from file (first line)")
    p.add_argument("-d", "--domain", default="")
    p.add_argument("-m", "--mode", required=True,
                   choices=("cpu", "memory", "disk", "service", "uptime", "wql", "falcon"))
    p.add_argument("-w", "--warning", type=float, default=80.0)
    p.add_argument("-c", "--critical", type=float, default=90.0)
    p.add_argument("-a", "--argument", default="")
    p.add_argument("--wql", default="")
    args = p.parse_args()

    if args.password_file:
        try:
            with open(args.password_file) as f:
                args.password = f.readline().rstrip("\n")
        except Exception as e:
            out(UNKNOWN, f"could not read password file: {e}")

    dcom, services = _connect(args.host, args.user, args.password, args.domain)
    try:
        if args.mode == "cpu":
            mode_cpu(services, args.warning, args.critical)
        elif args.mode == "memory":
            mode_memory(services, args.warning, args.critical)
        elif args.mode == "disk":
            mode_disk(services, args.argument or "C:", args.warning, args.critical)
        elif args.mode == "service":
            mode_service(services, args.argument)
        elif args.mode == "uptime":
            mode_uptime(services, args.warning, args.critical)
        elif args.mode == "wql":
            mode_wql(services, args.wql)
        elif args.mode == "falcon":
            mode_falcon(services)
    finally:
        try:
            services.RemRelease()
        except Exception:
            pass
        try:
            dcom.disconnect()
        except Exception:
            pass


if __name__ == "__main__":
    main()
