2
2
# Owner(s): ["oncall: quantization"]
3
3
4
4
import torch
5
+ import torch ._C_flatbuffer
5
6
6
7
from torch .ao .quantization import (
7
8
default_dynamic_qconfig ,
22
23
LinearAddModel ,
23
24
)
24
25
25
- from torch .jit .mobile import _load_for_lite_interpreter
26
+ from torch .jit .mobile import _load_for_lite_interpreter , LiteScriptModule
26
27
27
28
from torch .testing import FileCheck
29
+ from torch .utils import bundled_inputs as bundled_inputs
28
30
29
31
import io
32
+ from typing import Dict
30
33
31
34
class myMod (torch .nn .Module ):
32
35
def __init__ (self , weight ):
@@ -396,7 +399,7 @@ def _check_against_ref_dynamic_ptq(self, model):
396
399
self .assertTrue (thrown )
397
400
398
401
399
- def _check_serialization_deserialization (self , model ):
402
+ def _check_serdes_and_device_side_api_helper (self , model , check_device_side_api = False ):
400
403
model .eval ()
401
404
inputs = model .get_example_inputs ()
402
405
ref_m = torch .jit .script (model )
@@ -410,27 +413,40 @@ def _check_serialization_deserialization(self, model):
410
413
ref_m = torch .jit .load (buffer )
411
414
ref_output = ref_m (* inputs )
412
415
413
- m = OnDevicePTQUtils .ptq_dynamic_quantize (model , qconfig_dict )
414
- buffer = io .BytesIO ()
415
- torch .jit .save (m , buffer )
416
- buffer .seek (0 )
417
- m = torch .jit .load (buffer )
418
- m .reset_observers_forward ()
419
- m .observe_forward (* inputs )
420
- m .quantize_forward (* inputs )
421
- output = m .quantized_forward (* inputs )
422
- self .assertTrue (torch .allclose (ref_output , output ))
423
-
424
- # check for lite interpreter
425
- m = OnDevicePTQUtils .ptq_dynamic_quantize (model , qconfig_dict )
426
- buffer = io .BytesIO (m ._save_to_buffer_for_lite_interpreter ())
427
- buffer .seek (0 )
428
- m = _load_for_lite_interpreter (buffer ) # Error here
429
- m .run_method ("reset_observers_forward" )
430
- m .run_method ("observe_forward" , * inputs )
431
- m .run_method ("quantize_forward" , * inputs )
432
- output = m .run_method ("quantized_forward" , * inputs )
433
- self .assertTrue (torch .allclose (ref_output , output ))
416
+ if not check_device_side_api :
417
+ m = OnDevicePTQUtils .ptq_dynamic_quantize (model , qconfig_dict )
418
+ buffer = io .BytesIO ()
419
+ torch .jit .save (m , buffer )
420
+ buffer .seek (0 )
421
+ m = torch .jit .load (buffer )
422
+ m .reset_observers_forward ()
423
+ m .observe_forward (* inputs )
424
+ m .quantize_forward (* inputs )
425
+ output = m .quantized_forward (* inputs )
426
+ self .assertTrue (torch .allclose (ref_output , output ))
427
+ else :
428
+ # check for lite interpreter
429
+ m = OnDevicePTQUtils .ptq_dynamic_quantize (model , qconfig_dict )
430
+ first_input , = inputs
431
+ rand_input = bundled_inputs .bundle_randn (first_input .size (), dtype = first_input .dtype )
432
+ m = bundled_inputs .bundle_inputs (m , inputs = [(rand_input , )])
433
+ buffer = io .BytesIO (m ._save_to_buffer_for_lite_interpreter ())
434
+ buffer .seek (0 )
435
+ m = _load_for_lite_interpreter (buffer ) # Error here
436
+ torch ._C ._quantize_ondevice_ptq_dynamic (m ._c , "forward" )
437
+ self .assertFalse (m .find_method ("quantized_forward" ))
438
+ self .assertFalse (m .find_method ("quantize_forward" ))
439
+ self .assertFalse (m .find_method ("observe_forward" ))
440
+ self .assertFalse (m .find_method ("reset_observers_forward" ))
441
+ output = m (* inputs )
442
+ self .assertTrue (torch .allclose (ref_output , output ))
443
+
444
+ # Now serialize to flabuffer and load from fb and check
445
+ dict : Dict [str , str ] = {}
446
+ bytes = torch ._C_flatbuffer ._save_mobile_module_to_bytes (m ._c , dict )
447
+ m = LiteScriptModule (torch ._C_flatbuffer ._load_mobile_module_from_bytes (bytes ))
448
+ fb_output = m (* inputs )
449
+ self .assertTrue (torch .allclose (ref_output , fb_output ))
434
450
435
451
model .eval ()
436
452
inputs = model .get_example_inputs ()
@@ -445,27 +461,41 @@ def _check_serialization_deserialization(self, model):
445
461
ref_m = torch .jit .load (buffer )
446
462
ref_output = ref_m (* inputs )
447
463
448
- m = OnDevicePTQUtils .ptq_dynamic_quantize (model , qconfig_dict )
449
- buffer = io .BytesIO ()
450
- torch .jit .save (m , buffer )
451
- buffer .seek (0 )
452
- m = torch .jit .load (buffer )
453
- m .reset_observers_forward ()
454
- m .observe_forward (* inputs )
455
- m .quantize_forward (* inputs )
456
- output = m .quantized_forward (* inputs )
457
- self .assertTrue (torch .allclose (ref_output , output ))
464
+ if not check_device_side_api :
465
+ m = OnDevicePTQUtils .ptq_dynamic_quantize (model , qconfig_dict )
466
+ buffer = io .BytesIO ()
467
+ torch .jit .save (m , buffer )
468
+ buffer .seek (0 )
469
+ m = torch .jit .load (buffer )
470
+ m .reset_observers_forward ()
471
+ m .observe_forward (* inputs )
472
+ m .quantize_forward (* inputs )
473
+ output = m .quantized_forward (* inputs )
474
+ self .assertTrue (torch .allclose (ref_output , output ))
475
+ else :
476
+ # check for lite interpreter
477
+ m = OnDevicePTQUtils .ptq_dynamic_quantize (model , qconfig_dict )
478
+ first_input , = inputs
479
+ rand_input = bundled_inputs .bundle_randn (first_input .size (), dtype = first_input .dtype )
480
+ m = bundled_inputs .bundle_inputs (m , inputs = [(rand_input , )])
481
+ buffer = io .BytesIO (m ._save_to_buffer_for_lite_interpreter ())
482
+ buffer .seek (0 )
483
+ m = _load_for_lite_interpreter (buffer ) # Error here
484
+ torch ._C ._quantize_ondevice_ptq_dynamic (m ._c , "forward" )
485
+ self .assertFalse (m .find_method ("quantized_forward" ))
486
+ self .assertFalse (m .find_method ("quantize_forward" ))
487
+ self .assertFalse (m .find_method ("observe_forward" ))
488
+ self .assertFalse (m .find_method ("reset_observers_forward" ))
489
+ output = m (* inputs )
490
+ self .assertTrue (torch .allclose (ref_output , output ))
458
491
459
- # check for lite interpreter
460
- m = OnDevicePTQUtils .ptq_dynamic_quantize (model , qconfig_dict )
461
- buffer = io .BytesIO (m ._save_to_buffer_for_lite_interpreter ())
462
- buffer .seek (0 )
463
- m = _load_for_lite_interpreter (buffer ) # Error here
464
- m .run_method ("reset_observers_forward" )
465
- m .run_method ("observe_forward" , * inputs )
466
- m .run_method ("quantize_forward" , * inputs )
467
- output = m .run_method ("quantized_forward" , * inputs )
468
- self .assertTrue (torch .allclose (ref_output , output ))
492
+
493
+ def _check_serialization_deserialization (self , model ):
494
+ self ._check_serdes_and_device_side_api_helper (model , False )
495
+
496
+
497
+ def _check_device_side_api (self , model ):
498
+ self ._check_serdes_and_device_side_api_helper (model , True )
469
499
470
500
471
501
def test_quantize_forward (self ):
@@ -492,3 +522,8 @@ def test_against_offdevice_dynamic_ptq(self):
492
522
def test_serialization_deserialization (self ):
493
523
model = MyConvLinearModule ()
494
524
self ._check_serialization_deserialization (model )
525
+
526
+
527
+ def test_device_side_api (self ):
528
+ model = MyConvLinearModule ()
529
+ self ._check_device_side_api (model )
0 commit comments