diff --git a/docs/lammps_agent.md b/docs/lammps_agent.md new file mode 100644 index 00000000..538303b1 --- /dev/null +++ b/docs/lammps_agent.md @@ -0,0 +1,191 @@ +# LammpsAgent Documentation + +`LammpsAgent` is a class that helps set up and run a LAMMPS simulation workflow. At the highest level, it can: + +- discover candidate interatomic potentials from the NIST database for a set of elements, +- summarize and choose a potential for the simulation task at hand, +- author a LAMMPS input script using the chosen potential (and an optional template / data file), +- execute LAMMPS via MPI (CPU or Kokkos GPU), +- iteratively “fix” the input script on failures by using run history until success or a max attempt limit. + +The agent writes the outputs into a local `workspace` directory and uses rich console panels to display progress, choices, diffs, and errors. + +--- + +## Dependencies + +The main dependency is the [LAMMPS](https://www.lammps.org) code that needs to be separately installed. LAMMPS is a classical molecular dynamics code developed by Sandia National Laboratories. Installation instructions can be found [here](https://docs.lammps.org/Install.html). On MacOS and Linux systems, the simplest way to install LAMMPS is often via [Conda](https://anaconda.org/channels/conda-forge/packages/lammps/overview), in the same conda environment where `ursa` is installed. + +One additional dependency, that are not installed along with `ursa`, is `atomman`. This can be installed via `pip install atomman`. + +--- + +## Basic Usage + +```python +from ursa.agents import LammpsAgent +from langchain_openai import ChatOpenAI + +agent = LammpsAgent(llm = ChatOpenAI(model='gpt-5')) + +result = agent.invoke({ + "simulation_task": "Carry out a LAMMPS simulation of Cu to determine its equation of state.", + "elements": ["Cu"], + "template": "No template provided." #Template for the input file +}) +``` + +For more advanced usage see examples here: `ursa/examples/two_agent_examples/lammps_execute/`. + +--- + +## High-level flow + +The agent compiles a `StateGraph(LammpsState)` with this logic: + +### Entry routing +Chooses one of three paths: + +1. **User-provided potential**: + - This path is chosen when the user provides a specific potential file, along with the `pair_style`/`pair_coeff` information required to generate the input script + - In this case the autonomous potential search/selection by the agent is skipped + - The provided potential file is copied to `workspace` + +2. **User-chosen potential already in state** (`state["chosen_potential"]` exists): + - This is similar to the above path, but the user selects a potential from the `atomman` database and initializes the state with this entry before invoking the agent + - This path also skips the potential search/selection and goes straight to authoring a LAMMPS input script for the user-chosen potential + +3. **Agent-selected potential**: + - Agent queries NIST (via atomman) for potentials matching the requested elements + - Summarizes NIST's data on each potential (up to `max_potentials`) with regards to the applicability of the potential for the given `simulation task` + - Ultimately picks one potential + +If a `data_file` is provided to the agent, the entry router attempts to copy it into the workspace. + +### Potential search & selection (agent-selected path) +- `_find_potentials`: queries `atomman.library.Database(remote=True)` for potentials matching: + - `elements` from state + - supported `pair_styles` list (see `self.pair_styles`) +- `_summarize_one`: for each candidate potential: + - extracts data on potential from NIST + - trims extracted text to a token budget using `tiktoken` + - summarizes usefulness for the requested `simulation_task` + - writes summary to `workspace/potential_summaries/potential_.txt` +- `_build_summaries`: builds a combined string of summaries for selection +- `_choose`: the agent selects the final potential to be used and the rationale for choosing it + - writes rationale to `workspace/potential_summaries/Rationale.txt` + - stores `chosen_potential` in state + +If `find_potential_only=True`, the graph exits after choosing the potential (or finding no matches). + +### Author input +- Downloads potential files into `workspace` (only if not user-provided) +- Gets `pair_info` via `chosen_potential.pair_info()` +- Optionally includes: + - `template` from state for the LAMMPS input script + - `data_file` (usually for the atomic structure that can be included in the input script) +- The agent authors the input script: `{ "input_script": "" }` +- Writes `workspace/in.lammps` +- Enforces that logs should go to `./log.lammps` + +### Run LAMMPS + +Runs `` with `-np ` in `workspace`: + +Allowed options for `` are `mpirun` and `mpiexec` (see also Parameters section below). + +For example, LAMMPS run commands executed by the agent look like: + +- **CPU mode** (default, when `ngpus < 0`): + - `mpirun -np -in in.lammps` + +- **GPU/Kokkos mode** (when `ngpus >= 0`): + - `mpirun -np -in in.lammps -k on g -sf kk -pk kokkos neigh half newton on` + +Note that the running under GPU mode is preliminary. + +The agent captures `stdout`, `stderr`, and `returncode`, and appends an entry to `run_history`. + +### Fix loop +If the run fails: +- formats the entire `run_history` (scripts + stdout/stderr) into an error blob +- the agent produces a new `input_script` +- prints a unified diff between old and new scripts +- overwrites `workspace/in.lammps` +- increments `fix_attempts` +- reruns LAMMPS + +Stops when: +- run succeeds (`returncode == 0`), or +- `fix_attempts >= max_fix_attempts` + +--- + +## State model (`LammpsState`) + +The graph state is a `TypedDict` containing (key fields): + +- **Inputs / problem definition** + - `simulation_task: str` — natural language description of what to simulate + - `elements: list[str]` — chemical symbols used to identify candidate potentials + - `template: Optional[str]` — optional LAMMPS input template to adapt + - `chosen_potential: Optional[Any]` — selected potential object (user-chosen) + +- **Potential selection internals** + - `matches: list[Any]` — candidate potentials from atomman + - `idx: int` — index used for summarization loop + - `summaries: list[str]` — a brief summary of each potential + - `full_texts: list[str]` — the data/metadata on the potential from NIST (capped at `max_tokens`) + - `summaries_combined: str` - a single string with the summaries of all the considered potentials + +- **Run artifacts** + - `input_script: str` — current LAMMPS input text written to `in.lammps` + - `run_returncode: Optional[int]` - generally, `returncode = 0` indicates a successful simulation run + - `run_stdout: str` - the stdout from the LAMMPS execution + - `run_stderr: str` - the stderr from the LAMMPS execution + - `run_history: list[dict[str, Any]]` — attempt-by-attempt record + - `fix_attempts: int` - the number of times the agent has attempted to fix the LAMMPS input script + +--- + +## Parameters + +Key parameters you can tune: + +### Potential selection +- `potential_files`, `pair_style`, `pair_coeff`: if all provided, the agent uses the user's potential files and skips search +- `max_potentials` (default `5`): max number of candidate potentials to summarize before choosing one +- `find_potential_only` (default `False`): exit after selecting a potential (no input LAMMPS input writing/running) + +### Fix loop +- `max_fix_attempts` (default `10`): maximum number of input rewrite attempts after failures + +### Data file support +- `data_file` (default `None`): path to a LAMMPS data file; the agent copies it to `workspace` +- `data_max_lines` (default `50`): number of lines from data included in the agent's prompt + +### Execution +- `workspace` (default `./workspace`): where `in.lammps`, potentials, and summaries are written +- `mpi_procs` (default `8`): number of mpi processes for LAMMPS run +- `ngpus` (default `-1`): set `>= 0` to enable Kokkos GPU flags +- `lammps_cmd` (default `lmp_mpi`): the name of the LAMMPS executable to launch +- `mpirun_cmd` (default `mpirun`): currently available options are `mpirun` and `mpiexec`. Other options such as `srun` will be added soon + +### LLM / context trimming +- `tiktoken_model` (default `gpt-5-mini`): tokenizer model name used to trim fetched potential metadata text +- `max_tokens` (default `200000`): token cap for extracted metadata text + +--- + +## Files and directories created + +Inside `workspace/`: + +- `in.lammps` — generated/updated input script +- `log.lammps` — expected LAMMPS log output (the LLM is instructed to create it) +- `potential_summaries/` + - `potential_.txt` — per-potential LLM summaries + - `Rationale.txt` — rationale for the selected potential +- downloaded potential files (from atomman or copied from user paths) +- copied `data_file` (if provided) + diff --git a/examples/two_agent_examples/lammps_execute/EOS_of_Cu.py b/examples/two_agent_examples/lammps_execute/EOS_of_Cu.py index 104a7376..dd20577f 100644 --- a/examples/two_agent_examples/lammps_execute/EOS_of_Cu.py +++ b/examples/two_agent_examples/lammps_execute/EOS_of_Cu.py @@ -1,8 +1,9 @@ -import asyncio - from langchain_openai import ChatOpenAI +from rich import get_console + +from ursa.agents import LammpsAgent -from ursa.agents import ExecutionAgent, LammpsAgent +console = get_console() try: import atomman as am @@ -19,51 +20,40 @@ workspace = "./workspace_eos_cu" - -async def main(): - wf = LammpsAgent( - llm=llm, - max_potentials=2, - max_fix_attempts=5, - find_potential_only=False, - mpi_procs=8, - workspace=workspace, - lammps_cmd="lmp_mpi", - mpirun_cmd="mpirun", - ) - - with open("eos_template.txt", "r") as file: - template = file.read() - - simulation_task = "Carry out a LAMMPS simulation of Cu to determine its equation of state." - - elements = ["Cu"] - - db = am.library.Database(remote=True) - matches = db.get_lammps_potentials(pair_style=["eam"], elements=elements) - chosen_potential = matches[-1] - - final_lammps_state = await wf.ainvoke( - simulation_task=simulation_task, - elements=elements, - template=template, - chosen_potential=chosen_potential, +wf = LammpsAgent( + llm=llm, + max_potentials=2, + max_fix_attempts=5, + find_potential_only=False, + ngpus=-1, # if -1 will not use gpus. Lammps executable must be installed with kokkos package for gpu usage + mpi_procs=8, + workspace=workspace, + lammps_cmd="lmp_mpi", + mpirun_cmd="mpirun", + summarize_results=True, +) + +with open("eos_template.txt", "r") as file: + template = file.read() + +simulation_task = ( + "Carry out a LAMMPS simulation of Cu to determine its equation of state." +) + +elements = ["Cu"] + +db = am.library.Database(remote=True) +matches = db.get_lammps_potentials(pair_style=["eam"], elements=elements) +chosen_potential = matches[-1] + +final_lammps_state = wf.invoke( + simulation_task=simulation_task, + elements=elements, + template=template, + chosen_potential=chosen_potential, +) + +if final_lammps_state.get("run_returncode") == 0: + console.print( + "\n[green]LAMMPS Workflow completed successfully.[/green] Exiting....." ) - - if final_lammps_state.get("run_returncode") == 0: - print("\nNow handing things off to execution agent.....") - - executor = ExecutionAgent(llm=llm, workspace=workspace) - exe_plan = f""" - You are part of a larger scientific workflow whose purpose is to accomplish this task: {simulation_task} - - A LAMMPS simulation has been done and the output is located in the file 'log.lammps'. - - Summarize the contents of this file in a markdown document. Include a plot, if relevent. - """ - - _ = await executor.ainvoke(exe_plan) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/two_agent_examples/lammps_execute/stiffness_tensor_of_hea.py b/examples/two_agent_examples/lammps_execute/stiffness_tensor_of_hea.py index c85d7f35..1cb70eac 100644 --- a/examples/two_agent_examples/lammps_execute/stiffness_tensor_of_hea.py +++ b/examples/two_agent_examples/lammps_execute/stiffness_tensor_of_hea.py @@ -1,8 +1,9 @@ -import asyncio - from langchain_openai import ChatOpenAI +from rich import get_console + +from ursa.agents import LammpsAgent -from ursa.agents import ExecutionAgent, LammpsAgent +console = get_console() model = "gpt-5" @@ -10,44 +11,33 @@ workspace = "./workspace_stiffness_tensor" - -async def main(): - wf = LammpsAgent( - llm=llm, - max_potentials=5, - max_fix_attempts=15, - find_potential_only=False, - mpi_procs=8, - workspace=workspace, - lammps_cmd="lmp_mpi", - mpirun_cmd="mpirun", - ) - - with open("elastic_template.txt", "r") as file: - template = file.read() - - simulation_task = "Carry out a LAMMPS simulation of the high entropy alloy Co-Cr-Fe-Mn-Ni to determine its stiffness tensor." - - elements = ["Co", "Cr", "Fe", "Mn", "Ni"] - - final_lammps_state = await wf.ainvoke( - simulation_task=simulation_task, elements=elements, template=template +wf = LammpsAgent( + llm=llm, + max_potentials=5, + max_fix_attempts=15, + find_potential_only=False, + mpi_procs=8, + workspace=workspace, + lammps_cmd="lmp_mpi", + mpirun_cmd="mpirun", + summarize_results=True, +) + +with open("elastic_template.txt", "r") as file: + template = file.read() + +simulation_task = ( + "Carry out a LAMMPS simulation of the high entropy alloy Co-Cr-Fe-Mn-Ni " + "to determine its stiffness tensor." +) + +elements = ["Co", "Cr", "Fe", "Mn", "Ni"] + +final_lammps_state = wf.invoke( + simulation_task=simulation_task, elements=elements, template=template +) + +if final_lammps_state.get("run_returncode") == 0: + console.print( + "\n[green]LAMMPS Workflow completed successfully.[/green] Exiting....." ) - - if final_lammps_state.get("run_returncode") == 0: - print("\nNow handing things off to execution agent.....") - - executor = ExecutionAgent(llm=llm, workspace=workspace) - exe_plan = f""" - You are part of a larger scientific workflow whose purpose is to accomplish this task: {simulation_task} - - A LAMMPS simulation has been done and the output is located in the file 'log.lammps'. - - Summarize the contents of this file in a markdown document. Include a plot, if relevent. - """ - - _ = await executor.ainvoke(exe_plan) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/src/ursa/agents/lammps_agent.py b/src/ursa/agents/lammps_agent.py index 5948778d..7ad4ba9d 100644 --- a/src/ursa/agents/lammps_agent.py +++ b/src/ursa/agents/lammps_agent.py @@ -1,3 +1,4 @@ +import difflib import json import os import subprocess @@ -5,9 +6,16 @@ import tiktoken from langchain.chat_models import BaseChatModel +from langchain_core.messages import HumanMessage from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langgraph.graph import END +from rich.console import Console +from rich.panel import Panel +from rich.rule import Rule +from rich.syntax import Syntax + +from ursa.agents.execution_agent import ExecutionAgent from .base import BaseAgent @@ -35,6 +43,7 @@ class LammpsState(TypedDict, total=False): run_returncode: Optional[int] run_stdout: str run_stderr: str + run_history: list[dict[str, Any]] fix_attempts: int @@ -45,30 +54,54 @@ class LammpsAgent(BaseAgent[LammpsState]): def __init__( self, llm: BaseChatModel, + potential_files: Optional[list[str]] = None, + pair_style: Optional[str] = None, + pair_coeff: Optional[str] = None, max_potentials: int = 5, max_fix_attempts: int = 10, find_potential_only: bool = False, + data_file: str = None, + data_max_lines: int = 50, + ngpus: int = -1, mpi_procs: int = 8, workspace: str = "./workspace", lammps_cmd: str = "lmp_mpi", mpirun_cmd: str = "mpirun", tiktoken_model: str = "gpt-5-mini", max_tokens: int = 200000, + summarize_results: bool = True, **kwargs, ): if not working: raise ImportError( "LAMMPS agent requires the atomman and trafilatura dependencies. These can be installed using 'pip install ursa-ai[lammps]' or, if working from a local installation, 'pip install -e .[lammps]' ." ) + super().__init__(llm, **kwargs) + + self.user_potential_files = potential_files + self.user_pair_style = pair_style + self.user_pair_coeff = pair_coeff + self.use_user_potential = ( + potential_files is not None + and pair_style is not None + and pair_coeff is not None + ) + self.max_potentials = max_potentials self.max_fix_attempts = max_fix_attempts self.find_potential_only = find_potential_only + self.data_file = data_file + self.data_max_lines = data_max_lines + self.ngpus = ngpus self.mpi_procs = mpi_procs self.lammps_cmd = lammps_cmd self.mpirun_cmd = mpirun_cmd self.tiktoken_model = tiktoken_model self.max_tokens = max_tokens + self.summarize_results = summarize_results + + self.console = Console() self.pair_styles = [ "eam", @@ -84,6 +117,9 @@ def __init__( "nep", ] + self.workspace = workspace + os.makedirs(self.workspace, exist_ok=True) + self.str_parser = StrOutputParser() self.summ_chain = ( @@ -118,9 +154,13 @@ def __init__( "Here is some information about the pair_style and pair_coeff that might be useful in writing the input file: {pair_info}.\n" "If a template for the input file is provided, you should adapt it appropriately to meet the task requirements.\n" "Template provided (if any): {template}\n" - "Ensure that all output data is written only to the './log.lammps' file. Do not create any other output file.\n" - "To create the log, use only the 'log ./log.lammps' command. Do not use any other command like 'echo' or 'screen'.\n" + "If a data file is provided, use it in the input script via the 'read_data' command.\n" + "Name of data file (if any): {data_file}\n" + "First few lines of data file (if any):\n{data_content}\n" + "Ensure that all logs are recorded in a './log.lammps' file.\n" + "To create the log file, use may use the 'log ./log.lammps' command. \n" "Return your answer **only** as valid JSON, with no extra text or formatting.\n" + "IMPORTANT: Properly escape all special characters in the input_script string (use \\n for newlines, \\\\ for backslashes, etc.).\n" "Use this exact schema:\n" "{{\n" ' "input_script": ""\n' @@ -133,17 +173,21 @@ def __init__( self.fix_chain = ( ChatPromptTemplate.from_template( "You are part of a larger scientific workflow whose purpose is to accomplish this task: {simulation_task}\n" - "For this purpose, this input file for LAMMPS was written: {input_script}\n" - "However, when running the simulation, an error was raised.\n" - "Here is the full stdout message that includes the error message: {err_message}\n" - "Your task is to write a new input file that resolves the error.\n" + "Multiple attempts at writing and running a LAMMPS input file have been made.\n" + "Here is the run history across attempts (each includes the input script and its stdout/stderr):{err_message}\n" + "Use the history to identify what changed between attempts and avoid repeating failed approaches.\n" + "Your task is to write a new input file that resolves the latest error.\n" "Note that all potential files are in the './' directory.\n" "Here is some information about the pair_style and pair_coeff that might be useful in writing the input file: {pair_info}.\n" "If a template for the input file is provided, you should adapt it appropriately to meet the task requirements.\n" "Template provided (if any): {template}\n" - "Ensure that all output data is written only to the './log.lammps' file. Do not create any other output file.\n" - "To create the log, use only the 'log ./log.lammps' command. Do not use any other command like 'echo' or 'screen'.\n" + "If a data file is provided, use it in the input script via the 'read_data' command.\n" + "Name of data file (if any): {data_file}\n" + "First few lines of data file (if any):\n{data_content}\n" + "Ensure that all logs are recorded in a './log.lammps' file.\n" + "To create the log file, use may use the 'log ./log.lammps' command. \n" "Return your answer **only** as valid JSON, with no extra text or formatting.\n" + "IMPORTANT: Properly escape all special characters in the input_script string (use \\n for newlines, \\\\ for backslashes, etc.).\n" "Use this exact schema:\n" "{{\n" ' "input_script": ""\n' @@ -153,6 +197,47 @@ def __init__( | self.str_parser ) + def _section(self, title: str): + self.console.print(Rule(f"[bold cyan]{title}[/bold cyan]")) + + def _panel(self, title: str, body: str, style: str = "cyan"): + self.console.print( + Panel(body, title=f"[bold]{title}[/bold]", border_style=style) + ) + + def _code_panel( + self, + title: str, + code: str, + language: str = "bash", + style: str = "magenta", + ): + syn = Syntax( + code, language, theme="monokai", line_numbers=True, word_wrap=True + ) + self.console.print( + Panel(syn, title=f"[bold]{title}[/bold]", border_style=style) + ) + + def _diff_panel(self, old: str, new: str, title: str = "LAMMPS input diff"): + diff = "\n".join( + difflib.unified_diff( + old.splitlines(), + new.splitlines(), + fromfile="in.lammps (before)", + tofile="in.lammps (after)", + lineterm="", + ) + ) + if not diff.strip(): + diff = "(no changes)" + syn = Syntax( + diff, "diff", theme="monokai", line_numbers=False, word_wrap=True + ) + self.console.print( + Panel(syn, title=f"[bold]{title}[/bold]", border_style="cyan") + ) + @staticmethod def _safe_json_loads(s: str) -> dict[str, Any]: s = s.strip() @@ -163,6 +248,73 @@ def _safe_json_loads(s: str) -> dict[str, Any]: s = s[i + 1 :].strip() return json.loads(s) + def _read_and_trim_data_file(self, data_file_path: str) -> str: + """Read LAMMPS data file and trim to token limit for LLM context.""" + if os.path.exists(data_file_path): + with open(data_file_path, "r") as f: + content = f.read() + lines = content.splitlines() + if len(lines) > self.data_max_lines: + content = "\n".join(lines[: self.data_max_lines]) + print( + f"Data file trimmed from {len(lines)} to {self.data_max_lines} lines" + ) + return content + else: + return "Could not read data file." + + def _copy_data_file(self, data_file_path: str) -> str: + """Copy data file to workspace and return new path.""" + if not os.path.exists(data_file_path): + raise FileNotFoundError(f"Data file not found: {data_file_path}") + + filename = os.path.basename(data_file_path) + dest_path = os.path.join(self.workspace, filename) + os.system(f"cp {data_file_path} {dest_path}") + print(f"Data file copied to workspace: {dest_path}") + return dest_path + + def _copy_user_potential_files(self): + """Copy user-provided potential files to workspace.""" + print("Copying user-provided potential files to workspace...") + for pot_file in self.user_potential_files: + if not os.path.exists(pot_file): + raise FileNotFoundError(f"Potential file not found: {pot_file}") + + filename = os.path.basename(pot_file) + dest_path = os.path.join(self.workspace, filename) + + try: + os.system(f"cp {pot_file} {dest_path}") + print(f"Potential files copied to workspace: {dest_path}") + except Exception as e: + print(f"Error copying {filename}: {e}") + raise + + def _create_user_potential_wrapper(self, state: LammpsState) -> LammpsState: + """Create a wrapper object for user-provided potential to match atomman interface.""" + self._copy_user_potential_files() + + # Create a simple object that mimics the atomman potential interface + class UserPotential: + def __init__(self, pair_style, pair_coeff): + self._pair_style = pair_style + self._pair_coeff = pair_coeff + + def pair_info(self): + return f"pair_style {self._pair_style}\npair_coeff {self._pair_coeff}" + + user_potential = UserPotential( + self.user_pair_style, self.user_pair_coeff + ) + + return { + **state, + "chosen_potential": user_potential, + "fix_attempts": 0, + "run_history": [], + } + def _fetch_and_trim_text(self, url: str) -> str: downloaded = trafilatura.fetch_url(url) if not downloaded: @@ -188,11 +340,25 @@ def _fetch_and_trim_text(self, url: str) -> str: return text def _entry_router(self, state: LammpsState) -> dict: + # Check if using user-provided potential + if self.use_user_potential: + if self.find_potential_only: + raise Exception( + "Cannot set find_potential_only=True when providing your own potential!" + ) + print("Using user-provided potential files") + if self.find_potential_only and state.get("chosen_potential"): raise Exception( "You cannot set find_potential_only=True and also specify your own potential!" ) + if self.data_file: + try: + self._copy_data_file(self.data_file) + except Exception as e: + print(f"Warning: Could not process data file: {e}") + if not state.get("chosen_potential"): self.potential_summaries_dir = os.path.join( self.workspace, "potential_summaries" @@ -213,6 +379,7 @@ def _find_potentials(self, state: LammpsState) -> LammpsState: "summaries": [], "full_texts": [], "fix_attempts": 0, + "run_history": [], } def _should_summarize(self, state: LammpsState) -> str: @@ -227,7 +394,7 @@ def _should_summarize(self, state: LammpsState) -> str: def _summarize_one(self, state: LammpsState) -> LammpsState: i = state["idx"] - print(f"Summarizing potential #{i}") + self._section(f"Summarizing potential #{i}") match = state["matches"][i] md = match.metadata() @@ -268,7 +435,7 @@ def _build_summaries(self, state: LammpsState) -> LammpsState: return {**state, "summaries_combined": "".join(parts)} def _choose(self, state: LammpsState) -> LammpsState: - print("Choosing one potential for this task...") + self._section("Choosing potential") choice = self.choose_chain.invoke({ "summaries_combined": state["summaries_combined"], "simulation_task": state["simulation_task"], @@ -276,12 +443,14 @@ def _choose(self, state: LammpsState) -> LammpsState: choice_dict = self._safe_json_loads(choice) chosen_index = int(choice_dict["Chosen index"]) - print(f"Chosen potential #{chosen_index}") - print("Rationale for choosing this potential:") - print(choice_dict["rationale"]) - chosen_potential = state["matches"][chosen_index] + self._panel( + "Chosen Potential", + f"[bold]Index:[/bold] {chosen_index}\n[bold]ID:[/bold] {chosen_potential.id}\n\n[bold]Rationale:[/bold]\n{choice_dict['rationale']}", + style="green", + ) + out_file = os.path.join(self.potential_summaries_dir, "Rationale.txt") with open(out_file, "w") as f: f.write(f"Chosen potential #{chosen_index}") @@ -298,97 +467,234 @@ def _route_after_summarization(self, state: LammpsState) -> str: return "continue_author" def _author(self, state: LammpsState) -> LammpsState: - print("First attempt at writing LAMMPS input file....") - state["chosen_potential"].download_files(self.workspace) + self._section("First attempt at writing LAMMPS input file") + + if not self.use_user_potential: + state["chosen_potential"].download_files(self.workspace) pair_info = state["chosen_potential"].pair_info() + + data_content = "" + if self.data_file: + data_content = self._read_and_trim_data_file(self.data_file) + authored_json = self.author_chain.invoke({ "simulation_task": state["simulation_task"], "pair_info": pair_info, "template": state["template"], + "data_file": self.data_file, + "data_content": data_content, }) script_dict = self._safe_json_loads(authored_json) input_script = script_dict["input_script"] with open(os.path.join(self.workspace, "in.lammps"), "w") as f: f.write(input_script) + + self._section("Authored LAMMPS input") + self._code_panel( + "in.lammps", input_script, language="bash", style="magenta" + ) + return {**state, "input_script": input_script} def _run_lammps(self, state: LammpsState) -> LammpsState: - print("Running LAMMPS....") - result = subprocess.run( - [ - self.mpirun_cmd, - "-np", - str(self.mpi_procs), - self.lammps_cmd, - "-in", - "in.lammps", - ], - cwd=self.workspace, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - text=True, - check=False, + self._section("Running LAMMPS") + + if self.ngpus >= 0: + result = subprocess.run( + [ + self.mpirun_cmd, + "-np", + str(self.mpi_procs), + self.lammps_cmd, + "-in", + "in.lammps", + "-k", + "on", + "g", + str(self.ngpus), + "-sf", + "kk", + "-pk", + "kokkos", + "neigh", + "half", + "newton", + "on", + ], + cwd=self.workspace, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False, + ) + print(result) + else: + result = subprocess.run( + [ + self.mpirun_cmd, + "-np", + str(self.mpi_procs), + self.lammps_cmd, + "-in", + "in.lammps", + ], + cwd=self.workspace, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False, + ) + + status_style = "green" if result.returncode == 0 else "red" + self._panel( + "Run Result", + f"returncode = {result.returncode}", + style=status_style, ) + + if result.returncode != 0: + err_view = ( + result.stderr.strip() + "\n" + result.stdout.strip() + ).strip() or "(no output captured)" + self._panel("Run error/output", err_view[-6000:], style="red") + + hist = list(state.get("run_history", [])) + hist.append({ + "attempt": state.get("fix_attempts", 0), + "input_script": state.get("input_script", ""), + "returncode": result.returncode, + "stdout": result.stdout, + "stderr": result.stderr, + }) + return { **state, "run_returncode": result.returncode, "run_stdout": result.stdout, "run_stderr": result.stderr, + "run_history": hist, } def _route_run(self, state: LammpsState) -> str: rc = state.get("run_returncode", 0) attempts = state.get("fix_attempts", 0) if rc == 0: - print("LAMMPS run successful! Exiting...") + self._section("LAMMPS run successful! Exiting...") return "done_success" if attempts < self.max_fix_attempts: - print("LAMMPS run Failed. Attempting to rewrite input file...") + self._section( + "LAMMPS run Failed. Attempting to rewrite input file..." + ) return "need_fix" - print("LAMMPS run Failed and maximum fix attempts reached. Exiting...") + self._section( + "LAMMPS run Failed and maximum fix attempts reached. Exiting.." + ) return "done_failed" def _fix(self, state: LammpsState) -> LammpsState: pair_info = state["chosen_potential"].pair_info() - err_blob = state.get("run_stdout") + + hist = state.get("run_history", []) + if not hist: + hist = [ + { + "attempt": state.get("fix_attempts", 0), + "input_script": state.get("input_script", ""), + "returncode": state.get("run_returncode"), + "stdout": state.get("run_stdout", ""), + "stderr": state.get("run_stderr", ""), + } + ] + + parts = [] + for h in hist: + parts.append( + "=== Attempt {attempt} | returncode={returncode} ===\n" + "--- input_script ---\n{input_script}\n" + "--- stdout ---\n{stdout}\n" + "--- stderr ---\n{stderr}\n".format(**h) + ) + err_blob = "\n".join(parts) + + data_content = "" + if self.data_file: + data_content = self._read_and_trim_data_file(self.data_file) fixed_json = self.fix_chain.invoke({ "simulation_task": state["simulation_task"], - "input_script": state["input_script"], "err_message": err_blob, "pair_info": pair_info, "template": state["template"], + "data_file": self.data_file, + "data_content": data_content, }) script_dict = self._safe_json_loads(fixed_json) + new_input = script_dict["input_script"] + old_input = state["input_script"] + self._diff_panel(old_input, new_input) + with open(os.path.join(self.workspace, "in.lammps"), "w") as f: f.write(new_input) + return { **state, "input_script": new_input, "fix_attempts": state.get("fix_attempts", 0) + 1, } + def _summarize(self, state: LammpsState) -> LammpsState: + self._section( + "Now handing things off to execution agent for summarization/visualization" + ) + + executor = ExecutionAgent(llm=self.llm) + + exe_plan = f""" + You are part of a larger scientific workflow whose purpose is to accomplish this task: {state["simulation_task"]} + A LAMMPS simulation has been done and the output is located in the file 'log.lammps'. + Summarize the contents of this file in a markdown document. Include a plot, if relevent. + """ + + exe_results = executor.invoke({ + "messages": [HumanMessage(content=exe_plan)], + "workspace": self.workspace, + }) + + for x in exe_results["messages"]: + print(x.content) + + return state + + def _post_run(self, state: LammpsState) -> LammpsState: + return state + def _build_graph(self): self.add_node(self._entry_router) self.add_node(self._find_potentials) self.add_node(self._summarize_one) self.add_node(self._build_summaries) self.add_node(self._choose) + self.add_node(self._create_user_potential_wrapper) self.add_node(self._author) self.add_node(self._run_lammps) self.add_node(self._fix) + self.add_node(self._post_run) + self.add_node(self._summarize) self.graph.set_entry_point("_entry_router") self.graph.add_conditional_edges( "_entry_router", - lambda state: ( + lambda state: "user_potential" + if self.use_user_potential + else ( "user_choice" if state.get("chosen_potential") else "agent_choice" ), { + "user_potential": "_create_user_potential_wrapper", "user_choice": "_author", "agent_choice": "_find_potentials", }, @@ -424,6 +730,7 @@ def _build_graph(self): }, ) + self.graph.add_edge("_create_user_potential_wrapper", "_author") self.graph.add_edge("_author", "_run_lammps") self.graph.add_conditional_edges( @@ -431,8 +738,20 @@ def _build_graph(self): self._route_run, { "need_fix": "_fix", - "done_success": END, + "done_success": "_post_run", "done_failed": END, }, ) + self.graph.add_edge("_fix", "_run_lammps") + + self.graph.add_conditional_edges( + "_post_run", + lambda _: "summarize" if self.summarize_results else "skip", + { + "summarize": "_summarize", + "skip": END, + }, + ) + + self.graph.add_edge("_summarize", END) diff --git a/uv.lock b/uv.lock index c3b57af4..2bcab590 100644 --- a/uv.lock +++ b/uv.lock @@ -8801,4 +8801,4 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/38/f249a2050ad1eea0bb364046153942e34abba95dd5520af199aed86fbb49/zstandard-0.25.0-cp314-cp314-win32.whl", hash = "sha256:da469dc041701583e34de852d8634703550348d5822e66a0c827d39b05365b12", size = 444513, upload-time = "2025-09-14T22:18:20.61Z" }, { url = "https://files.pythonhosted.org/packages/3a/43/241f9615bcf8ba8903b3f0432da069e857fc4fd1783bd26183db53c4804b/zstandard-0.25.0-cp314-cp314-win_amd64.whl", hash = "sha256:c19bcdd826e95671065f8692b5a4aa95c52dc7a02a4c5a0cac46deb879a017a2", size = 516118, upload-time = "2025-09-14T22:18:17.849Z" }, { url = "https://files.pythonhosted.org/packages/f0/ef/da163ce2450ed4febf6467d77ccb4cd52c4c30ab45624bad26ca0a27260c/zstandard-0.25.0-cp314-cp314-win_arm64.whl", hash = "sha256:d7541afd73985c630bafcd6338d2518ae96060075f9463d7dc14cfb33514383d", size = 476940, upload-time = "2025-09-14T22:18:19.088Z" }, -] +] \ No newline at end of file