@@ -291,14 +291,18 @@ def forward(self, x):
291
291
return x
292
292
293
293
294
+ def is_stem_deep (stem_type ):
295
+ return any ([s in stem_type for s in ('deep' , 'tiered' )])
296
+
297
+
294
298
def create_resnetv2_stem (
295
299
in_chs , out_chs = 64 , stem_type = '' , preact = True ,
296
300
conv_layer = StdConv2d , norm_layer = partial (GroupNormAct , num_groups = 32 )):
297
301
stem = OrderedDict ()
298
302
assert stem_type in ('' , 'fixed' , 'same' , 'deep' , 'deep_fixed' , 'deep_same' , 'tiered' )
299
303
300
304
# NOTE conv padding mode can be changed by overriding the conv_layer def
301
- if any ([ s in stem_type for s in ( 'deep' , 'tiered' )] ):
305
+ if is_stem_deep ( stem_type ):
302
306
# A 3 deep 3x3 conv stack as in ResNet V1D models
303
307
if 'tiered' in stem_type :
304
308
stem_chs = (3 * out_chs // 8 , out_chs // 2 ) # 'T' resnets in resnet.py
@@ -350,7 +354,7 @@ def __init__(
350
354
stem_chs = make_div (stem_chs * wf )
351
355
self .stem = create_resnetv2_stem (
352
356
in_chans , stem_chs , stem_type , preact , conv_layer = conv_layer , norm_layer = norm_layer )
353
- stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv' ) if preact else 'stem.norm'
357
+ stem_feat = ('stem.conv3' if is_stem_deep ( stem_type ) else 'stem.conv' ) if preact else 'stem.norm'
354
358
self .feature_info .append (dict (num_chs = stem_chs , reduction = 2 , module = stem_feat ))
355
359
356
360
prev_chs = stem_chs
0 commit comments