Skip to content

Commit

Permalink
Implement init_project to protect a new project from scratch
Browse files Browse the repository at this point in the history
  • Loading branch information
emdoyle committed Feb 8, 2024
1 parent 7373028 commit d46722d
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 83 deletions.
120 changes: 69 additions & 51 deletions modguard/check.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from dataclasses import dataclass
from typing import Optional

from .core.boundary import BoundaryTrie, BoundaryNode
from .parsing.boundary import build_boundary_trie
from .parsing.imports import get_imports
from .parsing import utils
Expand All @@ -22,6 +24,60 @@ def message(self) -> str:
return f"Import '{self.import_mod_path}' in {self.location} is blocked by boundary '{self.boundary_path}'"


def check_import(
boundary_trie: BoundaryTrie,
import_mod_path: str,
file_nearest_boundary: BoundaryNode,
file_mod_path: str,
) -> Optional[BoundaryNode]:
nearest_boundary = boundary_trie.find_nearest(import_mod_path)
# An imported module is allowed only in the following cases:
# * The module is not contained by a boundary [generally 3rd party]
import_mod_has_boundary = nearest_boundary is not None

# * The imported module's boundary is a child of the file's boundary
import_mod_is_child_of_current = (
import_mod_has_boundary
and file_nearest_boundary.full_path.startswith(nearest_boundary.full_path)
)

# * The module is exported as public by its boundary and is allowed in the current path
import_mod_public_member_definition = (
next(
(
public_member
for public_member_name, public_member in nearest_boundary.public_members.items()
if import_mod_path.startswith(public_member_name)
),
None,
)
if import_mod_has_boundary
else None
)
import_mod_is_public_and_allowed = (
import_mod_public_member_definition is not None
and (
import_mod_public_member_definition.allowlist is None
or any(
(
file_mod_path.startswith(allowed_path)
for allowed_path in import_mod_public_member_definition.allowlist
)
)
)
)

if (
not import_mod_has_boundary
or import_mod_is_child_of_current
or import_mod_is_public_and_allowed
):
return None

# In error case, return path of the violated boundary
return nearest_boundary


def check(root: str, exclude_paths: list[str] = None) -> list[ErrorInfo]:
if not os.path.isdir(root):
return [ErrorInfo(exception_message=f"The path {root} is not a directory.")]
Expand All @@ -33,66 +89,28 @@ def check(root: str, exclude_paths: list[str] = None) -> list[ErrorInfo]:
boundary_trie = build_boundary_trie(root, exclude_paths=exclude_paths)

errors = []
for dirpath, filename in utils.walk_pyfiles(root, exclude_paths=exclude_paths):
file_path = os.path.join(dirpath, filename)
current_mod_path = utils.file_to_module_path(file_path)
current_nearest_boundary = boundary_trie.find_nearest(current_mod_path)
for file_path in utils.walk_pyfiles(root, exclude_paths=exclude_paths):
mod_path = utils.file_to_module_path(file_path)
nearest_boundary = boundary_trie.find_nearest(mod_path)
assert (
current_nearest_boundary is not None
nearest_boundary is not None
), f"Checking file ({file_path}) outside of boundaries!"
import_mod_paths = get_imports(file_path)
for mod_path in import_mod_paths:
nearest_boundary = boundary_trie.find_nearest(mod_path)
# An imported module is allowed only in the following cases:
# * The module is not contained by a boundary [generally 3rd party]
import_mod_has_boundary = nearest_boundary is not None

# * The module's boundary is a child of the current boundary
import_mod_is_child_of_current = (
import_mod_has_boundary
and current_nearest_boundary.full_path.startswith(
nearest_boundary.full_path
)
)

# * The module is exported as public by its boundary and is allowed in the current path
import_mod_public_member_definition = (
next(
(
public_member
for public_member_name, public_member in nearest_boundary.public_members.items()
if mod_path.startswith(public_member_name)
),
None,
)
if import_mod_has_boundary
else None
for import_mod_path in import_mod_paths:
violated_boundary = check_import(
boundary_trie=boundary_trie,
import_mod_path=import_mod_path,
file_nearest_boundary=nearest_boundary,
file_mod_path=mod_path,
)
import_mod_is_public_and_allowed = (
import_mod_public_member_definition is not None
and (
import_mod_public_member_definition.allowlist is None
or any(
(
current_mod_path.startswith(allowed_path)
for allowed_path in import_mod_public_member_definition.allowlist
)
)
)
)

if (
not import_mod_has_boundary
or import_mod_is_child_of_current
or import_mod_is_public_and_allowed
):
if violated_boundary is None:
# This import is OK
continue

errors.append(
ErrorInfo(
import_mod_path=mod_path,
boundary_path=nearest_boundary.full_path,
import_mod_path=import_mod_path,
boundary_path=violated_boundary.full_path,
location=file_path,
)
)
Expand Down
2 changes: 2 additions & 0 deletions modguard/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .boundary import BoundaryNode, BoundaryTrie
from .public import PublicMember
22 changes: 20 additions & 2 deletions modguard/core/boundary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import deque
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Generator

from .public import PublicMember
from modguard.errors import ModguardSetupError
Expand All @@ -13,15 +14,21 @@ class Boundary:
@dataclass
class BoundaryNode:
public_members: dict[str, PublicMember] = field(default_factory=dict)
children: dict = field(default_factory=dict)
children: dict[str, "BoundaryNode"] = field(default_factory=dict)
is_end_of_path: bool = False
full_path: str = None

def add_public_member(self, member: PublicMember):
self.public_members[member.name] = member


@dataclass
class BoundaryTrie:
root: BoundaryNode = field(default_factory=BoundaryNode)

def __iter__(self):
return boundary_trie_iterator(self)

def get(self, path: str) -> Optional[BoundaryNode]:
node = self.root
parts = path.split(".")
Expand Down Expand Up @@ -72,3 +79,14 @@ def find_nearest(self, path: str) -> BoundaryNode:
break

return nearest_parent


def boundary_trie_iterator(trie: BoundaryTrie) -> Generator[BoundaryNode, None, None]:
stack = deque([trie.root])

while stack:
node = stack.popleft()
if node.is_end_of_path:
yield node

stack.extend(node.children.values())
66 changes: 66 additions & 0 deletions modguard/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import errors
from .check import check_import
from .core import PublicMember
from .parsing import utils
from .parsing.boundary import ensure_boundary, build_boundary_trie
from .parsing.imports import get_imports
from .parsing.public import mark_as_public


def init_project(root: str, exclude_paths: list[str] = None):
# Core functionality:
# * do nothing in any package already having a Boundary
# * import and call Boundary in __init__.py for all other packages
# * import and decorate public on all externally imported functions and classes
if not os.path.isdir(root):
return errors.ModguardSetupError(f"The path {root} is not a directory.")

# This 'canonicalizes' the path arguments, resolving directory traversal
root = utils.canonical(root)
exclude_paths = list(map(utils.canonical, exclude_paths)) if exclude_paths else None

boundary_trie = build_boundary_trie(root, exclude_paths=exclude_paths)
initial_boundary_paths = [boundary.full_path for boundary in boundary_trie]

for dirpath in utils.walk_pypackages(root, exclude_paths=exclude_paths):
added_boundary = ensure_boundary(dirpath + "/__init__.py")
if added_boundary:
dir_mod_path = utils.file_to_module_path(dirpath)
boundary_trie.insert(dir_mod_path)

for file_path in utils.walk_pyfiles(root, exclude_paths=exclude_paths):
mod_path = utils.file_to_module_path(file_path)
# If this file belongs to a Boundary which existed
# before calling init_project, ignore the file and move on
if any(
(
mod_path.startswith(initial_boundary_path)
for initial_boundary_path in initial_boundary_paths
)
):
continue

nearest_boundary = boundary_trie.find_nearest(mod_path)
assert (
nearest_boundary is not None
), f"Checking file ({file_path}) outside of boundaries!"
import_mod_paths = get_imports(file_path)
for import_mod_path in import_mod_paths:
violated_boundary = check_import(
boundary_trie=boundary_trie,
import_mod_path=import_mod_path,
file_nearest_boundary=nearest_boundary,
file_mod_path=mod_path,
)
if violated_boundary is None:
# This import is fine, no need to mark anything as public
continue

file_path, member_name = utils.module_to_file_path(import_mod_path)
mark_as_public(file_path, member_name)
violated_boundary.add_public_member(PublicMember(name=import_mod_path))


if __name__ == "__main__":
init_project(".")
47 changes: 36 additions & 11 deletions modguard/parsing/boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,45 @@ def visit_Call(self, node):
self.generic_visit(node)


def has_boundary(file_path: str) -> bool:
with open(file_path, "r") as file:
file_content = file.read()

def _has_boundary(file_path: str, file_content: str) -> bool:
try:
parsed_ast = ast.parse(file_content)
boundary_finder = BoundaryFinder()
boundary_finder.visit(parsed_ast)
return boundary_finder.found_boundary
except SyntaxError as e:
raise ModguardParseError(f"Syntax error in {file_path}: {e}")

boundary_finder = BoundaryFinder()
boundary_finder.visit(parsed_ast)
return boundary_finder.found_boundary


def has_boundary(file_path: str) -> bool:
with open(file_path, "r") as file:
file_content = file.read()

return _has_boundary(file_path, file_content)


BOUNDARY_PRELUDE = "import modguard\nmodguard.Boundary()\n"


def _add_boundary(file_path: str, file_content: str):
with open(file_path, "w") as file:
file.write(BOUNDARY_PRELUDE + file_content)


@public
def ensure_boundary(file_path: str) -> bool:
with open(file_path, "r") as file:
file_content = file.read()

if _has_boundary(file_path, file_content):
# Boundary already exists, don't need to create one
return False

# Boundary doesn't exist, create one
_add_boundary(file_path, file_content)
return True


@public
def build_boundary_trie(root: str, exclude_paths: list[str] = None) -> BoundaryTrie:
Expand All @@ -63,14 +90,12 @@ def build_boundary_trie(root: str, exclude_paths: list[str] = None) -> BoundaryT
# This means a project will pass 'check' by default
boundary_trie.insert(file_to_module_path(root))

for dirpath, filename in walk_pyfiles(root, exclude_paths=exclude_paths):
file_path = os.path.join(dirpath, filename)
for file_path in walk_pyfiles(root, exclude_paths=exclude_paths):
if has_boundary(file_path):
mod_path = file_to_module_path(file_path)
boundary_trie.insert(mod_path)

for dirpath, filename in walk_pyfiles(root, exclude_paths=exclude_paths):
file_path = os.path.join(dirpath, filename)
for file_path in walk_pyfiles(root, exclude_paths=exclude_paths):
mod_path = file_to_module_path(file_path)
public_members = get_public_members(file_path)
for public_member in public_members:
Expand Down
21 changes: 11 additions & 10 deletions modguard/parsing/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
self.current_mod_path = current_mod_path
self.is_package = is_package
self.ignored_imports = ignore_directives or {}
self.imports = []
self.imports: list[str] = []

def _get_ignored_modules(self, lineno: int) -> Optional[list[str]]:
# Check for ignore directive at the previous line or on the current line
Expand Down Expand Up @@ -102,14 +102,15 @@ def get_imports(file_path: str) -> list[str]:

try:
parsed_ast = ast.parse(file_content)
ignore_directives = get_ignore_directives(file_content)
mod_path = file_to_module_path(file_path)
import_visitor = ImportVisitor(
is_package=file_path.endswith("__init__.py"),
current_mod_path=mod_path,
ignore_directives=ignore_directives,
)
import_visitor.visit(parsed_ast)
return import_visitor.imports
except SyntaxError as e:
raise ModguardParseError(f"Syntax error in {file_path}: {e}")

ignore_directives = get_ignore_directives(file_content)
mod_path = file_to_module_path(file_path)
import_visitor = ImportVisitor(
is_package=file_path.endswith("__init__.py"),
current_mod_path=mod_path,
ignore_directives=ignore_directives,
)
import_visitor.visit(parsed_ast)
return import_visitor.imports
Loading

0 comments on commit d46722d

Please sign in to comment.