4646__all__ = ['ResNetV2' ] # model_registry will add each entrypoint fn to this
4747
4848
49+ class PreActBasic (nn .Module ):
50+ """ Pre-activation basic block (not in typical 'v2' implementations)
51+ """
52+
53+ def __init__ (
54+ self ,
55+ in_chs ,
56+ out_chs = None ,
57+ bottle_ratio = 1.0 ,
58+ stride = 1 ,
59+ dilation = 1 ,
60+ first_dilation = None ,
61+ groups = 1 ,
62+ act_layer = None ,
63+ conv_layer = None ,
64+ norm_layer = None ,
65+ proj_layer = None ,
66+ drop_path_rate = 0. ,
67+ ):
68+ super ().__init__ ()
69+ first_dilation = first_dilation or dilation
70+ conv_layer = conv_layer or StdConv2d
71+ norm_layer = norm_layer or partial (GroupNormAct , num_groups = 32 )
72+ out_chs = out_chs or in_chs
73+ mid_chs = make_divisible (out_chs * bottle_ratio )
74+
75+ if proj_layer is not None and (stride != 1 or first_dilation != dilation or in_chs != out_chs ):
76+ self .downsample = proj_layer (
77+ in_chs ,
78+ out_chs ,
79+ stride = stride ,
80+ dilation = dilation ,
81+ first_dilation = first_dilation ,
82+ preact = True ,
83+ conv_layer = conv_layer ,
84+ norm_layer = norm_layer ,
85+ )
86+ else :
87+ self .downsample = None
88+
89+ self .norm1 = norm_layer (in_chs )
90+ self .conv1 = conv_layer (in_chs , mid_chs , 3 , stride = stride , dilation = first_dilation , groups = groups )
91+ self .norm2 = norm_layer (mid_chs )
92+ self .conv2 = conv_layer (mid_chs , out_chs , 3 , dilation = dilation , groups = groups )
93+ self .drop_path = DropPath (drop_path_rate ) if drop_path_rate > 0 else nn .Identity ()
94+
95+ def zero_init_last (self ):
96+ nn .init .zeros_ (self .conv3 .weight )
97+
98+ def forward (self , x ):
99+ x_preact = self .norm1 (x )
100+
101+ # shortcut branch
102+ shortcut = x
103+ if self .downsample is not None :
104+ shortcut = self .downsample (x_preact )
105+
106+ # residual branch
107+ x = self .conv1 (x_preact )
108+ x = self .conv2 (self .norm2 (x ))
109+ x = self .drop_path (x )
110+ return x + shortcut
111+
49112
50113class PreActBottleneck (nn .Module ):
51114 """Pre-activation (v2) bottleneck block.
@@ -80,8 +143,15 @@ def __init__(
80143
81144 if proj_layer is not None :
82145 self .downsample = proj_layer (
83- in_chs , out_chs , stride = stride , dilation = dilation , first_dilation = first_dilation , preact = True ,
84- conv_layer = conv_layer , norm_layer = norm_layer )
146+ in_chs ,
147+ out_chs ,
148+ stride = stride ,
149+ dilation = dilation ,
150+ first_dilation = first_dilation ,
151+ preact = True ,
152+ conv_layer = conv_layer ,
153+ norm_layer = norm_layer ,
154+ )
85155 else :
86156 self .downsample = None
87157
@@ -140,8 +210,14 @@ def __init__(
140210
141211 if proj_layer is not None :
142212 self .downsample = proj_layer (
143- in_chs , out_chs , stride = stride , dilation = dilation , preact = False ,
144- conv_layer = conv_layer , norm_layer = norm_layer )
213+ in_chs ,
214+ out_chs ,
215+ stride = stride ,
216+ dilation = dilation ,
217+ preact = False ,
218+ conv_layer = conv_layer ,
219+ norm_layer = norm_layer ,
220+ )
145221 else :
146222 self .downsample = None
147223
@@ -339,6 +415,8 @@ def __init__(
339415 stem_type = '' ,
340416 avg_down = False ,
341417 preact = True ,
418+ basic = False ,
419+ bottle_ratio = 0.25 ,
342420 act_layer = nn .ReLU ,
343421 norm_layer = partial (GroupNormAct , num_groups = 32 ),
344422 conv_layer = StdConv2d ,
@@ -390,7 +468,11 @@ def __init__(
390468 curr_stride = 4
391469 dilation = 1
392470 block_dprs = [x .tolist () for x in torch .linspace (0 , drop_path_rate , sum (layers )).split (layers )]
393- block_fn = PreActBottleneck if preact else Bottleneck
471+ if preact :
472+ block_fn = PreActBasic if basic else PreActBottleneck
473+ else :
474+ assert not basic
475+ block_fn = Bottleneck
394476 self .stages = nn .Sequential ()
395477 for stage_idx , (d , c , bdpr ) in enumerate (zip (layers , channels , block_dprs )):
396478 out_chs = make_divisible (c * wf )
@@ -404,6 +486,7 @@ def __init__(
404486 stride = stride ,
405487 dilation = dilation ,
406488 depth = d ,
489+ bottle_ratio = bottle_ratio ,
407490 avg_down = avg_down ,
408491 act_layer = act_layer ,
409492 conv_layer = conv_layer ,
@@ -613,6 +696,14 @@ def _cfg(url='', **kwargs):
613696 hf_hub_id = 'timm/' ,
614697 num_classes = 21843 , custom_load = True ),
615698
699+ 'resnetv2_18.untrained' : _cfg (
700+ interpolation = 'bicubic' , crop_pct = 0.95 ),
701+ 'resnetv2_18d.untrained' : _cfg (
702+ interpolation = 'bicubic' , crop_pct = 0.95 , first_conv = 'stem.conv1' ),
703+ 'resnetv2_34.untrained' : _cfg (
704+ interpolation = 'bicubic' , crop_pct = 0.95 ),
705+ 'resnetv2_34d.untrained' : _cfg (
706+ interpolation = 'bicubic' , crop_pct = 0.95 , first_conv = 'stem.conv1' ),
616707 'resnetv2_50.a1h_in1k' : _cfg (
617708 hf_hub_id = 'timm/' ,
618709 interpolation = 'bicubic' , crop_pct = 0.95 , test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
@@ -679,6 +770,42 @@ def resnetv2_152x4_bit(pretrained=False, **kwargs) -> ResNetV2:
679770 'resnetv2_152x4_bit' , pretrained = pretrained , layers = [3 , 8 , 36 , 3 ], width_factor = 4 , ** kwargs )
680771
681772
773+ @register_model
774+ def resnetv2_18 (pretrained = False , ** kwargs ) -> ResNetV2 :
775+ model_args = dict (
776+ layers = [2 , 2 , 2 , 2 ], channels = (64 , 128 , 256 , 512 ), basic = True , bottle_ratio = 1.0 ,
777+ conv_layer = create_conv2d , norm_layer = BatchNormAct2d
778+ )
779+ return _create_resnetv2 ('resnetv2_18' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
780+
781+
782+ @register_model
783+ def resnetv2_18d (pretrained = False , ** kwargs ) -> ResNetV2 :
784+ model_args = dict (
785+ layers = [2 , 2 , 2 , 2 ], channels = (64 , 128 , 256 , 512 ), basic = True , bottle_ratio = 1.0 ,
786+ conv_layer = create_conv2d , norm_layer = BatchNormAct2d , stem_type = 'deep' , avg_down = True
787+ )
788+ return _create_resnetv2 ('resnetv2_18d' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
789+
790+
791+ @register_model
792+ def resnetv2_34 (pretrained = False , ** kwargs ) -> ResNetV2 :
793+ model_args = dict (
794+ layers = (3 , 4 , 6 , 3 ), channels = (64 , 128 , 256 , 512 ), basic = True , bottle_ratio = 1.0 ,
795+ conv_layer = create_conv2d , norm_layer = BatchNormAct2d
796+ )
797+ return _create_resnetv2 ('resnetv2_34' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
798+
799+
800+ @register_model
801+ def resnetv2_34d (pretrained = False , ** kwargs ) -> ResNetV2 :
802+ model_args = dict (
803+ layers = (3 , 4 , 6 , 3 ), channels = (64 , 128 , 256 , 512 ), basic = True , bottle_ratio = 1.0 ,
804+ conv_layer = create_conv2d , norm_layer = BatchNormAct2d , stem_type = 'deep' , avg_down = True
805+ )
806+ return _create_resnetv2 ('resnetv2_34d' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
807+
808+
682809@register_model
683810def resnetv2_50 (pretrained = False , ** kwargs ) -> ResNetV2 :
684811 model_args = dict (layers = [3 , 4 , 6 , 3 ], conv_layer = create_conv2d , norm_layer = BatchNormAct2d )
0 commit comments