Skip to content

Commit 17f5daa

Browse files
update rcil
1 parent dfd306f commit 17f5daa

File tree

4 files changed

+16
-13
lines changed

4 files changed

+16
-13
lines changed

csseg/configs/rcil/base_cfg.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
'depth': 101,
1111
'outstride': 16,
1212
'out_indices': (0, 1, 2, 3),
13-
'norm_cfg': {'type': 'InPlaceABNSync', 'activation': 'identity'},
14-
'act_cfg': {'type': 'LeakyReLU', 'negative_slope': 0.01, 'inplace': False},
13+
'norm_cfg': {'type': 'InPlaceABNSync', 'activation': 'leaky_relu', 'activation_param': 1.0},
14+
'act_cfg': None,
1515
'pretrained': True,
1616
},
1717
'decoder_cfg': {
@@ -21,8 +21,8 @@
2121
'out_channels': 256,
2222
'dilations': (1, 6, 12, 18),
2323
'pooling_size': 32,
24-
'norm_cfg': {'type': 'InPlaceABNSync', 'activation': 'identity'},
25-
'act_cfg': {'type': 'LeakyReLU', 'negative_slope': 0.01, 'inplace': False},
24+
'norm_cfg': {'type': 'InPlaceABNSync', 'activation': 'leaky_relu', 'activation_param': 1.0},
25+
'act_cfg': None,
2626
},
2727
'losses_cfgs': {
2828
'segmentation_init': {

csseg/modules/models/decoders/rcilaspphead.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,14 @@ def __init__(self, in_channels, feats_channels, out_channels, dilations, pooling
4747
self.global_branch = nn.Sequential(
4848
nn.Conv2d(in_channels, feats_channels, kernel_size=1, stride=1, padding=0, bias=False),
4949
BuildNormalization(placeholder=feats_channels, norm_cfg=norm_cfg),
50-
BuildActivation(act_cfg=act_cfg),
50+
nn.LeakyReLU(0.01),
5151
nn.Conv2d(feats_channels, feats_channels, kernel_size=1, stride=1, padding=0, bias=False),
5252
)
5353
# output project
5454
self.bottleneck_conv = nn.Conv2d(feats_channels * len(dilations), out_channels, kernel_size=1, stride=1, padding=0, bias=False)
5555
self.bottleneck_bn = nn.Sequential(
5656
BuildNormalization(placeholder=out_channels, norm_cfg=norm_cfg),
57-
BuildActivation(act_cfg=act_cfg),
57+
nn.LeakyReLU(0.01),
5858
)
5959
# initialize parameters
6060
assert norm_cfg['activation'] == 'identity'
@@ -77,9 +77,9 @@ def forward(self, x):
7777
input_size = x.shape
7878
# feed to parallel convolutions branch1 and branch2
7979
outputs_branch1 = torch.cat([conv(x) for conv in self.parallel_convs_branch1], dim=1)
80-
outputs_branch1 = self.parallel_bn_branch1[0](outputs_branch1)
80+
outputs_branch1 = self.parallel_bn_branch1(outputs_branch1)
8181
outputs_branch2 = torch.cat([conv(x) for conv in self.parallel_convs_branch2], dim=1)
82-
outputs_branch2 = self.parallel_bn_branch2[0](outputs_branch2)
82+
outputs_branch2 = self.parallel_bn_branch2(outputs_branch2)
8383
# merge
8484
r = torch.rand(1, outputs_branch1.shape[1], 1, 1, dtype=torch.float32)
8585
if not self.training: r[:, :, :, :] = 1.0
@@ -91,7 +91,7 @@ def forward(self, x):
9191
weight_branch2[(r < 0.66) & (r >= 0.33)] = 2.
9292
weight_branch2[r >= 0.66] = 1.
9393
outputs = outputs_branch1 * weight_branch1.type_as(outputs_branch1) * 0.5 + outputs_branch2 * weight_branch2.type_as(outputs_branch2) * 0.5
94-
outputs = self.parallel_bn_branch1[1](outputs)
94+
outputs = F.leaky_relu(outputs, negative_slope=0.01)
9595
outputs = self.bottleneck_conv(outputs)
9696
# feed to global branch
9797
global_feats = self.globalpooling(x)

csseg/modules/models/encoders/resnetplop.py

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, norm
1818
)
1919
'''forward'''
2020
def forward(self, x):
21+
if isinstance(x, tuple): x = x[0]
2122
identity = x
2223
out = self.conv1(x)
2324
out = self.bn1(out)

csseg/modules/models/encoders/resnetrcil.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import re
88
import torch
99
import torch.nn as nn
10+
import torch.nn.functional as F
1011
from .bricks import BuildNormalization
1112
from .resnet import ResNet, BasicBlock, Bottleneck
1213

@@ -23,10 +24,11 @@ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, norm
2324
self.bn2_branch2 = BuildNormalization(placeholder=planes, norm_cfg=shortcut_norm_cfg)
2425
'''forward'''
2526
def forward(self, x):
27+
if isinstance(x, tuple): x = x[0]
2628
identity = x
2729
out = self.conv1(x)
2830
out = self.bn1(out)
29-
out = self.relu(out)
31+
out = F.leaky_relu(out, 0.01)
3032
out_branch1 = self.conv2(out)
3133
out_branch1 = self.bn2(out_branch1)
3234
out_branch2 = self.conv2_branch2(out)
@@ -41,7 +43,7 @@ def forward(self, x):
4143
weight_branch2[(r < 0.66) & (r >= 0.33)] = 2.
4244
weight_branch2[r >= 0.66] = 1.
4345
out = out_branch1 * weight_branch1.type_as(out_branch1) * 0.5 + out_branch2 * weight_branch2.type_as(out_branch2) * 0.5
44-
out = self.relu(out)
46+
out = F.leaky_relu(out, 0.01)
4547
if self.downsample is not None: identity = self.downsample(x)
4648
out = out + identity
4749
distillation = out
@@ -65,7 +67,7 @@ def forward(self, x):
6567
identity = x
6668
out = self.conv1(x)
6769
out = self.bn1(out)
68-
out = self.relu(out)
70+
out = F.leaky_relu(out, 0.01)
6971
out_branch1 = self.conv2(out)
7072
out_branch1 = self.bn2(out_branch1)
7173
out_branch2 = self.conv2_branch2(out)
@@ -80,7 +82,7 @@ def forward(self, x):
8082
weight_branch2[(r < 0.66) & (r >= 0.33)] = 2.
8183
weight_branch2[r >= 0.66] = 1.
8284
out = out_branch1 * weight_branch1.type_as(out_branch1) * 0.5 + out_branch2 * weight_branch2.type_as(out_branch2) * 0.5
83-
out = self.relu(out)
85+
out = F.leaky_relu(out, 0.01)
8486
out = self.conv3(out)
8587
out = self.bn3(out)
8688
if self.downsample is not None: identity = self.downsample(x)

0 commit comments

Comments
 (0)