File tree Expand file tree Collapse file tree
quantization/PQNet/easy/cifar10 Expand file tree Collapse file tree Original file line number Diff line number Diff line change 4545 schdr :
4646 key : CosineAnnealingLR
4747 params :
48- T_max : 100
48+ T_max : 500
4949 gpu :
5050 gpus : 1
5151 vRam : -1
Original file line number Diff line number Diff line change 4343 schdr :
4444 key : CosineAnnealingLR
4545 params :
46- T_max : 100
47- eta_min : 1.e-8
46+ T_max : 500
4847 gpu :
4948 gpus : 1
5049 vRam : -1
Original file line number Diff line number Diff line change @@ -102,6 +102,12 @@ def __init__(self, bits) -> None:
102102 self ._net = nn .ModuleList (ffnNet () for _ in range (bits // 8 ))
103103 self ._bitFlip = CSQ_D ._randomBitFlip (bits , int (bits // 32 ) ** 2 )
104104
105+ def reset (self ):
106+ for ffnNet in self ._net :
107+ for net in ffnNet :
108+ if hasattr (net , "reset_parameters" ):
109+ net .reset_parameters ()
110+
105111 def forward (self , x , flip ):
106112 if flip :
107113 x = self ._bitFlip (x )
@@ -132,14 +138,16 @@ def BitFlip(self, numBitsToFlip: int):
132138 self .mapper ._bitFlip .BitFlip = numBitsToFlip
133139 self .bitFlip .BitFlip = numBitsToFlip
134140
135- def resetPermIdx (self ):
141+ def reset (self ):
136142 # reset permIdx
137143 self .permIdx .data .copy_ (torch .randperm (self .m * 8 , device = self .permIdx .device ))
144+ # reset params
145+ self .mapper .reset ()
138146
139147 def forward (self , x : torch .Tensor , y : torch .Tensor ):
140148 self ._ticker += 1
141149 if self ._ticker % int (self .m * 16 ) == 0 :
142- self .resetPermIdx ()
150+ self .reset ()
143151 # X are permuted on last dim according to permIdx
144152 x = x [:, self .permIdx ]
145153
You can’t perform that action at this time.
0 commit comments