Skip to content

Commit 124dd51

Browse files
make some update
1 parent 17f5daa commit 124dd51

File tree

9 files changed

+27
-7
lines changed

9 files changed

+27
-7
lines changed

csseg/configs/rcil/base_cfg.py

+1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
'loss_seg': {'MIBUnbiasedCrossEntropyLoss': {'scale_factor': 1.0, 'reduction': 'mean', 'ignore_index': 255}}
3333
},
3434
'distillation': {'scale_factor': 1.0, 'spp_scales': [4, 8, 12, 16, 20, 24]},
35+
'distillation_mib': {'scale_factor': 100, 'alpha': 1.0},
3536
}
3637
}
3738
# RUNNER_CFG

csseg/modules/models/decoders/rcilaspphead.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ def __init__(self, in_channels, feats_channels, out_channels, dilations, pooling
5757
nn.LeakyReLU(0.01),
5858
)
5959
# initialize parameters
60-
assert norm_cfg['activation'] == 'identity'
61-
self.initparams(actname2torchactname(act_cfg['type']), act_cfg.get('negative_slope'))
60+
if hasattr(self.bottleneck_bn[0], 'activation'):
61+
self.initparams(self.bottleneck_bn[0].activation, self.bottleneck_bn[0].activation_param)
62+
else:
63+
self.initparams(actname2torchactname(act_cfg['type']), act_cfg.get('negative_slope'))
6264
'''initparams'''
6365
def initparams(self, nonlinearity, param=None):
6466
gain = nn.init.calculate_gain(nonlinearity, param)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
'''
2+
Function:
3+
Implementation of ResNetUCD
4+
Author:
5+
Zhenchao Jin
6+
'''
7+
from .resnetmib import ResNetMIB as ResNetUCD

csseg/modules/models/segmentors/caf.py

Whitespace-only changes.

csseg/modules/models/segmentors/ewf.py

Whitespace-only changes.

csseg/modules/runners/caf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
'''
22
Function:
3-
Implementation of "Class Similarity Weighted Knowledge Distillation for Continual Semantic Segmentation"
3+
Implementation of "Continual attentive fusion for incremental learning in semantic segmentation"
44
Author:
55
Zhenchao Jin
66
'''

csseg/modules/runners/ewf.py

Whitespace-only changes.

csseg/modules/runners/rcil.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch.nn.functional as F
1313
import torch.distributed as dist
1414
from apex import amp
15+
from .mib import MIBRunner
1516
from .base import BaseRunner
1617

1718

@@ -153,7 +154,7 @@ def train(self, cur_epoch):
153154
seg_targets=seg_targets,
154155
losses_cfgs=seg_losses_cfgs,
155156
)
156-
# --calculate distillation losses
157+
# --calculate pod distillation losses
157158
pod_total_loss, pod_losses_log_dict = 0, {}
158159
if self.history_segmentor is not None:
159160
distillation_feats = outputs['distillation_feats']
@@ -165,8 +166,16 @@ def train(self, cur_epoch):
165166
dataset_type=self.runner_cfg['dataset_cfg']['type'],
166167
**losses_cfgs['distillation']
167168
)
168-
# --merge two losses
169-
loss_total = pod_total_loss + seg_total_loss
169+
# --calculate mib distillation losses
170+
kd_total_loss, kd_losses_log_dict = 0, {}
171+
if self.history_segmentor is not None:
172+
kd_total_loss, kd_losses_log_dict = MIBRunner.featuresdistillation(
173+
history_distillation_feats=F.interpolate(history_outputs['seg_logits'], size=images.shape[2:], mode="bilinear", align_corners=self.segmentor.module.align_corners),
174+
distillation_feats=F.interpolate(outputs['seg_logits'], size=images.shape[2:], mode="bilinear", align_corners=self.segmentor.module.align_corners),
175+
**losses_cfgs['distillation_mib']
176+
)
177+
# --merge three losses
178+
loss_total = pod_total_loss + kd_total_loss + seg_total_loss
170179
# --perform back propagation
171180
with amp.scale_loss(loss_total, self.optimizer) as scaled_loss_total:
172181
scaled_loss_total.backward()
@@ -175,6 +184,7 @@ def train(self, cur_epoch):
175184
self.scheduler.zerograd()
176185
# --logging training loss info
177186
seg_losses_log_dict.update(pod_losses_log_dict)
187+
seg_losses_log_dict.update(kd_losses_log_dict)
178188
seg_losses_log_dict.pop('loss_total')
179189
seg_losses_log_dict['loss_total'] = loss_total.item()
180190
losses_log_dict = self.loggingtraininginfo(seg_losses_log_dict, losses_log_dict, init_losses_log_dict)

csseg/modules/runners/reminder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
'''
22
Function:
3-
Implementation of "Continual attentive fusion for incremental learning in semantic segmentation"
3+
Implementation of "Class Similarity Weighted Knowledge Distillation for Continual Semantic Segmentation"
44
Author:
55
Zhenchao Jin
66
'''

0 commit comments

Comments
 (0)