@@ -489,3 +489,42 @@ func.func @shape_calc_with_two_uses(%arg0: !torch.vtensor<[2],f32>) -> !torch.vt
489
489
490
490
return %arg0 : !torch.vtensor <[2 ],f32 >
491
491
}
492
+
493
+ // CHECK-LABEL: func.func @unflat_shape_partial_dyn
494
+ // CHECK-DAG: %[[INT768:.*]] = torch.constant.int 768
495
+ // CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0
496
+ // CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1
497
+ // CHECK-DAG: %[[INT4:.*]] = torch.constant.int 4
498
+ // CHECK : } shapes {
499
+ // CHECK : %[[SZE0:.*]] = torch.aten.size.int %arg0, %[[INT0]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int
500
+ // CHECK : %[[SZE1:.*]] = torch.aten.size.int %arg0, %[[INT1]] : !torch.vtensor<[?,?,3072],f32>, !torch.int -> !torch.int
501
+ // CHECK : %[[LIST:.*]] = torch.prim.ListConstruct %[[SZE0]], %[[SZE1]], %[[INT4]], %[[INT768]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
502
+ // CHECK : torch.shape.calculate.yield.shapes %[[LIST]] : !torch.list<int>
503
+ // CHECK : } : !torch.vtensor<[?,?,4,768],f32>
504
+ func.func @unflat_shape_partial_dyn (%arg0: !torch.vtensor <[?,?,3072 ],f32 >) -> !torch.vtensor <[?,?,4 ,?],f32 > {
505
+ %int768 = torch.constant.int 768
506
+ %int3072 = torch.constant.int 3072
507
+ %int0 = torch.constant.int 0
508
+ %int3 = torch.constant.int 3
509
+ %int1 = torch.constant.int 1
510
+ %none = torch.constant.none
511
+ %int -1 = torch.constant.int -1
512
+ %int2 = torch.constant.int 2
513
+ %int4 = torch.constant.int 4
514
+ %0 = torch.prim.ListConstruct %int4 , %int -1 : (!torch.int , !torch.int ) -> !torch.list <int >
515
+ %1 = torch.shape.calculate {
516
+ %2 = torch.aten.unflatten.int %arg0 , %int2 , %0 : !torch.vtensor <[?,?,3072 ],f32 >, !torch.int , !torch.list <int > -> !torch.vtensor <[?,?,4 ,?],f32 >
517
+ torch.shape.calculate.yield %2 : !torch.vtensor <[?,?,4 ,?],f32 >
518
+ } shapes {
519
+ %2 = torch.aten.size.int %arg0 , %int0 : !torch.vtensor <[?,?,3072 ],f32 >, !torch.int -> !torch.int
520
+ %3 = torch.aten.size.int %arg0 , %int1 : !torch.vtensor <[?,?,3072 ],f32 >, !torch.int -> !torch.int
521
+ %4 = torch.prim.ListConstruct %2 , %3 , %int3072 : (!torch.int , !torch.int , !torch.int ) -> !torch.list <int >
522
+ %5 = torch.prim.ListConstruct %int4 , %int768 : (!torch.int , !torch.int ) -> !torch.list <int >
523
+ %6 = torch.aten.slice.t %4 , %none , %int2 , %int1 : !torch.list <int >, !torch.none , !torch.int , !torch.int -> !torch.list <int >
524
+ %7 = torch.aten.add.t %6 , %5 : !torch.list <int >, !torch.list <int > -> !torch.list <int >
525
+ %8 = torch.aten.slice.t %4 , %int3 , %none , %int1 : !torch.list <int >, !torch.int , !torch.none , !torch.int -> !torch.list <int >
526
+ %9 = torch.aten.add.t %7 , %8 : !torch.list <int >, !torch.list <int > -> !torch.list <int >
527
+ torch.shape.calculate.yield.shapes %9 : !torch.list <int >
528
+ } : !torch.vtensor <[?,?,4 ,?],f32 >
529
+ return %1 : !torch.vtensor <[?,?,4 ,?],f32 >
530
+ }
0 commit comments