@@ -812,6 +812,79 @@ def forward(self, x, src, dim, start, end, step):
812
812
f"Slice_scatter TRT outputs don't match with the original model." ,
813
813
)
814
814
815
+ def test_lowering_slice_scatter_dynamic_module (self ):
816
+ class sliceScatter (torch .nn .Module ):
817
+ def __init__ (self , * args , ** kwargs ) -> None :
818
+ super ().__init__ (* args , ** kwargs )
819
+
820
+ def forward (self , x , src , dim , start = None , end = None , step = 1 ):
821
+ y = torch .ops .aten .slice_scatter (x , src , dim , start , end , step )
822
+ return y
823
+
824
+ # Operations expected to be removed in the traced graph after decompositions
825
+ expected_ops = {
826
+ torch .ops .aten .scatter .src ,
827
+ }
828
+ unexpected_ops = {torch .ops .aten .select_scatter }
829
+
830
+ a = torch .zeros (8 , 8 ).cuda ()
831
+ b = torch .ones (8 , 2 ).cuda ()
832
+
833
+ # 0-D tensors for dynamic scalar values
834
+ start = torch .tensor (1 , dtype = torch .int64 ).cuda ()
835
+ end = torch .tensor (6 , dtype = torch .int64 ).cuda ()
836
+ step = torch .tensor (1 , dtype = torch .int64 ).cuda ()
837
+
838
+ # Mark scalar tensors as dynamic (note: shape = ())
839
+ torch ._dynamo .mark_dynamic (start , (), min = 1 , max = 3 )
840
+ torch ._dynamo .mark_dynamic (end , (), min = 4 , max = 6 )
841
+
842
+ inputs = (a , b , start , end , None , step )
843
+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
844
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
845
+ fx_graph ,
846
+ inputs ,
847
+ expected_ops = expected_ops ,
848
+ unexpected_ops = unexpected_ops ,
849
+ min_block_size = 1 ,
850
+ )
851
+
852
+ self .assertEqual (
853
+ len (unexpected_ops_seen ),
854
+ 0 ,
855
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
856
+ )
857
+
858
+ self .assertEqual (
859
+ len (expected_ops_unseen ),
860
+ 0 ,
861
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
862
+ )
863
+
864
+ torch ._dynamo .reset ()
865
+
866
+ # Validate that the results between Torch and Torch-TRT are similar
867
+ optimized_model = torch_tensorrt .compile (
868
+ fx_graph ,
869
+ "torch_compile" ,
870
+ inputs ,
871
+ min_block_size = 1 ,
872
+ truncate_double = True ,
873
+ pass_through_build_failures = True ,
874
+ )
875
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
876
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
877
+
878
+ max_diff = float (
879
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
880
+ )
881
+ self .assertAlmostEqual (
882
+ max_diff ,
883
+ 0 ,
884
+ DECIMALS_OF_AGREEMENT ,
885
+ f"Slice_scatter TRT outputs don't match with the original model." ,
886
+ )
887
+
815
888
def test_lowering_select_scatter_dimZero_module (self ):
816
889
class selectScatter (torch .nn .Module ):
817
890
def __init__ (self , * args , ** kwargs ) -> None :
0 commit comments