diff --git a/playpen/repo_level_awareness/api.py b/playpen/repo_level_awareness/api.py new file mode 100644 index 00000000..167df416 --- /dev/null +++ b/playpen/repo_level_awareness/api.py @@ -0,0 +1,57 @@ +from abc import ABC +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Generator, Iterator, Optional + +from pydantic import BaseModel + + +@dataclass +class RpcClientConfig: + repo_directory: Path + + +# FIXME: Oh god oh no oh jeez oh man +class Task: + pass + + +# FIXME: Might not need +class TaskResult: + encountered_errors: list[str] + modified_files: list[Path] + + +@dataclass +class ValidationError(Task): + pass + + +@dataclass +class ValidationResult: + passed: bool + errors: list[ValidationError] + + +class ValidationStep(ABC): + def __init__(self, RpcClientConfig: RpcClientConfig) -> None: + self.config = RpcClientConfig + + def run(self) -> ValidationResult: + pass + + +class Agent(ABC): + def can_handle_task(self, task: Task) -> bool: + pass + + def execute_task(self, task: Task) -> TaskResult: + pass + + def refine_task(self, errors: list[str]) -> None: + # Knows that it's the refine step so that it might not spawn as much + # stuff. + pass + + def can_handle_error(self, errors: list[str]) -> bool: + pass diff --git a/playpen/repo_level_awareness/codeplan.py b/playpen/repo_level_awareness/codeplan.py new file mode 100755 index 00000000..dc030541 --- /dev/null +++ b/playpen/repo_level_awareness/codeplan.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +from typing import Any, Generator, List, Optional, Type + +from kai.service.kai_application.kai_application import UpdatedFileContent + +from api import Agent, RpcClientConfig, Task, TaskResult, ValidationStep +from maven_validator import MavenCompileStep + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Run the CodePlan loop against a project" + ) + parser.add_argument( + "source_directory", help="The root directory of the project to be fixed" + ) + + args = parser.parse_args() + + config = RpcClientConfig(args.source_directory) + codeplan(config, None) + + +def codeplan( + config: RpcClientConfig, + updated_file_content: UpdatedFileContent, +): + whatever_agent = Agent() + + task_manager = TaskManager( + config, + updated_file_content, + validators=[MavenCompileStep(config)], + agents=[whatever_agent], + ) + # has a list of files affected and unprocessed + # has a list of registered validators + # has a list of current validation errors + + for task in task_manager.get_next_task(): + task_manager.supply_result(task_manager.execute_task(task)) + # all current failing validations and all currently affected AND UNDEALT + # WITH files + + # Can do revalidation, or use cached results or whatever + + +class TaskManager: + def __init__( + self, + config: RpcClientConfig, + updated_file_content: UpdatedFileContent, + validators: Optional[list[ValidationStep]] = None, + agents: Optional[list[Agent]] = None, + ) -> None: + + # TODO: Files maybe could become unprocessed again, but that could lead + # to infinite looping really hard, so we're avoiding it for now. Could + # even have like a MAX_DEPTH or something. + self.processed_files: list[Path] = [] + self.unprocessed_files: list[Path] = [] + + self.validators: list[ValidationStep] = [] + if validators is not None: + self.validators.extend(validators) + + self.agents: list[Agent] = [] + if agents is not None: + self.agents.extend(agents) + + self.config = config + + self._validators_are_stale = True + + # TODO: Modify the inputs to this class accordingly + # We want all the context that went in and the result that came out too + # updated_file_content. + + # TODO: Actually add the paths to processed and unprocessed files. + + def execute_task(self, task: Task) -> TaskResult: + return self.get_agent_for_task(task).execute_task(task) + + def get_agent_for_task(self, task: Task) -> Agent: + for agent in self.agents: + if agent.can_handle_task(task): + return agent + + raise Exception("No agent available for this task") + + def supply_result(self, result: TaskResult) -> None: + # One result is the filesystem changes + # SUCCESS + # - Did something, modified file system -> Recompute + # - Did nothing -> Go to next task + + # another is that the agent failed + # FAILURE + # - Did it give us more info to feed back to the repo context + # - It failed and gave us nothing -> >:( + + for file_path in result.modified_files: + if file_path not in self.unprocessed_files: + self.unprocessed_files.append(file_path) + self._validators_are_stale = True + + if len(result.encountered_errors) > 0: + raise NotImplementedError("What should we do with errors?") + + def run_validators(self) -> list[tuple[type, str]]: + # NOTE: Do it this way so that in the future we could do something + # like get all the errors in an affected file and then send THAT as + # a task, versus locking us into, one validation error per task at a + # time. i.e. Make it soe wae can combine validation errors into + # single tasks. Or grabbing all errors of a type, or all errors + # referencing a specific type. + # + # Basically, we're surfacing this functionality cause this whole. + # process is going to get more complicated in the future. + validation_errors: list[tuple[type, str]] = [] + + for validator in self.validators: + result = validator.run() + if not result.passed: + validation_errors.extend((type(validator), e) for e in result.errors) + + self._validators_are_stale = False + + return validation_errors + + def get_next_task(self) -> Generator[Task, Any, None]: + validation_errors: list[tuple[type, str]] = [] + + # Check to see if validators are stale. If so, run them + while True: + if self._validators_are_stale: + validation_errors = self.run_validators() + + # pop an error of the stack of errors + if len(validation_errors) > 0: + err = validation_errors.pop(0) + yield err # TODO: This is a placeholder + continue + + if len(self.unprocessed_files) > 0: + yield Task(self.unprocessed_files.pop(0)) + continue + + break + + +if __name__ == "__main__": + with __import__("ipdb").launch_ipdb_on_exception(): + main() diff --git a/playpen/repo_level_awareness/maven_validator.py b/playpen/repo_level_awareness/maven_validator.py new file mode 100755 index 00000000..023f0d22 --- /dev/null +++ b/playpen/repo_level_awareness/maven_validator.py @@ -0,0 +1,253 @@ +#!/usr/bin/env python + +import re +import subprocess +from dataclasses import dataclass, field +from typing import List, Optional, Type + +from api import ValidationError, ValidationResult, ValidationStep + + +class MavenCompileStep(ValidationStep): + + def run(self) -> ValidationResult: + maven_output = run_maven(self.config.repo_directory) + errors = parse_maven_output(maven_output) + return ValidationResult(passed=not errors, errors=errors) + + +@dataclass +class MavenCompilerError(ValidationError): + file: str + line: int + column: int + message: str + details: List[str] = field(default_factory=list) + parse_lines: Optional[str] = None + + @classmethod + def from_match(cls, match, details): + """ + Factory method to create an instance from a regex match. + """ + file_path = match.group(1).strip() + line_number = int(match.group(2)) + column_number = int(match.group(3)) + message = match.group(4).strip() + return cls( + file=file_path, + line=line_number, + column=column_number, + message=message, + details=details.copy(), + ) + + +# Subclasses for specific error categories +@dataclass +class SymbolNotFoundError(MavenCompilerError): + missing_symbol: Optional[str] = None + symbol_location: Optional[str] = None + + +@dataclass +class PackageDoesNotExistError(MavenCompilerError): + missing_package: Optional[str] = None + + +@dataclass +class SyntaxError(MavenCompilerError): + pass + + +@dataclass +class TypeMismatchError(MavenCompilerError): + expected_type: Optional[str] = None + found_type: Optional[str] = None + + +@dataclass +class AnnotationError(MavenCompilerError): + pass + + +@dataclass +class AccessControlError(MavenCompilerError): + inaccessible_class: Optional[str] = None + + +@dataclass +class OtherError(MavenCompilerError): + pass + + +def run_maven(source_directory=".") -> str: + """ + Runs 'mvn compile' and returns the combined stdout and stderr output. + """ + cmd = ["mvn", "compile"] + try: + process = subprocess.run( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + check=False, + cwd=source_directory, + ) + return process.stdout + except FileNotFoundError: + print("Maven is not installed or not found in the system PATH.") + return "" + + +def classify_error(message: str) -> Type[MavenCompilerError]: + """ + Classifies an error message and returns the corresponding error class. + """ + message_lower = message.lower() + if "cannot find symbol" in message_lower: + return SymbolNotFoundError + elif message_lower.startswith("package ") and message_lower.endswith( + " does not exist" + ): + return PackageDoesNotExistError + elif "class, interface, or enum expected" in message_lower: + return SyntaxError + elif "incompatible types" in message_lower: + return TypeMismatchError + elif ( + "method does not override or implement a method from a supertype" + in message_lower + ): + return AnnotationError + elif "cannot access" in message_lower: + return AccessControlError + else: + return OtherError + + +def parse_maven_output(output: str) -> List[MavenCompilerError]: + """ + Parses the Maven output and returns a list of MavenCompilerError instances. + """ + errors: List[MavenCompilerError] = [] + lines = output.splitlines() + in_compilation_error_section = False + error_pattern = re.compile(r"\[ERROR\] (.+?):\[(\d+),(\d+)\] (.+)") + current_error: Optional[MavenCompilerError] = None + + acc = [] + for i, line in enumerate(lines): + if "[ERROR] COMPILATION ERROR :" in line: + in_compilation_error_section = True + continue + if in_compilation_error_section: + if line.startswith("[INFO] BUILD FAILURE"): + in_compilation_error_section = False + continue + if ( + line.startswith("[INFO]") + or line.startswith("[WARNING]") + or line.strip() == "" + ): + # TODO what to do with these? + continue + # Match error lines with file path, line, column, and message + match = error_pattern.match(line) + if match: + acc.append(line) + error_class = classify_error(match.group(4)) + current_error = error_class.from_match(match, []) + # Look ahead for details + details = [] + j = i + 1 + while j < len(lines) and lines[j].startswith(" "): + acc.append(lines[j]) + detail_line = lines[j].strip("[ERROR] ").strip() + details.append(detail_line) + j += 1 + current_error.details.extend(details) + # Extract additional information based on error type + if isinstance(current_error, SymbolNotFoundError): + for detail in current_error.details: + if "symbol:" in detail: + current_error.missing_symbol = detail.split("symbol:")[ + -1 + ].strip() + if "location:" in detail: + current_error.symbol_location = detail.split("location:")[ + -1 + ].strip() + elif isinstance(current_error, PackageDoesNotExistError): + current_error.missing_package = ( + current_error.message.split("package")[-1] + .split("does not exist")[0] + .strip() + ) + elif isinstance(current_error, TypeMismatchError): + for detail in current_error.details: + if "required:" in detail: + current_error.expected_type = detail.split("required:")[ + -1 + ].strip() + if "found:" in detail: + current_error.found_type = detail.split("found:")[ + -1 + ].strip() + elif isinstance(current_error, AccessControlError): + current_error.inaccessible_class = current_error.message.split( + "cannot access" + )[-1].strip() + current_error.parse_lines = "\n".join(acc) + errors.append(current_error) + acc = [] + else: + continue # Line does not match error pattern + return errors + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run Maven compile in a specified source directory." + ) + parser.add_argument( + "source_directory", help="The directory where 'mvn compile' should be run." + ) + args = parser.parse_args() + maven_output = run_maven(args.source_directory) + + results = {} + for error in parse_maven_output(maven_output): + if not results.get(error.file): + results[error.file] = [error] + else: + results[error.file].append(error) + for file, errors in results.items(): + errs_line = f"| Errors for file: {file} |" + print(errs_line) + print("-" * len(errs_line)) + for error in errors: + # print(f"File: {error.file}") + print(f"Line: {error.line}, Column: {error.column}") + print(f"Type: {type(error).__name__}") + print(f"Message: {error.message}") + if isinstance(error, SymbolNotFoundError): + print(f"Missing Symbol: {error.missing_symbol}") + print(f"Symbol Location: {error.symbol_location}") + elif isinstance(error, PackageDoesNotExistError): + print(f"Missing Package: {error.missing_package}") + elif isinstance(error, TypeMismatchError): + print(f"Expected Type: {error.expected_type}") + print(f"Found Type: {error.found_type}") + elif isinstance(error, AccessControlError): + print(f"Inaccessible Class: {error.inaccessible_class}") + if error.details: + print("Details:") + for detail in error.details: + print(f" {detail}") + print("Source lines:") + print(error.parse_lines) + print("-" * 40)