diff --git a/promptshell/ai_terminal_assistant.py b/promptshell/ai_terminal_assistant.py index ee381be..7ece319 100644 --- a/promptshell/ai_terminal_assistant.py +++ b/promptshell/ai_terminal_assistant.py @@ -13,6 +13,13 @@ class AITerminalAssistant: def __init__(self, model_name: str, max_tokens: int = 8000, config: dict = None): + """Initializes the AITerminalAssistant with the given model name and configuration. + + Args: + model_name (str): The name of the model to use for command execution. + max_tokens (int, optional): The maximum number of tokens for the model. Defaults to 8000. + config (dict, optional): Additional configuration settings. Defaults to None. + """ self.username = getpass.getuser() self.home_folder = os.path.expanduser("~") self.current_directory = os.getcwd() @@ -28,6 +35,7 @@ def __init__(self, model_name: str, max_tokens: int = 8000, config: dict = None) self.initialize_system_context() def initialize_system_context(self): + """Initializes the system context by gathering installed commands and system information.""" path_dirs = os.environ.get('PATH', '').split(os.pathsep) installed_commands = [] for dir in path_dirs: @@ -105,7 +113,7 @@ def initialize_system_context(self): Solution: Confirm file existence with 'ls' Alternative: Use trash-cli instead of rm """ - + self.debugger.definition = f""" [ROLE] Shell Environment Debugger [TASK] Diagnose complex system issues @@ -154,6 +162,14 @@ def initialize_system_context(self): """ def execute_command_with_live_output(self, command: str) -> Tuple[str, str, int]: + """Executes a shell command and captures its live output. + + Args: + command (str): The shell command to execute. + + Returns: + Tuple[str, str, int]: A tuple containing the standard output, standard error, and exit code. + """ interactive_commands = [ 'vim', 'vi', 'nano', 'emacs', 'ssh', 'telnet', 'top', 'htop', 'man', 'less', 'more', 'mysql', 'psql', 'nmtui', 'crontab', @@ -183,6 +199,14 @@ def execute_command_with_live_output(self, command: str) -> Tuple[str, str, int] return "", str(e), 1 def execute_interactive_command(self, command: str) -> Tuple[str, str, int]: + """Executes an interactive shell command. + + Args: + command (str): The interactive shell command to execute. + + Returns: + Tuple[str, str, int]: A tuple containing the standard output, standard error, and exit code. + """ print(format_text('yellow') + "Executing interactive command..." + reset_format()) try: proc = subprocess.Popen( @@ -201,6 +225,14 @@ def execute_interactive_command(self, command: str) -> Tuple[str, str, int]: return "", str(e), 1 def execute_command(self, user_input: str) -> str: + """Executes a command based on user input. + + Args: + user_input (str): The command or question input by the user. + + Returns: + str: The result of the command execution or an error message. + """ try: self.current_directory = os.getcwd() if user_input.strip() == "": @@ -260,6 +292,14 @@ def execute_command(self, user_input: str) -> str: return self.handle_error(str(e), user_input, command) def run_direct_command(self, command: str) -> str: + """Executes a direct shell command provided by the user. + + Args: + command (str): The shell command to execute. + + Returns: + str: The result of the command execution or an error message. + """ try: formatted_command = format_text('cyan') + f"Direct Command: {command}" + reset_format() print(formatted_command) @@ -287,6 +327,14 @@ def run_direct_command(self, command: str) -> str: return self.handle_error(str(e), command, command) def answer_question(self, question: str) -> str: + """Answers a user-provided question based on the current context. + + Args: + question (str): The question to answer. + + Returns: + str: The answer to the question. + """ context = f""" Command History (last 10 commands): {', '.join(self.command_history)} @@ -301,6 +349,14 @@ def answer_question(self, question: str) -> str: return format_text('cyan') + "Answer:\n" + answer + reset_format() def gather_additional_data(self, user_input: str) -> dict: + """Gathers additional data based on the user's input, such as clipboard content or file data. + + Args: + user_input (str): The user's input to analyze for additional data needs. + + Returns: + dict: A dictionary containing additional data, such as clipboard content or file content. + """ additional_data = {} if "clipboard" in user_input.lower(): clipboard_content = self.data_gatherer.get_clipboard_content() @@ -318,6 +374,16 @@ def gather_additional_data(self, user_input: str) -> dict: return additional_data def debug_error(self, command: str, error_output: str, exit_code: int) -> str: + """Analyzes a failed command and provides debugging suggestions. + + Args: + command (str): The command that failed. + error_output (str): The error output from the failed command. + exit_code (int): The exit code of the failed command. + + Returns: + str: A debugging suggestion or alternative command. + """ context = f""" Command History (last 10 commands): {', '.join(self.command_history)} @@ -335,6 +401,16 @@ def debug_error(self, command: str, error_output: str, exit_code: int) -> str: return self.debugger(debug_input) def handle_error(self, error: str, user_input: str, command: str) -> str: + """Handles errors by analyzing the issue and suggesting a corrected command. + + Args: + error (str): The error message. + user_input (str): The original user input. + command (str): The interpreted command that caused the error. + + Returns: + str: The result of executing the suggested command or an error message. + """ error_analysis = self.error_handler(f""" Error: {error} User Input: {user_input} diff --git a/promptshell/data_gatherer.py b/promptshell/data_gatherer.py index 52b63a5..63d920e 100644 --- a/promptshell/data_gatherer.py +++ b/promptshell/data_gatherer.py @@ -4,6 +4,11 @@ class DataGatherer: @staticmethod def get_clipboard_content(): + """Retrieves the current content of the clipboard. + + Returns: + str: The content of the clipboard or an error message if unable to access it. + """ try: return pyperclip.paste() except: @@ -11,6 +16,14 @@ def get_clipboard_content(): @staticmethod def get_file_content(file_path): + """Reads the content of a specified file. + + Args: + file_path (str): The path to the file to read. + + Returns: + str: The content of the file or an error message if reading fails. + """ try: with open(file_path, 'r') as file: return file.read() @@ -19,6 +32,14 @@ def get_file_content(file_path): @staticmethod def execute_command(command): + """Executes a shell command and captures its output. + + Args: + command (str): The shell command to execute. + + Returns: + str: The standard output of the command or an error message if execution fails. + """ try: result = subprocess.run(command, capture_output=True, text=True, shell=True) return result.stdout if result.returncode == 0 else f"Error: {result.stderr}" diff --git a/promptshell/format_utils.py b/promptshell/format_utils.py index 8929ff6..8e5bc1b 100644 --- a/promptshell/format_utils.py +++ b/promptshell/format_utils.py @@ -2,6 +2,17 @@ import platform def format_text(fg, bg=None, inverted=False, bold=False): + """Formats text for terminal output with specified foreground and background colors. + + Args: + fg (str): The foreground color (e.g., 'red', 'green'). + bg (str, optional): The background color (e.g., 'black', 'white'). Defaults to None. + inverted (bool, optional): If True, inverts the foreground and background colors. Defaults to False. + bold (bool, optional): If True, makes the text bold. Defaults to False. + + Returns: + str: The formatted text string with ANSI escape codes. + """ reset = "\033[0m" result = reset if bold: @@ -18,9 +29,19 @@ def format_text(fg, bg=None, inverted=False, bold=False): return result def reset_format(): + """Resets the text formatting to default. + + Returns: + str: The ANSI escape code to reset formatting. + """ return "\033[0m" def get_terminal_size(): + """Retrieves the current size of the terminal window. + + Returns: + tuple: A tuple containing the number of columns and rows in the terminal. + """ try: columns, rows = os.get_terminal_size(0) except OSError: @@ -28,12 +49,20 @@ def get_terminal_size(): return columns, rows def get_current_os(): - """Detect and normalize current operating system""" + """Detects and normalizes the current operating system. + + Returns: + str: The name of the current operating system ('windows', 'macos', or 'linux'). + """ system = platform.system().lower() return 'windows' if system == 'windows' else 'macos' if system == 'darwin' else 'linux' def get_os_specific_examples(): - """Return OS-appropriate command examples""" + """Returns OS-appropriate command examples based on the current operating system. + + Returns: + list: A list of command examples specific to the current OS. + """ current_os = get_current_os() examples = { 'windows': [ diff --git a/promptshell/main.py b/promptshell/main.py index a8dfb90..3c83f38 100644 --- a/promptshell/main.py +++ b/promptshell/main.py @@ -7,6 +7,20 @@ from .setup import setup_wizard, load_config, get_active_model def main(): + """The main entry point for the AI-Powered Terminal Assistant. + + This function initializes the assistant, loads the configuration, and starts the interactive loop + for processing user input. It handles commands, questions, and configuration updates. + + Behavior: + - Loads configuration or runs a setup wizard if no configuration exists. + - Enables ANSI support and sets up readline for command-line enhancements. + - Processes user input for commands, questions, or special options like '--help' or '--config'. + - Provides a clean exit on 'quit' or 'Ctrl + c'. + + Returns: + None + """ config = load_config() if not config: print("First-time setup required!") diff --git a/promptshell/node.py b/promptshell/node.py index 2b07c70..2af0677 100644 --- a/promptshell/node.py +++ b/promptshell/node.py @@ -9,6 +9,14 @@ class Node: def __init__(self, model_name: str, name: str, max_tokens: int = 8192, config: dict = None): + """Initializes a Node instance for interacting with AI models. + + Args: + model_name (str): The name of the AI model to use. + name (str): The name of the node. + max_tokens (int, optional): The maximum number of tokens for the model's response. Defaults to 8192. + config (dict, optional): Configuration settings for the node. Defaults to None. + """ self.model_name = model_name self.name = name self.definition = "" @@ -18,6 +26,15 @@ def __init__(self, model_name: str, name: str, max_tokens: int = 8192, config: d self.provider = get_provider() def __call__(self, input_text: str, additional_data: dict = None): + """Processes user input and generates a response using the configured AI provider. + + Args: + input_text (str): The input text or query from the user. + additional_data (dict, optional): Additional context or data to include in the prompt. Defaults to None. + + Returns: + str: The AI-generated response or an error message. + """ try: context_str = "\n".join([f"{msg['role']} {msg['content']}" for msg in self.context]) prompt = f""" system {self.definition} @@ -58,6 +75,14 @@ def __call__(self, input_text: str, additional_data: dict = None): return f"Error in processing: {str(e)}" def _call_ollama(self, prompt: str) -> str: + """Handles API calls for the Ollama provider. + + Args: + prompt (str): The prompt to send to the Ollama API. + + Returns: + str: The response from the Ollama API or an error message. + """ response = requests.post( 'http://localhost:11434/api/generate', json={ @@ -77,6 +102,14 @@ def _call_ollama(self, prompt: str) -> str: def _call_openai(self, prompt: str) -> str: + """Handles API calls for the OpenAI provider. + + Args: + prompt (str): The prompt to send to the OpenAI API. + + Returns: + str: The response from the OpenAI API or an error message. + """ api_key = self.config["OPENAI_API_KEY"] client = OpenAI(api_key=api_key) response = client.chat.completions.create( @@ -86,6 +119,14 @@ def _call_openai(self, prompt: str) -> str: return response.choices[0].message.content.strip() def _call_anthropic(self, prompt: str) -> str: + """Handles API calls for the Anthropic provider. + + Args: + prompt (str): The prompt to send to the Anthropic API. + + Returns: + str: The response from the Anthropic API or an error message. + """ api_key = self.config["ANTHROPIC_API_KEY"] client = anthropic.Anthropic(api_key=api_key) response = client.messages.create( @@ -96,6 +137,14 @@ def _call_anthropic(self, prompt: str) -> str: return response.content[0].text.strip() def _call_google(self, prompt: str) -> str: + """Handles API calls for the Google Generative AI provider. + + Args: + prompt (str): The prompt to send to the Google Generative AI API. + + Returns: + str: The response from the Google Generative AI API or an error message. + """ api_key = self.config["GOOGLE_API_KEY"] genai.configure(api_key=api_key) model = genai.GenerativeModel(self.model_name) @@ -103,6 +152,14 @@ def _call_google(self, prompt: str) -> str: return response.text.strip() def _call_groq(self, prompt: str) -> str: + """Handles API calls for the Groq provider. + + Args: + prompt (str): The prompt to send to the Groq API. + + Returns: + str: The response from the Groq API or an error message. + """ api_key = self.config["GROQ_API_KEY"] client = Groq(api_key=api_key) @@ -129,7 +186,14 @@ def _call_groq(self, prompt: str) -> str: return response_json["command"].strip() def _call_fireworks(self, prompt: str) -> str: - """Handle API calls for Fireworks AI provider""" + """Handles API calls for the Fireworks AI provider. + + Args: + prompt (str): The prompt to send to the Fireworks API. + + Returns: + str: The response from the Fireworks API or an error message. + """ api_key = self.config["FIREWORKS_API_KEY"] client = OpenAI( api_key=api_key, @@ -143,7 +207,14 @@ def _call_fireworks(self, prompt: str) -> str: return response.choices[0].message.content.strip() def _call_openrouter(self, prompt: str) -> str: - """Handle API calls for OpenRouter provider""" + """Handles API calls for the OpenRouter provider. + + Args: + prompt (str): The prompt to send to the OpenRouter API. + + Returns: + str: The response from the OpenRouter API or an error message. + """ api_key = self.config["OPENROUTER_API_KEY"] client = OpenAI( base_url="https://openrouter.ai/api/v1", @@ -161,7 +232,14 @@ def _call_openrouter(self, prompt: str) -> str: return response.choices[0].message.content.strip() def _call_deepseek(self, prompt: str) -> str: - """Handle API calls for DeepSeek provider""" + """Handles API calls for the DeepSeek provider. + + Args: + prompt (str): The prompt to send to the DeepSeek API. + + Returns: + str: The response from the DeepSeek API or an error message. + """ api_key = self.config["DEEPSEEK_API_KEY"] client = OpenAI( api_key=api_key, diff --git a/promptshell/readline_setup.py b/promptshell/readline_setup.py index d11c3e1..9f88ef3 100644 --- a/promptshell/readline_setup.py +++ b/promptshell/readline_setup.py @@ -3,6 +3,15 @@ import os def setup_readline(): + """Sets up readline for command-line input, enabling tab completion. + + This function attempts to import the readline module for Unix-like systems. + If the import fails, it tries to use pyreadline3 for Windows. If that also fails, + it attempts to use prompt_toolkit as an alternative for tab completion. + + Returns: + callable: A function for completing paths if using prompt_toolkit, otherwise None. + """ try: import readline # Works on Unix-like systems except ImportError: diff --git a/promptshell/setup.py b/promptshell/setup.py index f1bf39c..2d7a459 100644 --- a/promptshell/setup.py +++ b/promptshell/setup.py @@ -16,6 +16,15 @@ warning_printed = False # Global variable to track if the warning has been printed def setup_wizard(): + """Runs the setup wizard to configure the PromptShell environment. + + This function guides the user through selecting an operation mode (local or API), + choosing models, and entering API keys if necessary. It ensures the configuration + is saved to a file for future use. + + Returns: + None + """ # Load existing configuration config = load_config() @@ -193,9 +202,12 @@ def get_installed_models(): print(format_text("blue") + f"Active model: {get_active_model()}" + reset_format()) def load_config(): - """ - Loads the configuration file into a dictionary. - Returns default values if the file is missing or incomplete. + """Loads the configuration file into a dictionary. + + If the configuration file is missing or incomplete, default values are returned. + + Returns: + dict: A dictionary containing the configuration settings. """ global warning_printed @@ -248,6 +260,13 @@ def get_active_model(): return config["API_MODEL"] def get_provider(): + """Determines the active provider based on the current configuration. + + If the mode is 'api', it returns the active API provider. Otherwise, it defaults to 'ollama'. + + Returns: + str: The name of the active provider. + """ config = load_config() if config["MODE"] == "api": return config["ACTIVE_API_PROVIDER"]