-
Notifications
You must be signed in to change notification settings - Fork 5
Description
I just checked how much work it is to recreate the state of the workgraph from the failed state. We have do a similar creation of the workgraph as in the current one. The available data is given by the failed calcjob, we can retrieve the inputs from it. The steps that depend on the failed calcjob will be marked a skipped and have to be copied over from the failed workgraph. There is the annoyance that the CalcJob and the workgraph interface for aiida-shell are different and want need to translate one to the other. Copying over tasks from one workgraph to the another will be also very cumbersome as we need to redo the linking from the failed states. Here PoC code.
from aiida.orm import load_code, load_node
from aiida_workgraph import WorkGraph, TaskPool
from aiida import load_profile
load_profile()
wg = WorkGraph(name="test_shell_date")
task1 = wg.add_task(TaskPool.workgraph.shelljob, command="echo", arguments=["0 > out"], outputs=["out"], nodes={})
task2 = wg.add_task(TaskPool.workgraph.shelljob, command="cat", arguments=["{out}"], nodes={'out': task1.outputs.out})
wg.run()
### Restart
# user provides workgrahp pk somehow
import re
wg_node = load_node(wg.process.pk)
from aiida_workgraph.utils import get_workgraph_data
wg_data = get_workgraph_data(wg_node)
failed_wg = WorkGraph.from_dict(wg_data)
restart_wg = WorkGraph()
for name, task in wg_node.task_processes.items():
# the process state is not stored in the calcjob but in workgraph
# tasks that depend on a failed tasks have state SKIPPED
if wg_node.task_states[name] == 'FAILED':
wg_node.task_processes[name]
match = re.search(r"'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'", task)
uuid = match.group()[1:-1] # unwrap ''
calcjob = load_node(uuid)
# does not work there is too much other stuff in inputs that the user does not pass
#collect_inputs = {arg: getattr(calcjob.inputs, arg) for arg in calcjob.inputs.__dir__() if not arg.startswith("_")}
# because interface change in workgraph, you have to use workgraphs specific identifier
if str(calcjob.process_class) == "<class 'aiida_shell.calculations.shell.ShellJob'>":
identifier = TaskPool.workgraph.shelljob
else:
identifier = calcjob.process_class
# TODO double check if arguments at this point is already resolved
restart_wg.add_task(identifier, name=name, command=calcjob.inputs.code, arguments=calcjob.inputs.arguments.value, outputs=calcjob.inputs.outputs.value)
elif wg_node.task_states[name] == 'SKIPPED':
skipped_task = failed_wg.tasks[name]
# not sure if we can just pass inputs as kwargs, otherwise we need here
# dispatch depending on calcjob.process_class but even dispatch is not horrible for 2 process types
#collect_inputs = {arg: getattr(calcjob.inputs, arg) for arg in calcjob.inputs.__dir__() if not arg.startswith("_")}
# this is a bug in workgraph the identifier have not been correctly changed everywhere
if skipped_task.identifier == "ShellJob":
identifier = TaskPool.workgraph.shelljob
else:
identifier = skipped_task.identifier
# This is the actual bumper, becausb
nodes = {restart_wg.tasks[link.to_dict()['from_node']].outputs[link.to_dict()['from_socket']]._name: restart_wg.tasks[link.to_dict()['from_node']].outputs[link.to_dict()['from_socket']] for link in failed_wg.links if link.to_dict()['to_node'] == name}
restart_wg.add_task(identifier, name=name, command=skipped_task.inputs.command.value, arguments=skipped_task.inputs.arguments.value, outputs=skipped_task.inputs.outputs.value, nodes=nodes)
#for link in failed_wg.links:
# ld = link.to_dict() # easier accessible
# if ld['to_node'] == name:
# wg.add_link(restart_wg.tasks[ld['from_node']].outputs[ld['from_socket']],
# restart_wg.tasks[ld['to_node']].inputs[ld['to_socket']])
# Cannot just reuse nodes from old wg because they are bound. we have to redo the linking produced by the nodes here
#, nodes={name: socket for name, socket in skipped_task.inputs.nodes._sockets.items()})
restart_wg.run()Note that the user can only fix bugs in objects that are RemoteData since everything else will be retrieved from the database and we can only bypass it by enforcing RemoteData for the namelists (which could cause other issues as we internally want to adapt it between tasks?).