2
2
3
3
#include " op.h"
4
4
5
- #include < cassert>
6
5
#include < memory>
7
6
#include < string>
8
7
#include < vector>
@@ -26,10 +25,10 @@ void check_all_parameters(
26
25
void get_operator_from_registry_and_execute () {
27
26
auto & ops = torch::jit::getAllOperatorsFor (
28
27
torch::jit::Symbol::fromQualString (" custom::op" ));
29
- assert (ops.size () == 1 );
28
+ AT_ASSERT (ops.size () == 1 );
30
29
31
30
auto & op = ops.front ();
32
- assert (op->schema ().name == " custom::op" );
31
+ AT_ASSERT (op->schema ().name == " custom::op" );
33
32
34
33
torch::jit::Stack stack;
35
34
torch::jit::push (stack, torch::ones (5 ), 2.0 , 3 );
@@ -39,57 +38,57 @@ void get_operator_from_registry_and_execute() {
39
38
40
39
const auto manual = custom_op (torch::ones (5 ), 2.0 , 3 );
41
40
42
- assert (output.size () == 3 );
41
+ AT_ASSERT (output.size () == 3 );
43
42
for (size_t i = 0 ; i < output.size (); ++i) {
44
- assert (output[i].allclose (torch::ones (5 ) * 2 ));
45
- assert (output[i].allclose (manual[i]));
43
+ AT_ASSERT (output[i].allclose (torch::ones (5 ) * 2 ));
44
+ AT_ASSERT (output[i].allclose (manual[i]));
46
45
}
47
46
}
48
47
49
48
void load_serialized_module_with_custom_op_and_execute (
50
49
const std::string& path_to_exported_script_module) {
51
50
std::shared_ptr<torch::jit::script::Module> module =
52
51
torch::jit::load (path_to_exported_script_module);
53
- assert (module != nullptr );
52
+ AT_ASSERT (module != nullptr );
54
53
55
54
std::vector<torch::jit::IValue> inputs;
56
55
inputs.push_back (torch::ones (5 ));
57
56
auto output = module->forward (inputs).toTensor ();
58
57
59
- assert (output.allclose (torch::ones (5 ) + 1 ));
58
+ AT_ASSERT (output.allclose (torch::ones (5 ) + 1 ));
60
59
}
61
60
62
61
void test_argument_checking_for_serialized_modules (
63
62
const std::string& path_to_exported_script_module) {
64
63
std::shared_ptr<torch::jit::script::Module> module =
65
64
torch::jit::load (path_to_exported_script_module);
66
- assert (module != nullptr );
65
+ AT_ASSERT (module != nullptr );
67
66
68
67
try {
69
68
module->forward ({torch::jit::IValue (1 ), torch::jit::IValue (2 )});
70
- assert (false );
69
+ AT_ASSERT (false );
71
70
} catch (const c10::Error& error) {
72
- assert (
71
+ AT_ASSERT (
73
72
std::string (error.what_without_backtrace ())
74
73
.find (" Expected at most 1 argument(s) for operator 'forward', "
75
74
" but received 2 argument(s)" ) == 0 );
76
75
}
77
76
78
77
try {
79
78
module->forward ({torch::jit::IValue (5 )});
80
- assert (false );
79
+ AT_ASSERT (false );
81
80
} catch (const c10::Error& error) {
82
- assert (
81
+ AT_ASSERT (
83
82
std::string (error.what_without_backtrace ())
84
83
.find (" Expected value of type Dynamic for argument 'input' in "
85
84
" position 0, but instead got value of type int" ) == 0 );
86
85
}
87
86
88
87
try {
89
88
module->forward ({});
90
- assert (false );
89
+ AT_ASSERT (false );
91
90
} catch (const c10::Error& error) {
92
- assert (
91
+ AT_ASSERT (
93
92
std::string (error.what_without_backtrace ())
94
93
.find (" forward() is missing value for argument 'input'" ) == 0 );
95
94
}
0 commit comments