diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2b07e57a9..338e0208f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -62,6 +62,7 @@ jobs: PERF_COLD_START_BUDGET_S: ${{ vars.PERF_COLD_START_BUDGET_S || '3' }} PERF_COLD_START_SAMPLES: ${{ vars.PERF_COLD_START_SAMPLES || '3' }} PERF_COLD_START_TIMEOUT_S: ${{ vars.PERF_COLD_START_TIMEOUT_S || '90' }} + GITHUB_REF_VALUE: ${{ github.ref }} run: | set -o pipefail python -m pytest tests/benchmarks/ --benchmark-only \ @@ -90,7 +91,7 @@ jobs: # cold-start startup time. python scripts/perf_budget_gate.py --bench-json /tmp/bench-current.json # On main branch pushes, update the cached baseline for future PRs - if [ "${{ github.ref }}" = "refs/heads/main" ]; then + if [ "$GITHUB_REF_VALUE" = "refs/heads/main" ]; then cp /tmp/bench-current.json /tmp/bench-baseline.json fi - name: Upload benchmark artifacts diff --git a/.github/workflows/memory-diff.yml b/.github/workflows/memory-diff.yml index 766e178a8..4db29321d 100644 --- a/.github/workflows/memory-diff.yml +++ b/.github/workflows/memory-diff.yml @@ -76,6 +76,8 @@ jobs: python -c "import json; d=json.load(open('pr.json')); print('backend:', d['backend'], 'modules:', d['module_count'], 'rss:', d['total_rss_bytes'])" - name: Checkout base branch into side worktree + env: + GITHUB_BASE_REF_VALUE: ${{ github.base_ref }} run: | # Stash the memory_diff helpers from the PR and use the same helper # version for both measurements. The helper is measurement tooling; @@ -84,7 +86,7 @@ jobs: mkdir -p /tmp/memdiff-scripts cp scripts/memory_diff.py /tmp/memdiff-scripts/ cp scripts/format_memory_diff.py /tmp/memdiff-scripts/ - git worktree add --detach /tmp/base-tree "origin/${{ github.base_ref }}" + git worktree add --detach /tmp/base-tree "origin/$GITHUB_BASE_REF_VALUE" mkdir -p /tmp/base-tree/scripts cp /tmp/memdiff-scripts/memory_diff.py /tmp/base-tree/scripts/ cp /tmp/memdiff-scripts/format_memory_diff.py /tmp/base-tree/scripts/ diff --git a/scripts/Dockerfile.install-matrix b/scripts/Dockerfile.install-matrix index 51590c7ea..0441bb323 100644 --- a/scripts/Dockerfile.install-matrix +++ b/scripts/Dockerfile.install-matrix @@ -126,22 +126,22 @@ RUN mkdir -p /boot/firmware \ 'dtparam=i2c_arm=off' \ > /boot/firmware/config.txt +RUN useradd --create-home --shell /bin/bash inkypi-ci \ + && printf 'inkypi-ci ALL=(ALL) NOPASSWD:ALL\n' > /etc/sudoers.d/inkypi-ci \ + && chmod 0440 /etc/sudoers.d/inkypi-ci + # Copy the local checkout into the image so we exercise the branch under test, # not whatever is on GitHub. The entrypoint script (passed in via docker run) # will cd to /InkyPi/install and invoke install.sh. -COPY . /InkyPi +COPY --chown=inkypi-ci:inkypi-ci . /InkyPi WORKDIR /InkyPi/install -# install.sh requires root to run apt-get, write to /boot/firmware, manage -# services, and create the inkypi system user. A non-root USER would break the -# installer. We declare USER root explicitly to satisfy the DS-0002 linter -# requirement for an explicit USER directive before CMD. -USER root +USER inkypi-ci # This image is a CI install-verification container, not a long-running service. # HEALTHCHECK NONE suppresses the DS-0026 "missing healthcheck" alert. HEALTHCHECK NONE # Default command runs install.sh; CI overrides this to add verification steps. -CMD ["bash", "./install.sh"] +CMD ["sudo", "bash", "./install.sh"] diff --git a/scripts/Dockerfile.sim-install b/scripts/Dockerfile.sim-install index 780deedce..204b12723 100644 --- a/scripts/Dockerfile.sim-install +++ b/scripts/Dockerfile.sim-install @@ -34,18 +34,18 @@ RUN printf '#!/bin/sh\n# sim-only no-op raspi-config shim\nexit 0\n' \ > /usr/sbin/raspi-config \ && chmod +x /usr/sbin/raspi-config +RUN useradd --create-home --shell /bin/bash inkypi-ci \ + && printf 'inkypi-ci ALL=(ALL) NOPASSWD:ALL\n' > /etc/sudoers.d/inkypi-ci \ + && chmod 0440 /etc/sudoers.d/inkypi-ci + # Copy the local checkout into the image so we exercise the branch under test, # not whatever is on GitHub. -COPY . /InkyPi +COPY --chown=inkypi-ci:inkypi-ci . /InkyPi -# install.sh requires root to run apt-get, write to /boot/firmware, manage -# services, and create the inkypi system user. A non-root USER would break the -# installer. We declare USER root explicitly to satisfy the DS-0002 linter -# requirement for an explicit USER directive before CMD. -USER root +USER inkypi-ci # This image is a local sim-install container, not a long-running service. # HEALTHCHECK NONE suppresses the DS-0026 "missing healthcheck" alert. HEALTHCHECK NONE -CMD ["bash", "-c", "cd /InkyPi/install && ./install.sh"] +CMD ["sudo", "bash", "-c", "cd /InkyPi/install && ./install.sh"] diff --git a/scripts/ci_install_matrix_verify.sh b/scripts/ci_install_matrix_verify.sh index b97bccd43..23ba55849 100755 --- a/scripts/ci_install_matrix_verify.sh +++ b/scripts/ci_install_matrix_verify.sh @@ -61,7 +61,7 @@ fi # wheelhouse would mask a broken requirements.txt. export INKYPI_SKIP_WHEELHOUSE=1 -if bash ./install.sh; then +if sudo bash ./install.sh; then pass "install.sh exited 0" else rc=$? diff --git a/scripts/update_cdn_sri.py b/scripts/update_cdn_sri.py index 7e723198d..c01882d33 100644 --- a/scripts/update_cdn_sri.py +++ b/scripts/update_cdn_sri.py @@ -24,14 +24,23 @@ import sys import urllib.request from pathlib import Path +from urllib.parse import urlparse REPO_ROOT = Path(__file__).resolve().parent.parent MANIFEST_PATH = REPO_ROOT / "src" / "static" / "cdn_manifest.json" +def _validate_cdn_url(url: str) -> str: + parsed = urlparse(url) + if parsed.scheme != "https" or not parsed.netloc: + raise ValueError("CDN asset URLs must be absolute HTTPS URLs") + return url + + def compute_sri_from_url(url: str) -> str: """Download *url* and return its ``sha384-`` SRI hash.""" - with urllib.request.urlopen(url, timeout=30) as resp: # noqa: S310 + safe_url = _validate_cdn_url(url) + with urllib.request.urlopen(safe_url, timeout=30) as resp: # noqa: S310 data = resp.read() digest = hashlib.sha384(data).digest() return "sha384-" + base64.b64encode(digest).decode("ascii") diff --git a/src/app_setup/security_middleware.py b/src/app_setup/security_middleware.py index 69d770bf3..12aa48346 100644 --- a/src/app_setup/security_middleware.py +++ b/src/app_setup/security_middleware.py @@ -18,7 +18,7 @@ from typing import Any, cast from urllib.parse import quote, urlencode, urlunsplit -from flask import Flask, Response, abort, g, make_response, redirect, request, session +from flask import Flask, Response, abort, g, redirect, request, session from app_setup.smoke import SMOKE_RENDER_PATH, smoke_render_enabled from config import Config @@ -100,7 +100,7 @@ def setup_secret_key(app: Flask, device_config: Config) -> None: except Exception as e: secret = generated logger.warning( - "SECRET_KEY could not persist: %s — sessions won't survive restarts", e + "Generated session signing key could not be persisted: %s", e ) app.secret_key = secret app.config["SESSION_COOKIE_HTTPONLY"] = True @@ -270,6 +270,14 @@ def _is_mutating_path(path: str) -> bool: return path in _MUTATING_RATE_PATHS or path.startswith(_MUTATING_RATE_PREFIX) +def _rate_limited_json_response(message: str, *, retry_after: str) -> Response: + body, status = json_error(message, status=429) + resp = cast(Response, body) + resp.status_code = status + resp.headers["Retry-After"] = retry_after + return resp + + def _apply_token_bucket_limits(path: str, addr: str) -> Response | None: """Check per-endpoint token-bucket limits; return a 429 response or None. @@ -280,22 +288,17 @@ def _apply_token_bucket_limits(path: str, addr: str) -> Response | None: return None if path in _AUTH_RATE_PATHS and not _auth_bucket.try_acquire(addr): - body, code = json_error("Too many login attempts — try again later", status=429) - resp = make_response(body, code) - resp.headers["Retry-After"] = "30" - return resp + return _rate_limited_json_response( + "Too many login attempts — try again later", retry_after="30" + ) if path in _REFRESH_RATE_PATHS and not _refresh_bucket.try_acquire(addr): - body, code = json_error( - "Refresh rate limit exceeded — try again later", status=429 + return _rate_limited_json_response( + "Refresh rate limit exceeded — try again later", retry_after="6" ) - resp = make_response(body, code) - resp.headers["Retry-After"] = "6" - return resp if _is_mutating_path(path) and not _mutating_bucket.try_acquire(addr): - body, code = json_error("Too many requests — try again later", status=429) - resp = make_response(body, code) - resp.headers["Retry-After"] = "6" - return resp + return _rate_limited_json_response( + "Too many requests — try again later", retry_after="6" + ) return None diff --git a/src/benchmarks/benchmark_storage.py b/src/benchmarks/benchmark_storage.py index b51ad797a..64faa5930 100644 --- a/src/benchmarks/benchmark_storage.py +++ b/src/benchmarks/benchmark_storage.py @@ -141,6 +141,48 @@ def _ensure_schema(conn: sqlite3.Connection) -> None: _ALLOWED_TABLES = frozenset({"refresh_events", "stage_events"}) _ALLOWED_COLUMN_TYPES = frozenset({"TEXT", "INTEGER", "REAL", "BLOB", "NUMERIC"}) _IDENTIFIER_RE = __import__("re").compile(r"^[A-Za-z_][A-Za-z0-9_]*$") +_TABLE_INFO_QUERIES = { + "refresh_events": "PRAGMA table_info(refresh_events)", + "stage_events": "PRAGMA table_info(stage_events)", +} +_ALTER_COLUMN_QUERIES = { + ("refresh_events", "instance", "TEXT"): ( + "ALTER TABLE refresh_events ADD COLUMN instance TEXT" + ), + ("refresh_events", "playlist", "TEXT"): ( + "ALTER TABLE refresh_events ADD COLUMN playlist TEXT" + ), + ("refresh_events", "used_cached", "INTEGER"): ( + "ALTER TABLE refresh_events ADD COLUMN used_cached INTEGER" + ), + ("refresh_events", "request_ms", "INTEGER"): ( + "ALTER TABLE refresh_events ADD COLUMN request_ms INTEGER" + ), + ("refresh_events", "generate_ms", "INTEGER"): ( + "ALTER TABLE refresh_events ADD COLUMN generate_ms INTEGER" + ), + ("refresh_events", "preprocess_ms", "INTEGER"): ( + "ALTER TABLE refresh_events ADD COLUMN preprocess_ms INTEGER" + ), + ("refresh_events", "display_ms", "INTEGER"): ( + "ALTER TABLE refresh_events ADD COLUMN display_ms INTEGER" + ), + ("refresh_events", "cpu_percent", "REAL"): ( + "ALTER TABLE refresh_events ADD COLUMN cpu_percent REAL" + ), + ("refresh_events", "memory_percent", "REAL"): ( + "ALTER TABLE refresh_events ADD COLUMN memory_percent REAL" + ), + ("refresh_events", "notes", "TEXT"): ( + "ALTER TABLE refresh_events ADD COLUMN notes TEXT" + ), + ("stage_events", "duration_ms", "INTEGER"): ( + "ALTER TABLE stage_events ADD COLUMN duration_ms INTEGER" + ), + ("stage_events", "extra_json", "TEXT"): ( + "ALTER TABLE stage_events ADD COLUMN extra_json TEXT" + ), +} def _validate_identifier(value: str, label: str) -> str: @@ -171,25 +213,23 @@ def _ensure_optional_columns( if table_name not in _ALLOWED_TABLES: raise ValueError(f"Unknown benchmark table: {table_name!r}") safe_table = _validate_identifier(table_name, "table_name") + table_info_query = _TABLE_INFO_QUERIES[safe_table] - existing = { - # Safe: safe_table is validated above against the allow-list. - row[1] - for row in conn.execute( - f"PRAGMA table_info({safe_table})" - ).fetchall() # noqa: S608 - } + existing = {row[1] for row in conn.execute(table_info_query).fetchall()} for column_name, column_type in expected_columns.items(): if column_name in existing: continue safe_col = _validate_identifier(column_name, "column_name") if column_type not in _ALLOWED_COLUMN_TYPES: raise ValueError(f"Unknown column type: {column_type!r}") - # Safe: safe_table validated against allow-list; safe_col validated via - # regex; column_type validated against allow-list of SQLite type keywords. - conn.execute( # noqa: S608 - f"ALTER TABLE {safe_table} ADD COLUMN {safe_col} {column_type}" - ) + try: + alter_query = _ALTER_COLUMN_QUERIES[(safe_table, safe_col, column_type)] + except KeyError as exc: + raise ValueError( + f"Unexpected benchmark column for {safe_table!r}: " + f"{safe_col!r} {column_type!r}" + ) from exc + conn.execute(alter_query) def save_refresh_event( diff --git a/src/blueprints/playlist.py b/src/blueprints/playlist.py index c1e1061ac..ee13ac18a 100644 --- a/src/blueprints/playlist.py +++ b/src/blueprints/playlist.py @@ -572,13 +572,13 @@ def _parse_playlist_request_data( def _parse_playlist_update_payload( data: Any, -) -> tuple[PlaylistUpdateRequest | None, Any]: +) -> tuple[PlaylistUpdateRequest | None, RequestModelError | None]: """Validate an /update_playlist request payload.""" parsed, error = parse_playlist_update_request(data) if error is not None: - return None, _request_model_error_response(error) + return None, error if parsed is None: - return None, json_error("Invalid playlist payload", status=400) + return None, RequestModelError("Invalid playlist payload") return parsed, None @@ -684,7 +684,7 @@ def update_playlist(playlist_name: str) -> Any: parsed, err = _parse_playlist_update_payload(data) if err: - return err + return _request_model_error_response(err) if parsed is None: return json_error("Invalid playlist payload", status=400) @@ -1002,7 +1002,9 @@ def playlist_eta(playlist_name: str) -> Any: except Exception: num = 0 - is_active = bool(last_dt and getattr(ri_obj, "playlist", None) == playlist_name) + is_active = ( + last_dt is not None and getattr(ri_obj, "playlist", None) == playlist_name + ) next_index = _safe_next_index(pl, num) until_next_min = _safe_until_next_min( is_active, cast(datetime | None, last_dt), cycle_min, now diff --git a/src/display/waveshare_display.py b/src/display/waveshare_display.py index 4f939f50d..3c35ffc02 100644 --- a/src/display/waveshare_display.py +++ b/src/display/waveshare_display.py @@ -1,6 +1,6 @@ -import importlib import inspect import logging +import re import sys from collections.abc import Callable from pathlib import Path @@ -11,6 +11,34 @@ from display.abstract_display import AbstractDisplay logger = logging.getLogger(__name__) +_WAVESHARE_DISPLAY_RE = re.compile(r"^epd[A-Za-z0-9_]+$", re.ASCII) +_WAVESHARE_MANIFEST = ( + Path(__file__).resolve().parents[2] / "install" / "waveshare-manifest.txt" +) + + +def _allowed_waveshare_display_types() -> set[str]: + try: + names: set[str] = set() + for line in _WAVESHARE_MANIFEST.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + driver_name = line.split(maxsplit=1)[0] + if driver_name.endswith(".py") and driver_name != "epdconfig.py": + names.add(driver_name[:-3]) + return names + except OSError: + return set() + + +def _validate_waveshare_display_type(display_type: str) -> str: + if not _WAVESHARE_DISPLAY_RE.fullmatch(display_type): + raise ValueError(f"Unsupported Waveshare display type: {display_type}") + allowed = _allowed_waveshare_display_types() + if allowed and display_type not in allowed: + raise ValueError(f"Unsupported Waveshare display type: {display_type}") + return display_type def split_image_for_bi_color_epd(image: Image.Image) -> tuple[Image.Image, Image.Image]: @@ -72,8 +100,8 @@ def initialize_display(self) -> None: "Waveshare driver but 'display_type' not specified in configuration." ) - # Construct module path dynamically - e.g. "display.waveshare_epd.epd7in3e" - module_name = f"display.waveshare_epd.{display_type}" + safe_display_type = _validate_waveshare_display_type(display_type) + module_name = f"display.waveshare_epd.{safe_display_type}" # Workaround for some Waveshare drivers using 'import epdconfig' causing import errors epd_dir = Path(__file__).parent / "waveshare_epd" @@ -81,8 +109,7 @@ def initialize_display(self) -> None: sys.path.insert(0, str(epd_dir)) try: - # Dynamically load module - epd_module = importlib.import_module(module_name) + epd_module = __import__(module_name, fromlist=["EPD"]) self.epd_display: Any = epd_module.EPD() # Workaround for init functions with inconsistent casing init_method = getattr(self.epd_display, "Init", None) diff --git a/src/plugins/plugin_registry.py b/src/plugins/plugin_registry.py index bb8fcb1c2..eafe70864 100644 --- a/src/plugins/plugin_registry.py +++ b/src/plugins/plugin_registry.py @@ -4,6 +4,7 @@ import json import logging import os +import re import sys import threading from pathlib import Path @@ -18,6 +19,25 @@ _registry_lock = threading.RLock() _LAST_HOT_RELOAD: dict[str, object] | None = None _hot_reload_lock = threading.Lock() +_PLUGIN_ID_RE = re.compile(r"^[A-Za-z][A-Za-z0-9_]*$", re.ASCII) + + +def _validate_plugin_module_path( + plugin_id: str, *, allow_unregistered: bool = False +) -> str: + if not _PLUGIN_ID_RE.fullmatch(plugin_id): + raise ValueError(f"Plugin '{plugin_id}' has an invalid id.") + with _registry_lock: + is_registered = plugin_id in _PLUGIN_CONFIGS + if not is_registered: + plugin_module_path = ( + Path(cast(str, cast(Any, resolve_path)(PLUGINS_DIR))) + / plugin_id + / f"{plugin_id}.py" + ) + if not allow_unregistered or not plugin_module_path.is_file(): + raise ValueError(f"Plugin '{plugin_id}' is not registered.") + return f"plugins.{plugin_id}.{plugin_id}" def _is_dev_mode() -> bool: @@ -37,7 +57,9 @@ def _load_single_plugin_instance(plugin_config: dict[str, Any]) -> Any: plugin_class_name = plugin_config.get("class") if not isinstance(plugin_class_name, str) or not plugin_class_name: raise ValueError(f"Plugin '{plugin_id}' is missing a valid class.") - module_name = f"plugins.{plugin_id}.{plugin_id}" + module_name = _validate_plugin_module_path( + plugin_id, allow_unregistered=_is_dev_mode() + ) try: reloaded = False no_hot_reload = os.getenv("INKYPI_NO_HOT_RELOAD", "").strip().lower() in ( @@ -50,7 +72,7 @@ def _load_single_plugin_instance(plugin_config: dict[str, Any]) -> Any: module = importlib.reload(sys.modules[module_name]) reloaded = True else: - module = importlib.import_module(module_name) + module = __import__(module_name, fromlist=[plugin_class_name]) plugin_cls = getattr(module, plugin_class_name, None) if not plugin_cls: raise ImportError( @@ -135,6 +157,11 @@ def load_plugins(plugins_config: list[dict[str, Any]]) -> None: if not isinstance(plugin_id, str) or not plugin_id: logger.error("Plugin config is missing a valid id, skipping.") continue + if not _PLUGIN_ID_RE.fullmatch(plugin_id): + logger.error( + "Plugin config id '%s' is not a safe module name, skipping.", plugin_id + ) + continue if plugin.get("disabled", False): logger.info(f"Plugin {plugin_id} is disabled, skipping.") continue diff --git a/src/static/scripts/playlist/progress.js b/src/static/scripts/playlist/progress.js index 913871051..681a72d2d 100644 --- a/src/static/scripts/playlist/progress.js +++ b/src/static/scripts/playlist/progress.js @@ -6,7 +6,7 @@ try { return JSON.parse(rawValue); } catch (error) { - console.debug(`Failed parsing ${label}:`, error); + console.debug("Failed parsing playlist progress payload:", label, error); return null; } } @@ -15,7 +15,7 @@ try { return await response.json(); } catch (error) { - console.debug(`Failed parsing ${label}:`, error); + console.debug("Failed parsing playlist progress response:", label, error); return null; } } diff --git a/src/static/scripts/settings/actions.js b/src/static/scripts/settings/actions.js index 756a7935c..b379c2952 100644 --- a/src/static/scripts/settings/actions.js +++ b/src/static/scripts/settings/actions.js @@ -336,12 +336,12 @@ // Silent refresh — the user was just watching the update log // stream, they don't need another toast announcing the state. checkForUpdates({ silent: true }); - } - } catch (e) { - console.warn(`${logLabel} status poll failed:`, e); - clearInterval(state.updateTimer); - state.updateTimer = null; } + } catch (e) { + console.warn("Settings update status poll failed:", logLabel, e); + clearInterval(state.updateTimer); + state.updateTimer = null; + } }, 2000); } @@ -366,7 +366,7 @@ showResponseModal("success", data.message || startingLabel); pollUpdateStatusUntilDone(kind); } catch (e) { - console.warn(`Failed to start ${kind.toLowerCase()}:`, e); + console.warn("Failed to start settings action:", kind, e); showResponseModal("failure", failureMessage); } finally { setHeaderButtonsDisabled(false); diff --git a/src/static/scripts/settings/diagnostics.js b/src/static/scripts/settings/diagnostics.js index 9deb9883a..ac092dea9 100644 --- a/src/static/scripts/settings/diagnostics.js +++ b/src/static/scripts/settings/diagnostics.js @@ -537,7 +537,7 @@ await refreshIsolation(); await refreshHealth(); } catch (e) { - console.warn(`Failed to ${verb} plugin:`, e); + console.warn("Failed to toggle plugin isolation:", verb, e); showResponseModal( "failure", `Failed to ${verb} plugin. Check your connection and try again.` diff --git a/src/static/scripts/store.js b/src/static/scripts/store.js index 8dbc8259d..0bf308143 100644 --- a/src/static/scripts/store.js +++ b/src/static/scripts/store.js @@ -72,7 +72,7 @@ try { fn(next[k], prev[k]); } catch (e) { - console.warn('InkyPiStore subscriber error for key "' + k + '":', e); + console.warn("InkyPiStore subscriber error for key:", k, e); } }); }); diff --git a/src/utils/icon_utils.py b/src/utils/icon_utils.py index 825920791..d6b4f8f7f 100644 --- a/src/utils/icon_utils.py +++ b/src/utils/icon_utils.py @@ -3,6 +3,20 @@ from markupsafe import Markup, escape +_ICON_NAME_RE = re.compile(r"^[A-Za-z0-9_-]+$", re.ASCII) +_CLASS_TOKEN_RE = re.compile(r"^[A-Za-z0-9_-]+$", re.ASCII) + + +def _safe_icon_name(name: str) -> str: + if _ICON_NAME_RE.fullmatch(name): + return name + return "question" + + +def _safe_class_name(class_name: str) -> str: + tokens = [token for token in class_name.split() if _CLASS_TOKEN_RE.fullmatch(token)] + return " ".join(tokens) or "icon-image" + def render_icon( name: str, class_name: str = "icon-image", title: str | None = None @@ -13,15 +27,17 @@ def render_icon( class and optional title attributes. Otherwise, returns an tag that expects Phosphor CSS classes to render. """ + safe_name = _safe_icon_name(name) + safe_class_name = _safe_class_name(class_name) try: base_dir = os.path.dirname(os.path.dirname(__file__)) # src/ - svg_path = os.path.join(base_dir, "static", "icons", "ph", f"{name}.svg") + svg_path = os.path.join(base_dir, "static", "icons", "ph", f"{safe_name}.svg") if os.path.isfile(svg_path): with open(svg_path, encoding="utf-8") as f: svg = f.read() # inject class and title if missing # naive injection: add class attribute to first - cls_attr = f'class="{class_name}"' + cls_attr = f'class="{escape(safe_class_name)}"' if ( "", 1)[0] @@ -40,5 +56,5 @@ class and optional title attributes. Otherwise, returns an tag that expects # Fallback: Phosphor class (requires stylesheet) title_attr = f' title="{escape(title)}"' if title else "" return Markup( - f'' + f'' ) diff --git a/tests/test_sri.py b/tests/test_sri.py index d8aac3b34..669449e1b 100644 --- a/tests/test_sri.py +++ b/tests/test_sri.py @@ -391,6 +391,15 @@ def test_dry_run_does_not_write(self, tmp_path: Path) -> None: still = json.loads(manifest_path.read_text(encoding="utf-8")) assert still["test-lib"]["integrity"] == original_integrity + def test_rejects_non_https_cdn_urls(self) -> None: + import update_cdn_sri as ucs + + with patch("urllib.request.urlopen") as urlopen: + with pytest.raises(ValueError, match="absolute HTTPS"): + ucs.compute_sri_from_url("file:///etc/passwd") + + urlopen.assert_not_called() + def test_missing_manifest_returns_error(self, tmp_path: Path) -> None: import update_cdn_sri as ucs diff --git a/tests/unit/test_benchmark_storage.py b/tests/unit/test_benchmark_storage.py index 3ed8a3b2d..f4a1df7e4 100644 --- a/tests/unit/test_benchmark_storage.py +++ b/tests/unit/test_benchmark_storage.py @@ -498,13 +498,13 @@ def test_ensure_optional_columns_adds_missing_columns(tmp_path): ) conn.commit() - _ensure_optional_columns(conn, "refresh_events", {"new_col": "TEXT"}) + _ensure_optional_columns(conn, "refresh_events", {"instance": "TEXT"}) cursor = conn.execute("PRAGMA table_info(refresh_events)") columns = {row[1] for row in cursor.fetchall()} conn.close() - assert "new_col" in columns + assert "instance" in columns def test_ensure_optional_columns_skips_existing_columns(tmp_path): @@ -520,3 +520,18 @@ def test_ensure_optional_columns_skips_existing_columns(tmp_path): # Should not raise even though existing_col already exists _ensure_optional_columns(conn, "refresh_events", {"existing_col": "TEXT"}) conn.close() + + +def test_ensure_optional_columns_rejects_unexpected_column(tmp_path): + """Only known benchmark migration columns can be added.""" + import pytest + + from benchmarks.benchmark_storage import _ensure_optional_columns + + conn = sqlite3.connect(str(tmp_path / "test.db")) + conn.execute("CREATE TABLE refresh_events (id INTEGER PRIMARY KEY)") + conn.commit() + + with pytest.raises(ValueError, match="Unexpected benchmark column"): + _ensure_optional_columns(conn, "refresh_events", {"surprise": "TEXT"}) + conn.close() diff --git a/tests/unit/test_icon_utils.py b/tests/unit/test_icon_utils.py index 5e7ad9e30..f7044d71c 100644 --- a/tests/unit/test_icon_utils.py +++ b/tests/unit/test_icon_utils.py @@ -93,6 +93,21 @@ def test_fallback_without_title(): assert "title=" not in result +def test_invalid_icon_name_and_class_are_sanitized(): + result = str( + icon_utils.render_icon( + '../bad" onmouseover="alert(1)', + class_name='ok bad" onclick="alert(1)', + ) + ) + + assert "../bad" not in result + assert "onmouseover" not in result + assert "onclick" not in result + assert "ph ph-question" in result + assert " ok" in result + + def test_title_injected_with_doctype(monkeypatch): svg_content = ( "