Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions acestep/training/path_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

Provides a single ``safe_path`` function that validates user-provided
filesystem paths against a known safe root directory. The validation
uses ``os.path.normpath`` followed by a ``.startswith`` check — the
uses ``os.path.realpath`` followed by a ``.startswith`` check — the
exact pattern that CodeQL recognises as a sanitiser for the
``py/path-injection`` query.

Symlinks are resolved on both the root and user paths so that paths
through symlinks (e.g. ``/root/data`` → ``/vepfs/.../data``) are
compared consistently.

All training modules that accept user-supplied paths should call
``safe_path`` (or ``safe_open``) before performing any filesystem I/O.
"""
Expand All @@ -15,20 +19,30 @@

from loguru import logger


def _resolve(path: str) -> str:
"""Normalise and resolve symlinks in *path*.

Uses ``os.path.realpath`` so that symlinked prefixes are resolved
to their canonical form before comparison.
"""
return os.path.normpath(os.path.realpath(path))


# Root directory that all user-provided paths must resolve under.
# Defaults to the working directory at import time. Override via
# ``set_safe_root`` if needed (e.g. in tests).
_SAFE_ROOT: str = os.path.normpath(os.path.abspath(os.getcwd()))
_SAFE_ROOT: str = _resolve(os.getcwd())


def set_safe_root(root: str) -> None:
"""Override the safe root directory.

Args:
root: New safe root (will be normalised).
root: New safe root (will be normalised and symlink-resolved).
"""
global _SAFE_ROOT # noqa: PLW0603
_SAFE_ROOT = os.path.normpath(os.path.abspath(root))
_SAFE_ROOT = _resolve(root)


def get_safe_root() -> str:
Expand All @@ -40,30 +54,32 @@ def safe_path(user_path: str, *, base: Optional[str] = None) -> str:
"""Validate and normalise a user-provided path.

The returned path is guaranteed to live under *base* (or the
global ``_SAFE_ROOT`` when *base* is ``None``).
global ``_SAFE_ROOT`` when *base* is ``None``). Symlinks in both
the root and user path are resolved so that paths through symlinks
compare correctly.

Args:
user_path: Untrusted path string from user input.
base: Optional explicit base directory. When provided it is
normalised and used instead of ``_SAFE_ROOT``.
resolved (symlinks included) and used instead of
``_SAFE_ROOT``.

Returns:
Normalised absolute path that is within the safe root.
Normalised, symlink-resolved absolute path within the safe root.

Raises:
ValueError: If the normalised path escapes the safe root.
ValueError: If the resolved path escapes the safe root.
"""
if base is not None:
root = os.path.normpath(os.path.abspath(base))
root = _resolve(base)
else:
root = _SAFE_ROOT

# Normalise the user path. If it is relative, resolve it against
# *root*; if absolute, normalise it directly.
# Resolve the user path. If relative, join against *root* first.
if os.path.isabs(user_path):
normalised = os.path.normpath(user_path)
normalised = _resolve(user_path)
else:
normalised = os.path.normpath(os.path.join(root, user_path))
normalised = _resolve(os.path.join(root, user_path))

# ── CodeQL-recognised sanitiser barrier ──
# ``normpath(…).startswith(safe_prefix)`` is the pattern that
Expand Down