@@ -817,72 +817,33 @@ class sliceScatter(torch.nn.Module):
817
817
def __init__ (self , * args , ** kwargs ) -> None :
818
818
super ().__init__ (* args , ** kwargs )
819
819
820
- def forward (self , x , src , dim , start = None , end = None , step = 1 ):
821
- y = torch .ops .aten .slice_scatter (x , src , dim , start , end , step )
820
+ def forward (self , x , src ):
821
+ y = torch .ops .aten .slice_scatter (x , src , 1 , 6 , None , 1 )
822
822
return y
823
823
824
- # Operations expected to be removed in the traced graph after decompositions
825
- expected_ops = {
826
- torch .ops .aten .scatter .src ,
827
- }
828
- unexpected_ops = {torch .ops .aten .select_scatter }
829
-
830
- a = torch .zeros (8 , 8 ).cuda ()
831
- b = torch .ones (8 , 2 ).cuda ()
832
-
833
- # 0-D tensors for dynamic scalar values
834
- start = torch .tensor (1 , dtype = torch .int64 ).cuda ()
835
- end = torch .tensor (6 , dtype = torch .int64 ).cuda ()
836
- step = torch .tensor (1 , dtype = torch .int64 ).cuda ()
837
-
838
- # Mark scalar tensors as dynamic (note: shape = ())
839
- torch ._dynamo .mark_dynamic (start , (), min = 1 , max = 3 )
840
- torch ._dynamo .mark_dynamic (end , (), min = 4 , max = 6 )
841
-
842
- inputs = (a , b , start , end , None , step )
843
824
fx_graph = torch .fx .symbolic_trace (sliceScatter ())
844
- unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
845
- fx_graph ,
846
- inputs ,
847
- expected_ops = expected_ops ,
848
- unexpected_ops = unexpected_ops ,
849
- min_block_size = 1 ,
850
- )
851
825
852
- self .assertEqual (
853
- len (unexpected_ops_seen ),
854
- 0 ,
855
- f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
856
- )
857
-
858
- self .assertEqual (
859
- len (expected_ops_unseen ),
860
- 0 ,
861
- f"The following expected ops were not encountered: { expected_ops_unseen } " ,
826
+ dim1 = torch .export .Dim ("dim1" , min = 8 , max = 10 )
827
+ dynamic_shapes = {
828
+ "x" : [torch .export .Dim .STATIC , dim1 ],
829
+ "src" : [torch .export .Dim .STATIC , None ],
830
+ }
831
+ inputs = (torch .zeros (8 , 8 ).cuda (), torch .ones (8 , 2 ).cuda ())
832
+ exported_program = torch .export .export (
833
+ sliceScatter (), tuple (inputs ), dynamic_shapes = dynamic_shapes
862
834
)
863
-
835
+ fx_graph = exported_program .module ()
836
+ inputs = [
837
+ torch_tensorrt .Input (
838
+ min_shape = [8 , 8 ], opt_shape = [8 , 10 ], max_shape = [8 , 10 ]
839
+ ),
840
+ torch_tensorrt .Input (min_shape = [8 , 2 ], opt_shape = [8 , 2 ], max_shape = [8 , 2 ]),
841
+ ]
864
842
torch ._dynamo .reset ()
865
-
866
- # Validate that the results between Torch and Torch-TRT are similar
867
- optimized_model = torch_tensorrt .compile (
868
- fx_graph ,
869
- "torch_compile" ,
870
- inputs ,
871
- min_block_size = 1 ,
872
- truncate_double = True ,
873
- pass_through_build_failures = True ,
874
- )
875
- optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
876
- torch_model_results = fx_graph (* inputs ).detach ().cpu ()
877
-
878
- max_diff = float (
879
- torch .max (torch .abs (optimized_model_results - torch_model_results ))
880
- )
881
- self .assertAlmostEqual (
882
- max_diff ,
883
- 0 ,
884
- DECIMALS_OF_AGREEMENT ,
885
- f"Slice_scatter TRT outputs don't match with the original model." ,
843
+ trt_model = torch_tensorrt .dynamo .compile (exported_program , inputs )
844
+ inputs = (torch .zeros (8 , 8 ).cuda (), torch .ones (8 , 2 ).cuda ())
845
+ torch .testing .assert_close (
846
+ trt_model (* inputs ), fx_graph (* inputs ), rtol = RTOL , atol = ATOL
886
847
)
887
848
888
849
def test_lowering_select_scatter_dimZero_module (self ):
0 commit comments