diff --git a/docs/examples/sequential_fit_example.py b/docs/examples/sequential_fit_example.py new file mode 100644 index 0000000..9824fa2 --- /dev/null +++ b/docs/examples/sequential_fit_example.py @@ -0,0 +1,35 @@ +from pdfbl.sequential.sequential_runner import SequentialPDFFitRunner + +sts = SequentialPDFFitRunner() +sts.load_inputs( + input_data_dir="data/input_files", + structure_path="data/Ni.cif", + output_result_dir="data/results", + filename_order_pattern=r"(\d+)K\.gr", + refine_variable_names=[ + "a_1", + "s0", + "Uiso_0_1", + "delta2_1", + "qdamp", + "qbroad", + ], + initial_variable_values={ + "s0": 0.4, + "qdamp": 0.04, + "qbroad": 0.02, + "a_1": 3.52, + "Uiso_0_1": 0.005, + "delta2_1": 2, + }, + xmin=1.5, + xmax=25.0, + dx=0.01, + qmax=25, + qmin=0.1, + # whether_plot_y=True, + # whether_plot_ycalc=True, + plot_variable_names=["a_1"], + plot_result_entry_names=["residual"], +) +sts.start(mode="stream") diff --git a/news/sequential.rst b/news/sequential.rst new file mode 100644 index 0000000..67a024a --- /dev/null +++ b/news/sequential.rst @@ -0,0 +1,23 @@ +**Added:** + +* Add ``SequentialPDFFitRunner``. + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/requirements/conda.txt b/requirements/conda.txt index b45ca4f..d117d69 100644 --- a/requirements/conda.txt +++ b/requirements/conda.txt @@ -1,3 +1,6 @@ numpy diffpy.cmi scipy +prompt_toolkit +matplotlib +bg_mpl_stylesheets diff --git a/src/pdfbl/sequential/sequential_runner.py b/src/pdfbl/sequential/sequential_runner.py new file mode 100644 index 0000000..d30b9fb --- /dev/null +++ b/src/pdfbl/sequential/sequential_runner.py @@ -0,0 +1,318 @@ +import json +import re +import threading +from pathlib import Path +from queue import Queue +from typing import Literal + +from bg_mpl_stylesheets.styles import all_styles +from diffpy.srfit.fitbase import FitResults +from matplotlib import pyplot as plt +from prompt_toolkit import PromptSession +from prompt_toolkit.patch_stdout import patch_stdout + +from pdfbl.sequential.pdfadapter import PDFAdapter + +plt.style.use(all_styles["bg-style"]) + + +class SequentialPDFFitRunner: + def __init__(self): + self.input_files_known = [] + self.input_files_completed = [] + self.input_files_running = [] + self.adapter = PDFAdapter() + self.plot_data = {} + + def load_inputs( + self, + input_data_dir, + structure_path, + output_result_dir="results", + filename_order_pattern=r"(\d+)K\.gr", + whether_plot_y=False, + whether_plot_ycalc=False, + plot_variable_names=None, + plot_result_entry_names=None, + refine_variable_names=None, + initial_variable_values=None, + xmin=None, + xmax=None, + dx=None, + qmin=None, + qmax=None, + ): + self.inputs = { + "input_data_dir": input_data_dir, + "structure_path": structure_path, + "output_result_dir": output_result_dir, + "filename_order_pattern": filename_order_pattern, + "xmin": xmin, + "xmax": xmax, + "dx": dx, + "qmin": qmin, + "qmax": qmax, + "refine_variable_names": refine_variable_names or [], + "initial_variable_values": initial_variable_values or {}, + "whether_plot_y": whether_plot_y, + "whether_plot_ycalc": whether_plot_ycalc, + "plot_variable_names": plot_variable_names or [], + } + if whether_plot_y and whether_plot_ycalc: + fig, axes = plt.subplots(2, 1) + (line,) = axes[0].plot( + [], + [], + label="ycalc", + color=plt.rcParams["axes.prop_cycle"].by_key()["color"][0], + ) + self.plot_data["ycalc"] = { + "line": line, + "xdata": Queue(), + "ydata": Queue(), + } + (line,) = axes[1].plot( + [], + [], + label="y", + color=plt.rcParams["axes.prop_cycle"].by_key()["color"][1], + ) + self.plot_data["y"] = { + "line": line, + "xdata": Queue(), + "ydata": Queue(), + } + elif whether_plot_ycalc: + fig, ax = plt.subplots() + (line,) = ax.plot([], [], label="ycalc") + self.plot_data["ycalc"] = { + "line": line, + "xdata": Queue(), + "ydata": Queue(), + } + elif whether_plot_y: + fig, ax = plt.subplots() + (line,) = ax.plot([], [], label="y") + self.plot_data["y"] = { + "line": line, + "xdata": Queue(), + "ydata": Queue(), + } + if plot_variable_names: + self.plot_data["variables"] = {} + for var_name in plot_variable_names: + fig, ax = plt.subplots() + (line,) = ax.plot([], [], label=var_name, marker="o") + self.plot_data["variables"][var_name] = { + var_name: {"line": line, "buffer": [], "ydata": Queue()} + } + fig.suptitle(f"Variable: {var_name}") + if plot_result_entry_names: + self.plot_data["result_entries"] = {} + for entry_name in plot_result_entry_names: + fig, ax = plt.subplots() + (line,) = ax.plot([], [], label=entry_name, marker="o") + self.plot_data["result_entries"][entry_name] = { + entry_name: {"line": line, "buffer": [], "ydata": Queue()} + } + fig.suptitle(f"Result Entry: {entry_name}") + + def check_for_new_data(self): + input_data_dir = self.inputs["input_data_dir"] + filename_order_pattern = self.inputs["filename_order_pattern"] + files = [file for file in Path(input_data_dir).glob("*")] + sorted_file = sorted( + files, + key=lambda file: int( + re.findall(filename_order_pattern, file.name)[0] + ), + ) + if ( + self.input_files_known + != sorted_file[: len(self.input_files_known)] + ): + raise RuntimeError( + "Wrong order to run sequential toolset is detected. " + "This is likely due to files appearing in the input directory " + "in the wrong order. Please restart the sequential toolset." + ) + if self.input_files_known == sorted_file: + return + self.input_files_known = sorted_file + self.input_files_running = [ + f + for f in self.input_files_known + if f not in self.input_files_completed + ] + print(f"{[str(f) for f in self.input_files_running]} detected.") + + def set_start_input_file( + self, input_filename, input_filename_to_result_filename + ): + input_file_path = Path(input_filename) + if input_file_path not in self.input_files_known: + raise ValueError( + f"Input file {input_filename} not found in known input files." + ) + start_index = self.input_files_known.index(input_file_path) + self.input_files_completed = self.input_files_known[:start_index] + self.input_files_running = self.input_files_known[start_index:] + last_result_file = input_filename_to_result_filename( + self.input_files_completed[-1] + ) + last_result_variables_values = json.load(open(last_result_file, "r"))[ + "variables" + ] + last_result_variables_values = { + name: pack["value"] + for name, pack in last_result_variables_values.items() + } + self.last_result_variables_values = last_result_variables_values + + def run_one_round(self): + self.check_for_new_data() + xmin = self.inputs["xmin"] + xmax = self.inputs["xmax"] + dx = self.inputs["dx"] + qmin = self.inputs["qmin"] + qmax = self.inputs["qmax"] + structure_path = self.inputs["structure_path"] + output_result_dir = self.inputs["output_result_dir"] + initial_variable_values = self.inputs["initial_variable_values"] + refine_variable_names = self.inputs["refine_variable_names"] + if not self.input_files_running: + return None + for input_file in self.input_files_running: + self.adapter.init_profile( + str(input_file), + xmin=xmin, + xmax=xmax, + dx=dx, + qmin=qmin, + qmax=qmax, + ) + self.adapter.init_structures(structure_path) + self.adapter.init_contribution() + self.adapter.init_recipe() + if not hasattr(self, "last_result_variables_values"): + self.last_result_variables_values = initial_variable_values + self.adapter.set_initial_variable_values( + self.last_result_variables_values + ) + if refine_variable_names is None: + refine_variable_names = list(initial_variable_values.keys()) + self.adapter.refine_variables(refine_variable_names) + results = self.adapter.save_results( + filename=str( + Path(output_result_dir) / f"{input_file.stem}_result.json" + ), + mode="dict", + ) + self.last_result_variables_values = { + name: pack["value"] + for name, pack in results["variables"].items() + } + self.input_files_completed.append(input_file) + if "ycalc" in self.plot_data: + xdata = self.adapter.recipe.pdfcontribution.profile.x + ydata = self.adapter.recipe.pdfcontribution.profile.ycalc + self.plot_data["ycalc"]["xdata"].put(xdata) + self.plot_data["ycalc"]["ydata"].put(ydata) + if "y" in self.plot_data: + xdata = self.adapter.recipe.pdfcontribution.profile.x + ydata = self.adapter.recipe.pdfcontribution.profile.y + self.plot_data["y"]["xdata"].put(xdata) + self.plot_data["y"]["ydata"].put(ydata) + for var_name in self.plot_data.get("variables", {}): + new_value = self.adapter.recipe._parameters[var_name].value + self.plot_data["variables"][var_name][var_name]["ydata"].put( + new_value + ) + for entry_name in self.plot_data.get("result_entries", {}): + fit_results = FitResults(self.adapter.recipe) + entry_value = getattr(fit_results, entry_name) + self.plot_data["result_entries"][entry_name][entry_name][ + "ydata" + ].put(entry_value) + print(f"Completed processing {input_file.name}.") + self.input_files_running = [] + + def run(self, mode: Literal["batch", "stream"]): + if mode == "batch": + self.run_one_round() + elif mode == "stream": + stop_event = threading.Event() + session = PromptSession() + if self.plot_data is not None: + plt.ion() + plt.pause(1) # Update plot every 1s + + def stream_loop(): + while not stop_event.is_set(): + self.run_one_round() + stop_event.wait(1) # Check for new data every 1 second + + def input_loop(): + with patch_stdout(): + print("=== COMMANDS ===") + print("Type STOP to exit") + print("================") + while not stop_event.is_set(): + cmd = session.prompt("> ") + if cmd.strip() == "STOP": + stop_event.set() + print( + "Stopping the streaming sequential toolset..." + ) + else: + print( + "Unrecognized input. " + "Please type 'STOP' to end." + ) + + input_thread = threading.Thread(target=input_loop) + input_thread.start() + fit_thread = threading.Thread(target=stream_loop) + fit_thread.start() + while not stop_event.is_set(): + for key, plot_pack in self.plot_data.items(): + if key in ["ycalc", "y"]: + line = plot_pack["line"] + if not plot_pack["xdata"].empty(): + xdata = plot_pack["xdata"].get() + ydata = plot_pack["ydata"].get() + line.set_xdata(xdata) + line.set_ydata(ydata) + line.axes.relim() + line.axes.autoscale_view() + elif key == "variables": + for var_name, var_pack in plot_pack.items(): + line = var_pack[var_name]["line"] + buffer = var_pack[var_name]["buffer"] + if not var_pack[var_name]["ydata"].empty(): + new_y = var_pack[var_name]["ydata"].get() + buffer.append(new_y) + xdata = list(range(1, len(buffer) + 1)) + ydata = buffer + line.set_xdata(xdata) + line.set_ydata(ydata) + line.axes.relim() + line.axes.autoscale_view() + elif key == "result_entries": + for entry_name, entry_pack in plot_pack.items(): + line = entry_pack[entry_name]["line"] + buffer = entry_pack[entry_name]["buffer"] + if not entry_pack[entry_name]["ydata"].empty(): + new_y = entry_pack[entry_name]["ydata"].get() + buffer.append(new_y) + xdata = list(range(1, len(buffer) + 1)) + ydata = buffer + line.set_xdata(xdata) + line.set_ydata(ydata) + line.axes.relim() + line.axes.autoscale_view() + plt.pause(1) # Update plot every 1s + fit_thread.join() + input_thread.join() + else: + raise ValueError(f"Unknown mode: {mode}")