Skip to content

Commit

Permalink
Add support for passing & returning memref of bool types
Browse files Browse the repository at this point in the history
Support for passing memref of bool types as a function argument
and return is added in ref-backend.

Signed-off-by: Prashant Kumar <[email protected]>
  • Loading branch information
Prashant Kumar committed Dec 8, 2021
1 parent 9958cf0 commit c598e01
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 4 deletions.
54 changes: 54 additions & 0 deletions e2e_testing/torchscript/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,3 +1003,57 @@ def forward(self):
@register_test_case(module_factory=lambda: ZerosModuleFalsePinMemory())
def ZerosModuleFalsePinMemory_basic(module, tu: TestUtils):
module.forward()


class BoolTensorReturnFalseModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.bool, True),
])
def forward(self, a):
return a


@register_test_case(module_factory=lambda: BoolTensorReturnFalseModule())
def BoolTensorReturnFalseModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([0, 0], dtype=torch.bool))


class BoolTensorReturnTrueModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1], torch.bool, True),
])
def forward(self, a):
return a


@register_test_case(module_factory=lambda: BoolTensorReturnTrueModule())
def BoolTensorReturnTrueModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([1, 1, 1, 1, 1], dtype=torch.bool))


class BoolTensorReturnMixedModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.bool, True),
])
def forward(self, a):
return a


@register_test_case(module_factory=lambda: BoolTensorReturnMixedModule())
def BoolTensorReturnMixedModule_basic(module, tu: TestUtils):
module.forward(torch.tensor([[1, 0], [0,1]], dtype=torch.bool))
3 changes: 3 additions & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,7 @@
"AddCMulModule_basic",
"AddCDivModule_basic",
"SqueezeModule_broadcast",
"BoolTensorReturnFalseModule_basic",
"BoolTensorReturnTrueModule_basic",
"BoolTensorReturnMixedModule_basic",
}
9 changes: 6 additions & 3 deletions lib/RefBackend/RefBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ static bool isArgMemRefTypeValid(Type type) {
return true;
if (integerTy.isSignlessInteger(32))
return true;
if (integerTy.isSignlessInteger(1))
return true;
}
}
return false;
Expand Down Expand Up @@ -128,7 +130,7 @@ static LogicalResult mungeFunction(
auto type = arg.getType();
if (!isArgMemRefTypeValid(type))
return emitError(arg.getLoc(),
"argument must be a memref of f32, f64, i32, i64");
"argument must be a memref of f32, f64, i32, i64, i1");
auto cast = b.create<memref::CastOp>(arg.getLoc(), arg, type);
arg.replaceAllUsesExcept(cast, cast);
arg.setType(getAbiTypeForMemRef(type));
Expand Down Expand Up @@ -163,7 +165,7 @@ static LogicalResult mungeFunction(
if (supportedConsumeFuncReturnFuncs.find(funcName) == supportedFuncsEnd) {
op.emitError(
"must have one return value of memref types or scalar types "
"of i32, i64, f32, f64 or three return values of memref f32");
"of i32, i64, f32, f64, i1, or three return values of memref f32");
isSupported = false;
}

Expand All @@ -182,6 +184,7 @@ static LogicalResult mungeFunction(

static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
std::set<std::string> funcNames;
Type mri1 = UnrankedMemRefType::get(b.getI1Type(), 0);
Type mri32 = UnrankedMemRefType::get(b.getI32Type(), 0);
Type mri64 = UnrankedMemRefType::get(b.getI64Type(), 0);
Type mrf32 = UnrankedMemRefType::get(b.getF32Type(), 0);
Expand All @@ -191,7 +194,7 @@ static std::set<std::string> getSupportedConsumeFuncReturnFuncs(OpBuilder &b) {
Type f64 = b.getF64Type();

SmallVector<TypeRange> supportedReturnTypes = {
mri32, mri64, mrf32, mrf64, i64, f32, f64, {mrf32, mrf32, mrf32}};
mri1, mri32, mri64, mrf32, mrf64, i64, f32, f64, {mrf32, mrf32, mrf32}};

llvm::for_each(supportedReturnTypes, [&](TypeRange &types) {
funcNames.insert(getConsumeReturnFunctionNameForReturnTypes(types));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


def checkArgTypeIsSupported(ty):
SUPPORTED = [np.float32, np.float64, np.int32, np.int64]
SUPPORTED = [np.float32, np.float64, np.int32, np.int64, np.bool_]
assert ty in SUPPORTED, f"Only numpy arrays with dtypes in {SUPPORTED} are supported"


Expand All @@ -33,6 +33,10 @@ def __init__(self, module):
self.ee = ExecutionEngine(module)
self.result = None

@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mri1(a):
self.result = unranked_memref_to_numpy(a, np.bool_)

@ctypes.CFUNCTYPE(None, ctypes.POINTER(UnrankedMemRefDescriptor))
def consume_return_mri32(a):
self.result = unranked_memref_to_numpy(a, np.int32)
Expand Down Expand Up @@ -70,6 +74,9 @@ def consume_return_mrf32_mrf32_mrf32(arg0, arg1, arg2):
arg1,
np.float32), unranked_memref_to_numpy(arg2, np.float32)

self.ee.register_runtime("refbackend_consume_func_return_mri1",
consume_return_mri1)

self.ee.register_runtime("refbackend_consume_func_return_mri32",
consume_return_mri32)

Expand Down

0 comments on commit c598e01

Please sign in to comment.