Skip to content

Commit 9ea19cb

Browse files
goldsboroughfacebook-github-bot
authored andcommitted
Windows CI integration for custom ops (pytorch#12928)
Summary: Resubmission of pytorch#11527 ezyang orionr Pull Request resolved: pytorch#12928 Differential Revision: D10501342 Pulled By: goldsborough fbshipit-source-id: 7ce74795aab2f13efeb38f56ce82f53055f5eade
1 parent af78d4c commit 9ea19cb

File tree

11 files changed

+105
-59
lines changed

11 files changed

+105
-59
lines changed

.jenkins/pytorch/win-test.sh

+25-2
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,37 @@ call ci_scripts/setup_pytorch_env.bat
8787
cd test/ && python run_test.py --exclude nn --verbose && cd ..
8888
EOL
8989

90+
cat >ci_scripts/test_custom_script_ops.bat <<EOL
91+
call ci_scripts/setup_pytorch_env.bat
92+
93+
cd test/custom_operator
94+
95+
:: Build the custom operator library.
96+
mkdir build
97+
cd build
98+
:: Note: Caffe2 does not support MSVC + CUDA + Debug mode (has to be Release mode)
99+
cmake -DCMAKE_PREFIX_PATH=%CD%\\..\\..\\torch -DCMAKE_BUILD_TYPE=Release -GNinja ..
100+
ninja -v
101+
cd ..
102+
103+
:: Run tests Python-side and export a script module.
104+
python test_custom_ops.py -v
105+
python model.py --export-script-module="build/model.pt"
106+
:: Run tests C++-side and load the exported script module.
107+
cd build
108+
set PATH=C:\\Program Files\\NVIDIA Corporation\\NvToolsExt/bin/x64;%CD%\\..\\..\\torch\\lib;%PATH%
109+
test_custom_ops.exe model.pt
110+
111+
EOL
112+
90113
run_tests() {
91114
if [ -z "${JOB_BASE_NAME}" ] || [[ "${JOB_BASE_NAME}" == *-test ]]; then
92-
ci_scripts/test_python_nn.bat && ci_scripts/test_python_all_except_nn.bat
115+
ci_scripts/test_python_nn.bat && ci_scripts/test_python_all_except_nn.bat && ci_scripts/test_custom_script_ops.bat
93116
else
94117
if [[ "${JOB_BASE_NAME}" == *-test1 ]]; then
95118
ci_scripts/test_python_nn.bat
96119
elif [[ "${JOB_BASE_NAME}" == *-test2 ]]; then
97-
ci_scripts/test_python_all_except_nn.bat
120+
ci_scripts/test_python_all_except_nn.bat && ci_scripts/test_custom_script_ops.bat
98121
fi
99122
fi
100123
}

cmake/TorchConfig.cmake.in

+11-3
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#
1515
# torch
1616

17+
include(FindPackageHandleStandardArgs)
18+
1719
if ($ENV{TORCH_INSTALL_PREFIX})
1820
set(TORCH_INSTALL_PREFIX $ENV{TORCH_INSTALL_PREFIX})
1921
else()
@@ -37,7 +39,7 @@ endif()
3739
find_package(Caffe2 REQUIRED)
3840

3941
find_library(TORCH_LIBRARY torch PATHS "${TORCH_INSTALL_PREFIX}/lib")
40-
add_library(torch SHARED IMPORTED)
42+
add_library(torch UNKNOWN IMPORTED)
4143
set(TORCH_LIBRARIES torch ${Caffe2_MAIN_LIBS})
4244

4345
if (@USE_CUDA@)
@@ -67,11 +69,17 @@ if (@USE_CUDA@)
6769
endif()
6870

6971
# When we build libtorch with the old GCC ABI, dependent libraries must too.
70-
set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=@GLIBCXX_USE_CXX11_ABI@")
72+
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
73+
set(TORCH_CXX_FLAGS "-D_GLIBCXX_USE_CXX11_ABI=@GLIBCXX_USE_CXX11_ABI@")
74+
endif()
7175

7276
set_target_properties(torch PROPERTIES
7377
IMPORTED_LOCATION "${TORCH_LIBRARY}"
7478
INTERFACE_INCLUDE_DIRECTORIES "${TORCH_INCLUDE_DIRS}"
75-
INTERFACE_COMPILE_OPTIONS "${TORCH_CXX_FLAGS}"
7679
CXX_STANDARD 11
7780
)
81+
if (TORCH_CXX_FLAGS)
82+
set_property(TARGET torch PROPERTY INTERFACE_COMPILE_OPTIONS "${TORCH_CXX_FLAGS}")
83+
endif()
84+
85+
find_package_handle_standard_args(torch DEFAULT_MSG TORCH_LIBRARY TORCH_INCLUDE_DIRS)

test/custom_operator/CMakeLists.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ project(custom_ops)
55
find_package(Torch REQUIRED)
66

77
add_library(custom_ops SHARED op.cpp)
8-
target_link_libraries(custom_ops ${TORCH_LIBRARIES})
8+
target_compile_features(custom_ops PUBLIC cxx_range_for)
9+
target_link_libraries(custom_ops "${TORCH_LIBRARIES}")
10+
target_compile_definitions(custom_ops PRIVATE custom_ops_EXPORTS)
911

1012
add_executable(test_custom_ops test_custom_ops.cpp)
1113
target_link_libraries(test_custom_ops custom_ops)

test/custom_operator/model.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44

55
import torch
66

7-
8-
SHARED_LIBRARY_EXTENSIONS = {'linux': 'so', 'darwin': 'dylib', 'win32': 'dll'}
7+
SHARED_LIBRARY_NAMES = {
8+
'linux': 'libcustom_ops.so',
9+
'darwin': 'libcustom_ops.dylib',
10+
'win32': 'custom_ops.dll'
11+
}
912

1013

1114
def get_custom_op_library_path():
12-
extension = SHARED_LIBRARY_EXTENSIONS[sys.platform]
13-
path = os.path.abspath('build/libcustom_ops.{}'.format(extension))
15+
path = os.path.abspath('build/{}'.format(
16+
SHARED_LIBRARY_NAMES[sys.platform]))
1417
assert os.path.exists(path), path
1518
return path
1619

@@ -27,8 +30,7 @@ def forward(self, input):
2730

2831
def main():
2932
parser = argparse.ArgumentParser(
30-
description="Serialize a script module with custom ops"
31-
)
33+
description="Serialize a script module with custom ops")
3234
parser.add_argument("--export-script-module-to", required=True)
3335
options = parser.parse_args()
3436

test/custom_operator/op.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include <torch/script.h>
22

3+
#include "op.h"
4+
35
#include <cstddef>
46
#include <vector>
57

test/custom_operator/op.h

+13-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,19 @@
33
#include <cstddef>
44
#include <vector>
55

6-
TORCH_API std::vector<at::Tensor> custom_op(
6+
// clang-format off
7+
# if defined(_WIN32)
8+
# if defined(custom_ops_EXPORTS)
9+
# define CUSTOM_OP_API __declspec(dllexport)
10+
# else
11+
# define CUSTOM_OP_API __declspec(dllimport)
12+
# endif
13+
# else
14+
# define CUSTOM_OP_API
15+
# endif
16+
// clang-format on
17+
18+
CUSTOM_OP_API std::vector<at::Tensor> custom_op(
719
at::Tensor tensor,
820
double scalar,
921
int64_t repeat);

test/custom_operator/test_custom_ops.cpp

+14-15
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
#include "op.h"
44

5-
#include <cassert>
65
#include <memory>
76
#include <string>
87
#include <vector>
@@ -26,10 +25,10 @@ void check_all_parameters(
2625
void get_operator_from_registry_and_execute() {
2726
auto& ops = torch::jit::getAllOperatorsFor(
2827
torch::jit::Symbol::fromQualString("custom::op"));
29-
assert(ops.size() == 1);
28+
AT_ASSERT(ops.size() == 1);
3029

3130
auto& op = ops.front();
32-
assert(op->schema().name == "custom::op");
31+
AT_ASSERT(op->schema().name == "custom::op");
3332

3433
torch::jit::Stack stack;
3534
torch::jit::push(stack, torch::ones(5), 2.0, 3);
@@ -39,57 +38,57 @@ void get_operator_from_registry_and_execute() {
3938

4039
const auto manual = custom_op(torch::ones(5), 2.0, 3);
4140

42-
assert(output.size() == 3);
41+
AT_ASSERT(output.size() == 3);
4342
for (size_t i = 0; i < output.size(); ++i) {
44-
assert(output[i].allclose(torch::ones(5) * 2));
45-
assert(output[i].allclose(manual[i]));
43+
AT_ASSERT(output[i].allclose(torch::ones(5) * 2));
44+
AT_ASSERT(output[i].allclose(manual[i]));
4645
}
4746
}
4847

4948
void load_serialized_module_with_custom_op_and_execute(
5049
const std::string& path_to_exported_script_module) {
5150
std::shared_ptr<torch::jit::script::Module> module =
5251
torch::jit::load(path_to_exported_script_module);
53-
assert(module != nullptr);
52+
AT_ASSERT(module != nullptr);
5453

5554
std::vector<torch::jit::IValue> inputs;
5655
inputs.push_back(torch::ones(5));
5756
auto output = module->forward(inputs).toTensor();
5857

59-
assert(output.allclose(torch::ones(5) + 1));
58+
AT_ASSERT(output.allclose(torch::ones(5) + 1));
6059
}
6160

6261
void test_argument_checking_for_serialized_modules(
6362
const std::string& path_to_exported_script_module) {
6463
std::shared_ptr<torch::jit::script::Module> module =
6564
torch::jit::load(path_to_exported_script_module);
66-
assert(module != nullptr);
65+
AT_ASSERT(module != nullptr);
6766

6867
try {
6968
module->forward({torch::jit::IValue(1), torch::jit::IValue(2)});
70-
assert(false);
69+
AT_ASSERT(false);
7170
} catch (const c10::Error& error) {
72-
assert(
71+
AT_ASSERT(
7372
std::string(error.what_without_backtrace())
7473
.find("Expected at most 1 argument(s) for operator 'forward', "
7574
"but received 2 argument(s)") == 0);
7675
}
7776

7877
try {
7978
module->forward({torch::jit::IValue(5)});
80-
assert(false);
79+
AT_ASSERT(false);
8180
} catch (const c10::Error& error) {
82-
assert(
81+
AT_ASSERT(
8382
std::string(error.what_without_backtrace())
8483
.find("Expected value of type Dynamic for argument 'input' in "
8584
"position 0, but instead got value of type int") == 0);
8685
}
8786

8887
try {
8988
module->forward({});
90-
assert(false);
89+
AT_ASSERT(false);
9190
} catch (const c10::Error& error) {
92-
assert(
91+
AT_ASSERT(
9392
std::string(error.what_without_backtrace())
9493
.find("forward() is missing value for argument 'input'") == 0);
9594
}

torch/csrc/jit/operator.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
namespace torch { namespace jit {
2222

23-
FunctionSchema parseSchema(const std::string& schema);
23+
TORCH_API FunctionSchema parseSchema(const std::string& schema);
2424

2525
using OperationCreator = std::function<Operation(Node*)>;
2626

@@ -90,7 +90,7 @@ inline Operation getOperation(Node* node) {
9090
return getOperatorFor(node).getOperation(node);
9191
}
9292

93-
void registerOperator(Operator&& op);
93+
TORCH_API void registerOperator(Operator&& op);
9494

9595
// XXX: this function is meant to be used with string literals only!
9696
Operator& sig(const char *signature_literal);

torch/csrc/jit/script/module.cpp

-24
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,6 @@ void placeholderCreator(Method&) {
1313
throw RecursiveMethodCallError();
1414
}
1515

16-
static FunctionSchema defaultSchemaFor(const Method& method) {
17-
std::vector<Argument> args;
18-
std::vector<Argument> returns;
19-
Graph& g = *method.graph();
20-
size_t num_inputs = method.num_inputs();
21-
for(size_t i = 0; i < num_inputs; ++i) {
22-
const Value* v = g.inputs().at(i);
23-
std::string name = v->hasUniqueName() ? v->uniqueName() : ("argument_" + std::to_string(i));
24-
args.push_back({std::move(name), unshapedType(g.inputs()[i]->type())});
25-
}
26-
for(size_t i = 0; i < g.outputs().size(); ++i) {
27-
returns.push_back({"", unshapedType(g.outputs()[i]->type())});
28-
}
29-
return { method.name(), std::move(args), std::move(returns) };
30-
}
31-
32-
33-
const FunctionSchema& Method::getSchema() const {
34-
if(schema == nullptr) {
35-
schema.reset(new FunctionSchema(defaultSchemaFor(*this)));
36-
}
37-
return *schema;
38-
}
39-
4016
c10::optional<std::vector<Value*>> try_emit_call_to(
4117
Graph& graph,
4218
SourceRange loc,

torch/csrc/jit/script/module.h

+27-4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <torch/csrc/api/include/torch/detail/ordered_dict.h>
1313
#include <torch/csrc/utils/memory.h>
14+
#include <torch/csrc/WindowsTorchApiMacro.h>
1415

1516
#include <ATen/core/ArrayRef.h>
1617
#include "c10/util/Optional.h"
@@ -156,7 +157,12 @@ struct Method {
156157
return *this;
157158
}
158159

159-
const FunctionSchema& getSchema() const;
160+
const FunctionSchema& getSchema() const {
161+
if(schema == nullptr) {
162+
schema.reset(new FunctionSchema(defaultSchemaFor(*this)));
163+
}
164+
return *schema;
165+
}
160166

161167
std::string pretty_print_schema() const {
162168
JIT_ASSERT(schema);
@@ -178,6 +184,23 @@ struct Method {
178184
}
179185

180186
private:
187+
188+
static FunctionSchema defaultSchemaFor(const Method& method) {
189+
std::vector<Argument> args;
190+
std::vector<Argument> returns;
191+
Graph& g = *method.graph();
192+
size_t num_inputs = method.num_inputs();
193+
for(size_t i = 0; i < num_inputs; ++i) {
194+
const Value* v = g.inputs().at(i);
195+
std::string name = v->hasUniqueName() ? v->uniqueName() : ("argument_" + std::to_string(i));
196+
args.push_back({std::move(name), unshapedType(g.inputs()[i]->type())});
197+
}
198+
for(size_t i = 0; i < g.outputs().size(); ++i) {
199+
returns.push_back({"", unshapedType(g.outputs()[i]->type())});
200+
}
201+
return { method.name(), std::move(args), std::move(returns) };
202+
}
203+
181204
std::string name_;
182205
std::shared_ptr<Graph> graph_; // for debugging and for inlining
183206
bool optimize;
@@ -368,7 +391,7 @@ struct Module {
368391
/// destination is on the GPU or vice versa, the copy is performed
369392
/// asynchronously with respect to the host. Otherwise, the argument has no
370393
/// effect.
371-
void to(
394+
TORCH_API void to(
372395
at::Device device,
373396
at::ScalarType dtype,
374397
bool non_blocking = false);
@@ -379,15 +402,15 @@ struct Module {
379402
/// destination is on the GPU or vice versa, the copy is performed
380403
/// asynchronously with respect to the host. Otherwise, the argument has no
381404
/// effect.
382-
void to(at::ScalarType dtype, bool non_blocking = false);
405+
TORCH_API void to(at::ScalarType dtype, bool non_blocking = false);
383406

384407
/// Recursively moves all parameters to the given device.
385408
///
386409
/// If `non_blocking` is true and the source is in pinned memory and
387410
/// destination is on the GPU or vice versa, the copy is performed
388411
/// asynchronously with respect to the host. Otherwise, the argument has no
389412
/// effect.
390-
void to(at::Device device, bool non_blocking = false);
413+
TORCH_API void to(at::Device device, bool non_blocking = false);
391414

392415
/// Run a method from this module.
393416
///

torch/script.h

-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,5 @@
33
#include <torch/csrc/autograd/generated/variable_factories.h>
44
#include <torch/csrc/jit/custom_operator.h>
55
#include <torch/csrc/jit/import.h>
6-
#include <torch/csrc/WindowsTorchApiMacro.h>
76

87
#include <ATen/ATen.h>

0 commit comments

Comments
 (0)