30
30
from tfx .orchestration import metadata
31
31
from tfx .orchestration import node_proto_view
32
32
from tfx .orchestration .experimental .core import constants
33
+ from tfx .orchestration .experimental .core import env
33
34
from tfx .orchestration .experimental .core import mlmd_state
34
35
from tfx .orchestration .experimental .core import task as task_lib
35
36
from tfx .orchestration import mlmd_connection_manager as mlmd_cm
@@ -548,21 +549,41 @@ def register_executions_from_existing_executions(
548
549
contexts = metadata_handle .store .get_contexts_by_execution (
549
550
existing_executions [0 ].id
550
551
)
551
- return execution_lib .put_executions (
552
+ executions = execution_lib .put_executions (
552
553
metadata_handle ,
553
554
new_executions ,
554
555
contexts ,
555
556
input_artifacts_maps = input_artifacts ,
556
557
)
557
558
559
+ pipeline_asset = metadata_handle .store .pipeline_asset
560
+ if pipeline_asset :
561
+ env .get_env ().create_pipeline_run_node_executions (
562
+ pipeline_asset .owner ,
563
+ pipeline_asset .name ,
564
+ pipeline ,
565
+ node .node_info .id ,
566
+ executions ,
567
+ )
568
+ else :
569
+ logging .warning (
570
+ 'Pipeline asset %s not found in MLMD. Unable to create pipeline run'
571
+ ' node executions.' ,
572
+ pipeline_asset ,
573
+ )
574
+ return executions
575
+
558
576
577
+ # TODO(b/349654866): make pipeline and node_id non-optional.
559
578
def register_executions (
560
579
metadata_handle : metadata .Metadata ,
561
580
execution_type : metadata_store_pb2 .ExecutionType ,
562
581
contexts : Sequence [metadata_store_pb2 .Context ],
563
582
input_and_params : Sequence [InputAndParam ],
583
+ pipeline : Optional [pipeline_pb2 .Pipeline ] = None ,
584
+ node_id : Optional [str ] = None ,
564
585
) -> Sequence [metadata_store_pb2 .Execution ]:
565
- """Registers multiple executions in MLMD .
586
+ """Registers multiple executions in storage backends .
566
587
567
588
Along with the execution:
568
589
- the input artifacts will be linked to the executions.
@@ -575,6 +596,8 @@ def register_executions(
575
596
input_and_params: A list of InputAndParams, which includes input_dicts
576
597
(dictionaries of artifacts. One execution will be registered for each of
577
598
the input_dict) and corresponding exec_properties.
599
+ pipeline: Optional. The pipeline proto.
600
+ node_id: Optional. The node id of the executions to be registered.
578
601
579
602
Returns:
580
603
A list of MLMD executions that are registered in MLMD, with id populated.
@@ -603,21 +626,41 @@ def register_executions(
603
626
executions .append (execution )
604
627
605
628
if len (executions ) == 1 :
606
- return [
629
+ new_executions = [
607
630
execution_lib .put_execution (
608
631
metadata_handle ,
609
632
executions [0 ],
610
633
contexts ,
611
634
input_artifacts = input_and_params [0 ].input_artifacts ,
612
635
)
613
636
]
637
+ else :
638
+ new_executions = execution_lib .put_executions (
639
+ metadata_handle ,
640
+ executions ,
641
+ contexts ,
642
+ [
643
+ input_and_param .input_artifacts
644
+ for input_and_param in input_and_params
645
+ ],
646
+ )
614
647
615
- return execution_lib .put_executions (
616
- metadata_handle ,
617
- executions ,
618
- contexts ,
619
- [input_and_param .input_artifacts for input_and_param in input_and_params ],
620
- )
648
+ pipeline_asset = metadata_handle .store .pipeline_asset
649
+ if pipeline_asset and pipeline and node_id :
650
+ env .get_env ().create_pipeline_run_node_executions (
651
+ pipeline_asset .owner ,
652
+ pipeline_asset .name ,
653
+ pipeline ,
654
+ node_id ,
655
+ new_executions ,
656
+ )
657
+ else :
658
+ logging .warning (
659
+ 'Skipping creating pipeline run node executions for pipeline asset %s.' ,
660
+ pipeline_asset ,
661
+ )
662
+
663
+ return new_executions
621
664
622
665
623
666
def update_external_artifact_type (
0 commit comments