diff --git a/strix/config/config.py b/strix/config/config.py index 7578b61d3..d432a5857 100644 --- a/strix/config/config.py +++ b/strix/config/config.py @@ -48,6 +48,10 @@ class Config: # Telemetry strix_telemetry = "1" + # Webhook + strix_webhook_url: str | None = None + strix_webhook_format = "generic" + # Config file override (set via --config CLI arg) _config_file_override: Path | None = None diff --git a/strix/interface/main.py b/strix/interface/main.py index 33785e678..395b05ad4 100644 --- a/strix/interface/main.py +++ b/strix/interface/main.py @@ -40,6 +40,7 @@ validate_config_file, validate_llm_response, ) +from strix.interface.webhooks import send_completion_webhook # noqa: E402 from strix.runtime.docker_runtime import HOST_GATEWAY_HOSTNAME # noqa: E402 from strix.telemetry import posthog # noqa: E402 from strix.telemetry.tracer import get_global_tracer # noqa: E402 @@ -366,6 +367,23 @@ def parse_arguments() -> argparse.Namespace: help="Path to a custom config file (JSON) to use instead of ~/.strix/cli-config.json", ) + parser.add_argument( + "--webhook-url", + type=str, + default=None, + help="URL to send scan results to on completion. " + "Supports generic JSON endpoints, Slack incoming webhooks, and Discord webhooks.", + ) + + parser.add_argument( + "--webhook-format", + type=str, + choices=["generic", "slack", "discord"], + default="generic", + help="Webhook payload format (default: generic). " + "Auto-detected from the URL when set to 'generic'.", + ) + args = parser.parse_args() if args.instruction and args.instruction_file: @@ -520,7 +538,7 @@ def persist_config() -> None: save_current_config() -def main() -> None: +def main() -> None: # noqa: PLR0912 if sys.platform == "win32": asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) @@ -578,6 +596,16 @@ def main() -> None: results_path = Path("strix_runs") / args.run_name display_completion_message(args, results_path) + webhook_url = args.webhook_url or Config.get("strix_webhook_url") + if webhook_url: + webhook_format = args.webhook_format or Config.get("strix_webhook_format") or "generic" + send_completion_webhook( + webhook_url=webhook_url, + webhook_format=webhook_format, + tracer=tracer, + args=args, + ) + if args.non_interactive: tracer = get_global_tracer() if tracer and tracer.vulnerability_reports: diff --git a/strix/interface/webhooks.py b/strix/interface/webhooks.py new file mode 100644 index 000000000..6b0a7c07a --- /dev/null +++ b/strix/interface/webhooks.py @@ -0,0 +1,291 @@ +"""Webhook dispatcher for scan completion notifications. + +Sends scan results to external services (Slack, Discord, or generic JSON endpoints) +when a penetration test completes. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Any +from urllib.parse import urlparse + +import requests + + +if TYPE_CHECKING: + import argparse + + +logger = logging.getLogger(__name__) + +WEBHOOK_TIMEOUT = 10 + +# Platform limits for field truncation +_SLACK_SECTION_TEXT_LIMIT = 3000 +_DISCORD_FIELD_VALUE_LIMIT = 1024 +_DISCORD_TITLE_LIMIT = 256 + + +def _truncate(text: str, limit: int) -> str: + """Truncate *text* to *limit* characters, appending an ellipsis if trimmed.""" + if len(text) <= limit: + return text + return text[: limit - 1] + "\u2026" + + +def send_completion_webhook( + webhook_url: str, + webhook_format: str, + tracer: Any, + args: argparse.Namespace, +) -> None: + """Send scan completion results to a webhook URL. + + Args: + webhook_url: The destination webhook URL. + webhook_format: One of ``"generic"``, ``"slack"``, or ``"discord"``. + tracer: The global :class:`Tracer` instance containing scan results. + args: Parsed CLI arguments (used to extract target info and run name). + """ + # Validate URL scheme + parsed = urlparse(webhook_url) + if parsed.scheme not in ("http", "https"): + logger.warning("Invalid webhook URL scheme %r — skipping delivery", parsed.scheme) + return + + if not tracer: + logger.warning("No tracer available — skipping webhook delivery") + return + + resolved_format = _resolve_format(webhook_url, webhook_format) + + formatters: dict[str, Any] = { + "generic": _format_generic, + "slack": _format_slack, + "discord": _format_discord, + } + + formatter = formatters.get(resolved_format, _format_generic) + payload = formatter(tracer, args) + + try: + response = requests.post(webhook_url, json=payload, timeout=WEBHOOK_TIMEOUT) + response.raise_for_status() + logger.info( + "Webhook delivered successfully to %s (status %s)", + webhook_url, + response.status_code, + ) + except requests.RequestException as exc: + logger.warning("Failed to deliver webhook to %s: %s", webhook_url, exc) + + +# --------------------------------------------------------------------------- +# Format resolution +# --------------------------------------------------------------------------- + + +def _resolve_format(url: str, explicit_format: str) -> str: + """Auto-detect the webhook format from the URL when the user chose ``"generic"``.""" + if explicit_format != "generic": + return explicit_format + + host = urlparse(url).hostname or "" + if "hooks.slack.com" in host: + return "slack" + if "discord.com" in host or "discordapp.com" in host: + return "discord" + + return "generic" + + +# --------------------------------------------------------------------------- +# Payload helpers +# --------------------------------------------------------------------------- + + +def _targets_summary(args: argparse.Namespace) -> str: + targets_info: list[dict[str, Any]] = getattr(args, "targets_info", []) + if not targets_info: + return "unknown" + return ", ".join(t.get("original", "unknown") for t in targets_info) + + +def _vulnerability_summary(tracer: Any) -> list[dict[str, Any]]: + """Return a lightweight list of vulnerability dicts safe for JSON serialisation.""" + if not tracer: + return [] + return [ + { + "id": report.get("id", ""), + "title": report.get("title", ""), + "severity": report.get("severity", ""), + "cvss": report.get("cvss"), + "target": report.get("target", ""), + "endpoint": report.get("endpoint", ""), + "description": report.get("description", ""), + } + for report in tracer.vulnerability_reports + ] + + +def _severity_counts(tracer: Any) -> dict[str, int]: + counts: dict[str, int] = {"critical": 0, "high": 0, "medium": 0, "low": 0, "info": 0} + if not tracer: + return counts + for report in tracer.vulnerability_reports: + severity = report.get("severity", "").lower() + if severity in counts: + counts[severity] += 1 + return counts + + +def _scan_completed(tracer: Any) -> bool: + if tracer and tracer.scan_results: + return bool(tracer.scan_results.get("scan_completed", False)) + return False + + +# --------------------------------------------------------------------------- +# Formatters +# --------------------------------------------------------------------------- + + +def _format_generic(tracer: Any, args: argparse.Namespace) -> dict[str, Any]: + """Plain JSON payload with full scan data.""" + completed = _scan_completed(tracer) + llm_stats = tracer.get_total_llm_stats()["total"] if tracer else {} + vuln_reports = tracer.vulnerability_reports if tracer else [] + return { + "event": "scan_completed" if completed else "scan_ended", + "run_name": getattr(args, "run_name", ""), + "targets": _targets_summary(args), + "scan_mode": getattr(args, "scan_mode", ""), + "completed": completed, + "vulnerability_count": len(vuln_reports), + "severity_counts": _severity_counts(tracer), + "vulnerabilities": _vulnerability_summary(tracer), + "stats": { + "agents": len(tracer.agents) if tracer else 0, + "tools": tracer.get_real_tool_count() if tracer else 0, + "input_tokens": llm_stats.get("input_tokens", 0), + "output_tokens": llm_stats.get("output_tokens", 0), + "cost": llm_stats.get("cost", 0), + }, + } + + +def _format_slack(tracer: Any, args: argparse.Namespace) -> dict[str, Any]: + """Slack Block Kit payload.""" + completed = _scan_completed(tracer) + vuln_reports = tracer.vulnerability_reports if tracer else [] + vuln_count = len(vuln_reports) + counts = _severity_counts(tracer) + + status_emoji = ":white_check_mark:" if completed else ":warning:" + status_text = "completed" if completed else "ended" + + severity_line = ( + " | ".join(f"*{sev.upper()}*: {cnt}" for sev, cnt in counts.items() if cnt > 0) + or "None found" + ) + + blocks: list[dict[str, Any]] = [ + { + "type": "header", + "text": { + "type": "plain_text", + "text": f"{status_emoji} Strix Scan {status_text.title()}", + "emoji": True, + }, + }, + { + "type": "section", + "fields": [ + {"type": "mrkdwn", "text": f"*Target:*\n{_targets_summary(args)}"}, + {"type": "mrkdwn", "text": f"*Run:*\n{getattr(args, 'run_name', 'N/A')}"}, + {"type": "mrkdwn", "text": f"*Scan Mode:*\n{getattr(args, 'scan_mode', 'N/A')}"}, + {"type": "mrkdwn", "text": f"*Vulnerabilities:*\n{vuln_count}"}, + ], + }, + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": f"*Severity Breakdown:* {severity_line}", + }, + }, + ] + + # Add top vulnerabilities (max 5) + for report in vuln_reports[:5]: + title = _truncate(report.get("title", "Untitled"), 200) + severity = report.get("severity", "unknown").upper() + endpoint = report.get("endpoint", "") + text = f":rotating_light: *[{severity}]* {title}" + if endpoint: + text += f"\n`{_truncate(endpoint, 200)}`" + text = _truncate(text, _SLACK_SECTION_TEXT_LIMIT) + blocks.append( + { + "type": "section", + "text": {"type": "mrkdwn", "text": text}, + } + ) + + return {"blocks": blocks} + + +def _format_discord(tracer: Any, args: argparse.Namespace) -> dict[str, Any]: + """Discord webhook payload with an embed.""" + completed = _scan_completed(tracer) + vuln_reports = tracer.vulnerability_reports if tracer else [] + vuln_count = len(vuln_reports) + counts = _severity_counts(tracer) + + color = 0x22C55E if completed else 0xEAB308 # green / yellow + if counts["critical"] > 0: + color = 0xDC2626 + elif counts["high"] > 0: + color = 0xEA580C + + severity_line = ( + " | ".join(f"**{sev.upper()}**: {cnt}" for sev, cnt in counts.items() if cnt > 0) + or "None found" + ) + + fields: list[dict[str, Any]] = [ + { + "name": "Target", + "value": _truncate(_targets_summary(args), _DISCORD_FIELD_VALUE_LIMIT), + "inline": True, + }, + {"name": "Scan Mode", "value": getattr(args, "scan_mode", "N/A"), "inline": True}, + {"name": "Vulnerabilities", "value": str(vuln_count), "inline": True}, + {"name": "Severity Breakdown", "value": severity_line, "inline": False}, + ] + + # Top vulnerabilities (max 5) + for report in vuln_reports[:5]: + title = _truncate(report.get("title", "Untitled"), 200) + severity = report.get("severity", "unknown").upper() + endpoint = report.get("endpoint", "") + value = f"**[{severity}]** {title}" + if endpoint: + value += f"\n`{_truncate(endpoint, 200)}`" + value = _truncate(value, _DISCORD_FIELD_VALUE_LIMIT) + fields.append({"name": "\u200b", "value": value, "inline": False}) + + status_text = "Scan Completed" if completed else "Scan Ended" + run_name = _truncate(getattr(args, "run_name", "N/A"), _DISCORD_TITLE_LIMIT) + + embed: dict[str, Any] = { + "title": _truncate(f"\ud83d\udd12 Strix \u2014 {status_text}", _DISCORD_TITLE_LIMIT), + "description": f"Run: **{run_name}**", + "color": color, + "fields": fields, + "footer": {"text": "Strix Security Scanner"}, + } + + return {"embeds": [embed]} diff --git a/tests/interface/test_webhooks.py b/tests/interface/test_webhooks.py new file mode 100644 index 000000000..22a578f4c --- /dev/null +++ b/tests/interface/test_webhooks.py @@ -0,0 +1,336 @@ +"""Tests for the webhook dispatcher module.""" + +from __future__ import annotations + +import argparse +from typing import Any +from unittest.mock import MagicMock, patch + +import requests + +from strix.interface.webhooks import ( + _format_discord, + _format_generic, + _format_slack, + _resolve_format, + _severity_counts, + _targets_summary, + _truncate, + _vulnerability_summary, + send_completion_webhook, +) + + +def _make_tracer( + vulnerability_reports: list[dict[str, Any]] | None = None, + scan_completed: bool = True, +) -> MagicMock: + """Create a mock tracer with configurable vulnerability reports.""" + tracer = MagicMock() + tracer.vulnerability_reports = vulnerability_reports or [] + tracer.scan_results = {"scan_completed": scan_completed} + tracer.agents = {"agent-1": {}, "agent-2": {}} + tracer.get_real_tool_count.return_value = 5 + tracer.get_total_llm_stats.return_value = { + "total": { + "input_tokens": 1000, + "output_tokens": 500, + "cost": 0.05, + "requests": 3, + "cached_tokens": 200, + } + } + return tracer + + +def _make_args( + targets_info: list[dict[str, Any]] | None = None, + run_name: str = "test-run_abcd", + scan_mode: str = "deep", +) -> argparse.Namespace: + """Create a mock args namespace.""" + default_targets: list[dict[str, Any]] = [ + {"original": "https://example.com", "type": "web_application"}, + ] + return argparse.Namespace( + targets_info=targets_info if targets_info is not None else default_targets, + run_name=run_name, + scan_mode=scan_mode, + ) + + +SAMPLE_VULNS: list[dict[str, Any]] = [ + { + "id": "VULN-001", + "title": "SQL Injection in login endpoint", + "severity": "critical", + "cvss": 9.8, + "target": "https://example.com", + "endpoint": "/api/login", + "description": "Unsanitised input allows SQL injection.", + }, + { + "id": "VULN-002", + "title": "Reflected XSS", + "severity": "high", + "cvss": 7.1, + "target": "https://example.com", + "endpoint": "/search?q=