Skip to content

Commit 32f1fd8

Browse files
committed
fix
1 parent 9023ef9 commit 32f1fd8

3 files changed

Lines changed: 12 additions & 5 deletions

File tree

configs/hash/CSQ_D/easy/cifar10/32-bits.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ train:
4545
schdr:
4646
key: CosineAnnealingLR
4747
params:
48-
T_max: 100
48+
T_max: 500
4949
gpu:
5050
gpus: 1
5151
vRam: -1

configs/quantization/PQNet/easy/cifar10/32-bits.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ train:
4343
schdr:
4444
key: CosineAnnealingLR
4545
params:
46-
T_max: 100
47-
eta_min: 1.e-8
46+
T_max: 500
4847
gpu:
4948
gpus: 1
5049
vRam: -1

modfire/criterion/csq.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ def __init__(self, bits) -> None:
102102
self._net = nn.ModuleList(ffnNet() for _ in range(bits // 8))
103103
self._bitFlip = CSQ_D._randomBitFlip(bits, int(bits // 32) ** 2)
104104

105+
def reset(self):
106+
for ffnNet in self._net:
107+
for net in ffnNet:
108+
if hasattr(net, "reset_parameters"):
109+
net.reset_parameters()
110+
105111
def forward(self, x, flip):
106112
if flip:
107113
x = self._bitFlip(x)
@@ -132,14 +138,16 @@ def BitFlip(self, numBitsToFlip: int):
132138
self.mapper._bitFlip.BitFlip = numBitsToFlip
133139
self.bitFlip.BitFlip = numBitsToFlip
134140

135-
def resetPermIdx(self):
141+
def reset(self):
136142
# reset permIdx
137143
self.permIdx.data.copy_(torch.randperm(self.m * 8, device=self.permIdx.device))
144+
# reset params
145+
self.mapper.reset()
138146

139147
def forward(self, x: torch.Tensor, y: torch.Tensor):
140148
self._ticker += 1
141149
if self._ticker % int(self.m * 16) == 0:
142-
self.resetPermIdx()
150+
self.reset()
143151
# X are permuted on last dim according to permIdx
144152
x = x[:, self.permIdx]
145153

0 commit comments

Comments
 (0)