Skip to content

Commit cfd18e1

Browse files
kimishpatelpytorchmergebot
authored andcommitted
[Pytorch][Ondevice quantization] Add device side API to convert model (pytorch#83807)
Summary: This diff adds device side API which will convert the model to its quantized equivalent. THe input model must have been prepared AOT for quantization. API is implemented by: - Running reset obervers - Running observe method - Running quantize method - And replacing method, e.g. forward, with its quantized equivalent. Test Plan: test/quantization/jit/test_ondevice_quantization.py Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D38889818](https://our.internmc.facebook.com/intern/diff/D38889818) Pull Request resolved: pytorch#83807 Approved by: https://github.com/iseeyuan
1 parent eebdcb5 commit cfd18e1

File tree

10 files changed

+256
-43
lines changed

10 files changed

+256
-43
lines changed

buckbuild.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,7 @@ def define_buck_targets(
14171417
"torch/csrc/autograd/VariableTypeManual.cpp",
14181418
"torch/csrc/autograd/FunctionsManual.cpp",
14191419
"torch/csrc/api/src/data/datasets/mnist.cpp",
1420+
"torch/csrc/jit/mobile/quantization.cpp",
14201421
"torch/csrc/jit/mobile/train/export_data.cpp",
14211422
"torch/csrc/jit/mobile/train/optim/sgd.cpp",
14221423
"torch/csrc/jit/mobile/train/random.cpp",

build_variables.bzl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,7 @@ torch_mobile_core = [
564564
"torch/csrc/jit/mobile/observer.cpp",
565565
"torch/csrc/jit/mobile/parse_bytecode.cpp",
566566
"torch/csrc/jit/mobile/parse_operators.cpp",
567+
"torch/csrc/jit/mobile/quantization.cpp",
567568
"torch/csrc/jit/mobile/upgrader_mobile.cpp",
568569
"torch/csrc/jit/runtime/register_prim_ops.cpp",
569570
"torch/csrc/jit/runtime/register_special_ops.cpp",
@@ -612,6 +613,7 @@ libtorch_extra_sources = libtorch_core_jit_sources + [
612613
"torch/csrc/jit/mobile/observer.cpp",
613614
"torch/csrc/jit/mobile/parse_bytecode.cpp",
614615
"torch/csrc/jit/mobile/parse_operators.cpp",
616+
"torch/csrc/jit/mobile/quantization.cpp",
615617
"torch/csrc/jit/mobile/train/export_data.cpp",
616618
"torch/csrc/jit/mobile/train/optim/sgd.cpp",
617619
"torch/csrc/jit/mobile/train/random.cpp",

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
560560
${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
561561
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
562562
${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp
563+
${TORCH_SRC_DIR}/csrc/jit/mobile/quantization.cpp
563564
${TORCH_SRC_DIR}/csrc/jit/mobile/train/export_data.cpp
564565
${TORCH_SRC_DIR}/csrc/jit/mobile/train/optim/sgd.cpp
565566
${TORCH_SRC_DIR}/csrc/jit/mobile/train/random.cpp

test/quantization/jit/test_ondevice_quantization.py

Lines changed: 78 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Owner(s): ["oncall: quantization"]
33

44
import torch
5+
import torch._C_flatbuffer
56

67
from torch.ao.quantization import (
78
default_dynamic_qconfig,
@@ -22,11 +23,13 @@
2223
LinearAddModel,
2324
)
2425

25-
from torch.jit.mobile import _load_for_lite_interpreter
26+
from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule
2627

2728
from torch.testing import FileCheck
29+
from torch.utils import bundled_inputs as bundled_inputs
2830

2931
import io
32+
from typing import Dict
3033

3134
class myMod(torch.nn.Module):
3235
def __init__(self, weight):
@@ -396,7 +399,7 @@ def _check_against_ref_dynamic_ptq(self, model):
396399
self.assertTrue(thrown)
397400

398401

399-
def _check_serialization_deserialization(self, model):
402+
def _check_serdes_and_device_side_api_helper(self, model, check_device_side_api=False):
400403
model.eval()
401404
inputs = model.get_example_inputs()
402405
ref_m = torch.jit.script(model)
@@ -410,27 +413,40 @@ def _check_serialization_deserialization(self, model):
410413
ref_m = torch.jit.load(buffer)
411414
ref_output = ref_m(*inputs)
412415

413-
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
414-
buffer = io.BytesIO()
415-
torch.jit.save(m, buffer)
416-
buffer.seek(0)
417-
m = torch.jit.load(buffer)
418-
m.reset_observers_forward()
419-
m.observe_forward(*inputs)
420-
m.quantize_forward(*inputs)
421-
output = m.quantized_forward(*inputs)
422-
self.assertTrue(torch.allclose(ref_output, output))
423-
424-
# check for lite interpreter
425-
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
426-
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
427-
buffer.seek(0)
428-
m = _load_for_lite_interpreter(buffer) # Error here
429-
m.run_method("reset_observers_forward")
430-
m.run_method("observe_forward", *inputs)
431-
m.run_method("quantize_forward", *inputs)
432-
output = m.run_method("quantized_forward", *inputs)
433-
self.assertTrue(torch.allclose(ref_output, output))
416+
if not check_device_side_api:
417+
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
418+
buffer = io.BytesIO()
419+
torch.jit.save(m, buffer)
420+
buffer.seek(0)
421+
m = torch.jit.load(buffer)
422+
m.reset_observers_forward()
423+
m.observe_forward(*inputs)
424+
m.quantize_forward(*inputs)
425+
output = m.quantized_forward(*inputs)
426+
self.assertTrue(torch.allclose(ref_output, output))
427+
else:
428+
# check for lite interpreter
429+
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
430+
first_input, = inputs
431+
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
432+
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
433+
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
434+
buffer.seek(0)
435+
m = _load_for_lite_interpreter(buffer) # Error here
436+
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
437+
self.assertFalse(m.find_method("quantized_forward"))
438+
self.assertFalse(m.find_method("quantize_forward"))
439+
self.assertFalse(m.find_method("observe_forward"))
440+
self.assertFalse(m.find_method("reset_observers_forward"))
441+
output = m(*inputs)
442+
self.assertTrue(torch.allclose(ref_output, output))
443+
444+
# Now serialize to flabuffer and load from fb and check
445+
dict: Dict[str, str] = {}
446+
bytes = torch._C_flatbuffer._save_mobile_module_to_bytes(m._c, dict)
447+
m = LiteScriptModule(torch._C_flatbuffer._load_mobile_module_from_bytes(bytes))
448+
fb_output = m(*inputs)
449+
self.assertTrue(torch.allclose(ref_output, fb_output))
434450

435451
model.eval()
436452
inputs = model.get_example_inputs()
@@ -445,27 +461,41 @@ def _check_serialization_deserialization(self, model):
445461
ref_m = torch.jit.load(buffer)
446462
ref_output = ref_m(*inputs)
447463

448-
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
449-
buffer = io.BytesIO()
450-
torch.jit.save(m, buffer)
451-
buffer.seek(0)
452-
m = torch.jit.load(buffer)
453-
m.reset_observers_forward()
454-
m.observe_forward(*inputs)
455-
m.quantize_forward(*inputs)
456-
output = m.quantized_forward(*inputs)
457-
self.assertTrue(torch.allclose(ref_output, output))
464+
if not check_device_side_api:
465+
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
466+
buffer = io.BytesIO()
467+
torch.jit.save(m, buffer)
468+
buffer.seek(0)
469+
m = torch.jit.load(buffer)
470+
m.reset_observers_forward()
471+
m.observe_forward(*inputs)
472+
m.quantize_forward(*inputs)
473+
output = m.quantized_forward(*inputs)
474+
self.assertTrue(torch.allclose(ref_output, output))
475+
else:
476+
# check for lite interpreter
477+
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
478+
first_input, = inputs
479+
rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
480+
m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
481+
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
482+
buffer.seek(0)
483+
m = _load_for_lite_interpreter(buffer) # Error here
484+
torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
485+
self.assertFalse(m.find_method("quantized_forward"))
486+
self.assertFalse(m.find_method("quantize_forward"))
487+
self.assertFalse(m.find_method("observe_forward"))
488+
self.assertFalse(m.find_method("reset_observers_forward"))
489+
output = m(*inputs)
490+
self.assertTrue(torch.allclose(ref_output, output))
458491

459-
# check for lite interpreter
460-
m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
461-
buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
462-
buffer.seek(0)
463-
m = _load_for_lite_interpreter(buffer) # Error here
464-
m.run_method("reset_observers_forward")
465-
m.run_method("observe_forward", *inputs)
466-
m.run_method("quantize_forward", *inputs)
467-
output = m.run_method("quantized_forward", *inputs)
468-
self.assertTrue(torch.allclose(ref_output, output))
492+
493+
def _check_serialization_deserialization(self, model):
494+
self._check_serdes_and_device_side_api_helper(model, False)
495+
496+
497+
def _check_device_side_api(self, model):
498+
self._check_serdes_and_device_side_api_helper(model, True)
469499

470500

471501
def test_quantize_forward(self):
@@ -492,3 +522,8 @@ def test_against_offdevice_dynamic_ptq(self):
492522
def test_serialization_deserialization(self):
493523
model = MyConvLinearModule()
494524
self._check_serialization_deserialization(model)
525+
526+
527+
def test_device_side_api(self):
528+
model = MyConvLinearModule()
529+
self._check_device_side_api(model)

torch/_C/__init__.pyi.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def _jit_get_emit_hooks() -> Tuple[Callable, Callable]: ...
298298
def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ...
299299
def _load_for_lite_interpreter_from_buffer(buffer: BinaryIO, map_location: Union[_device, str, None]): ...
300300
def _export_operator_list(module: LiteScriptModule): ...
301+
def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
301302
def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
302303
def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
303304
def _backport_for_mobile(filename_input: Union[str, Path], filename_output: Union[str, Path], to_version: _int) -> None: ...

torch/csrc/jit/mobile/module.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,50 @@ Method Module::get_method(const std::string& name) const {
4343
AT_ERROR("Method '", name, "' is not defined.");
4444
}
4545

46+
bool Module::compareMethodSchemas(
47+
const std::string& name_1,
48+
const std::string& name_2) {
49+
c10::optional<c10::FunctionSchema> schema_1, schema_2;
50+
for (const auto& fn : cu_->methods()) {
51+
if (fn->name() == name_1) {
52+
schema_1 = fn->getSchema();
53+
}
54+
if (fn->name() == name_2) {
55+
schema_2 = fn->getSchema();
56+
}
57+
}
58+
if (schema_1.has_value() && schema_2.has_value()) {
59+
return (schema_1 == schema_2);
60+
}
61+
return false;
62+
}
63+
64+
void Module::unsafeRemoveMethod(const std::string& basename) {
65+
int64_t i = 0;
66+
for (; i < cu_->methods().size(); ++i) {
67+
if ((cu_->methods()[i])->name() == basename) {
68+
break;
69+
}
70+
}
71+
object_->type()->unsafeRemoveMethod(basename);
72+
cu_->unsafeRemoveFunction(i);
73+
}
74+
75+
void Module::unsafeCopyMethod(
76+
const std::string& new_method_name,
77+
const Function& to_be_copied) {
78+
TORCH_CHECK(
79+
!find_method(new_method_name).has_value(),
80+
"Trying to replace existing method.");
81+
const c10::QualifiedName& tobe_copied_name = to_be_copied.qualname();
82+
c10::QualifiedName qualified_method_name(
83+
tobe_copied_name.prefix(), new_method_name);
84+
std::unique_ptr<Function> new_fn = std::make_unique<Function>(
85+
qualified_method_name, to_be_copied.get_code(), to_be_copied.getSchema());
86+
object_->type()->addMethod(new_fn.get());
87+
cu_->register_function(std::move(new_fn));
88+
}
89+
4690
c10::optional<Method> Module::find_method(const std::string& basename) const {
4791
for (const auto& fn : cu_->methods()) {
4892
if (fn->name() == basename) {

torch/csrc/jit/mobile/module.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include <torch/csrc/jit/mobile/debug_info.h>
44
#include <torch/csrc/jit/mobile/function.h>
55
#include <torch/csrc/jit/mobile/method.h>
6+
#include <torch/csrc/jit/mobile/quantization.h>
67

78
namespace torch {
89
namespace jit {
@@ -42,6 +43,10 @@ class CompilationUnit {
4243
Function* find_function(const c10::QualifiedName& qn);
4344
const Function* find_function(const c10::QualifiedName& qn) const;
4445

46+
void unsafeRemoveFunction(const int64_t index) {
47+
methods_.erase(methods_.begin() + index);
48+
}
49+
4550
private:
4651
std::vector<std::unique_ptr<Function>> methods_;
4752
};
@@ -71,6 +76,7 @@ class TORCH_API Module {
7176
return get_method("forward")(std::move(inputs));
7277
}
7378
c10::optional<Method> find_method(const std::string& basename) const;
79+
7480
const std::string name() const {
7581
return object_->name();
7682
}
@@ -152,6 +158,18 @@ class TORCH_API Module {
152158
}
153159

154160
private:
161+
friend class quantization::PTQQuanizationHelper;
162+
163+
bool compareMethodSchemas(
164+
const std::string& name_1,
165+
const std::string& name_2);
166+
167+
void unsafeRemoveMethod(const std::string& basename);
168+
169+
void unsafeCopyMethod(
170+
const std::string& new_method_name,
171+
const Function& to_be_copied);
172+
155173
c10::intrusive_ptr<c10::ivalue::Object> object_;
156174
std::unordered_map<std::string, std::string> metadata_;
157175
std::shared_ptr<CompilationUnit> cu_;
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#include <ATen/Context.h>
2+
#include <torch/csrc/jit/mobile/module.h>
3+
#include <torch/csrc/jit/mobile/quantization.h>
4+
5+
namespace torch {
6+
namespace jit {
7+
namespace mobile {
8+
namespace quantization {
9+
10+
void PTQQuanizationHelper::quantize_dynamic(
11+
torch::jit::mobile::Module& m,
12+
const std::string& method_name) {
13+
at::globalContext().setReleaseWeightsWhenPrepacking(false);
14+
std::string reset_observers_method_name = "reset_observers_" + method_name;
15+
std::string observe_method_name = "observe_" + method_name;
16+
std::string quantize_method_name = "quantize_" + method_name;
17+
std::string quantized_method_name = "quantized_" + method_name;
18+
19+
TORCH_CHECK(
20+
m.find_method(reset_observers_method_name).has_value(),
21+
"PTQ ready module must have",
22+
reset_observers_method_name,
23+
" method.");
24+
TORCH_CHECK(
25+
m.find_method(observe_method_name),
26+
"PTQ ready module must have",
27+
reset_observers_method_name,
28+
" method.");
29+
TORCH_CHECK(
30+
m.find_method(quantize_method_name),
31+
"PTQ ready module must have",
32+
quantize_method_name,
33+
" method.");
34+
TORCH_CHECK(
35+
m.find_method(quantized_method_name),
36+
"PTQ ready module must have",
37+
quantized_method_name,
38+
" method.");
39+
TORCH_CHECK(
40+
m.find_method("get_all_bundled_inputs"),
41+
"PTQ ready module must have get_all_bundled_inputs method.");
42+
43+
auto inputs = m.run_method("get_all_bundled_inputs")
44+
.toList()
45+
.get(0)
46+
.toTupleRef()
47+
.elements()
48+
.vec();
49+
m.get_method(reset_observers_method_name)({});
50+
m.get_method(observe_method_name)(inputs);
51+
m.get_method(quantize_method_name)(inputs);
52+
53+
m.compareMethodSchemas(method_name, quantized_method_name);
54+
m.unsafeRemoveMethod(method_name);
55+
const Function& to_be_copied =
56+
m.find_method(quantized_method_name).value().function();
57+
m.unsafeCopyMethod(method_name, to_be_copied);
58+
m.unsafeRemoveMethod(quantized_method_name);
59+
m.unsafeRemoveMethod(quantize_method_name);
60+
m.unsafeRemoveMethod(observe_method_name);
61+
m.unsafeRemoveMethod(reset_observers_method_name);
62+
}
63+
} // namespace quantization
64+
} // namespace mobile
65+
} // namespace jit
66+
} // namespace torch

0 commit comments

Comments
 (0)