Skip to content

Commit

Permalink
Adjust bound check to be the same as PyTorch native (i.e. stricter) (#…
Browse files Browse the repository at this point in the history
…2755)

prims.expand expects the start and end dimensions to be strictly less
than the rank of the tensor.
  • Loading branch information
newling authored Jan 15, 2024
1 parent 87389f0 commit f78ec78
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
4 changes: 2 additions & 2 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6514,15 +6514,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int0 = torch.constant.int 0\n"
" %int1 = torch.constant.int 1\n"
" %0 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n"
" %1 = torch.aten.lt.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %1 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %2 = torch.aten.len.t %arg0 : !torch.list<int> -> !torch.int\n"
" %3 = torch.aten.le.int %arg2, %2 : !torch.int, !torch.int -> !torch.bool\n"
" %3 = torch.aten.lt.int %arg2, %2 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,21 @@ def prims〇convert_element_type〡shape(a: List[int], dtype: int) -> List[int]:

def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]:
# Obtained through trial and error on a few examples in PyTorch:
assert start <= len(a), "start out of bounds"
assert end <= len(a), "end out of bounds"
assert start < len(a), "start out of bounds"
assert end < len(a), "end out of bounds"
assert start >= 0, "start out of bounds"
assert end >= 0, "end out of bounds"
assert start <= end, "start must be less than or equal to end"

# Example:
# Examples:
#
# torch._prims.collapse(torch.empty(2,3,4), 1,2).shape
# is
# torch.Size([2, 12])
#
# torch._prims.collapse(torch.empty(2,3,4), 1,3).shape
# gives
# --> 524 assert idx >= 0 and idx < rank or idx == 0

collapsed: List[int] = []
for i in range(start):
Expand Down

0 comments on commit f78ec78

Please sign in to comment.