Skip to content

Commit

Permalink
Fix checking relative imports from local package
Browse files Browse the repository at this point in the history
  • Loading branch information
emdoyle committed Feb 8, 2024
1 parent fc1bd1b commit ec9e21f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
5 changes: 4 additions & 1 deletion modguard/parsing/boundary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@ def __init__(self):

def visit_ImportFrom(self, node):
# Check if 'Boundary' is imported specifically from a 'modguard'-rooted module
if (node.module == "modguard" or node.module and node.module.startswith("modguard.")) and any(
is_modguard_module_import = node.module is not None and (
node.module == "modguard" or node.module.startswith("modguard.")
)
if is_modguard_module_import and any(
alias.name == "Boundary" for alias in node.names
):
self.is_modguard_boundary_imported = True
Expand Down
19 changes: 13 additions & 6 deletions modguard/parsing/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@ def _get_ignored_modules(self, lineno: int) -> Optional[list[str]]:

def visit_ImportFrom(self, node):
# For relative imports (level > 0), adjust the base module path
if node.level > 0:
if node.module is not None and node.level > 0:
num_paths_to_strip = node.level - 1 if self.is_package else node.level
base_path_parts = self.current_mod_path.split(".")
if num_paths_to_strip:
base_path_parts = base_path_parts[:-num_paths_to_strip]
base_mod_path = ".".join([*base_path_parts, node.module if node.module else ''])
else:
base_mod_path = node.module
base_mod_path = node.module or ""

ignored_modules = self._get_ignored_modules(node.lineno)

Expand All @@ -71,14 +71,21 @@ def visit_ImportFrom(self, node):
return self.generic_visit(node)

for name_node in node.names:
if ignored_modules is not None and (
local_mod_path = (
f"{'.' * node.level}{node.module}.{name_node.asname or name_node.name}"
in ignored_modules
):
if node.module
else f"{'.' * node.level}{name_node.asname or name_node.name}"
)
if ignored_modules is not None and (local_mod_path in ignored_modules):
# This import is ignored by a modguard-ignore directive
continue

self.imports.append(f"{base_mod_path}.{name_node.asname or name_node.name}")
global_mod_path = (
f"{base_mod_path}.{name_node.asname or name_node.name}"
if node.module
else (name_node.asname or name_node.name)
)
self.imports.append(global_mod_path)

self.generic_visit(node)

Expand Down
10 changes: 8 additions & 2 deletions modguard/parsing/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@ def __init__(self, module_name: str):
self.import_found = False

def visit_ImportFrom(self, node):
if (node.module == "modguard" or node.module and node.module.startswith("modguard.")) and any(
is_modguard_module_import = node.module is not None and (
node.module == "modguard" or node.module.startswith("modguard.")
)
if is_modguard_module_import and any(
alias.name == self.module_name for alias in node.names
):
self.import_found = True
Expand All @@ -37,7 +40,10 @@ def __init__(self, current_mod_path: str, is_package: bool = False):
self.public_members: list[PublicMember] = []

def visit_ImportFrom(self, node):
if (node.module == "modguard" or node.module and node.module.startswith("modguard.")) and any(
is_modguard_module_import = node.module is not None and (
node.module == "modguard" or node.module.startswith("modguard.")
)
if is_modguard_module_import and any(
alias.name == "public" for alias in node.names
):
self.is_modguard_public_imported = True
Expand Down

0 comments on commit ec9e21f

Please sign in to comment.