Skip to content

Commit 766b4d3

Browse files
committed
Fix features for resnetv2_50t
1 parent e8045e7 commit 766b4d3

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

timm/models/resnetv2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,14 +291,18 @@ def forward(self, x):
291291
return x
292292

293293

294+
def is_stem_deep(stem_type):
295+
return any([s in stem_type for s in ('deep', 'tiered')])
296+
297+
294298
def create_resnetv2_stem(
295299
in_chs, out_chs=64, stem_type='', preact=True,
296300
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
297301
stem = OrderedDict()
298302
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
299303

300304
# NOTE conv padding mode can be changed by overriding the conv_layer def
301-
if any([s in stem_type for s in ('deep', 'tiered')]):
305+
if is_stem_deep(stem_type):
302306
# A 3 deep 3x3 conv stack as in ResNet V1D models
303307
if 'tiered' in stem_type:
304308
stem_chs = (3 * out_chs // 8, out_chs // 2) # 'T' resnets in resnet.py
@@ -350,7 +354,7 @@ def __init__(
350354
stem_chs = make_div(stem_chs * wf)
351355
self.stem = create_resnetv2_stem(
352356
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
353-
stem_feat = ('stem.conv3' if 'deep' in stem_type else 'stem.conv') if preact else 'stem.norm'
357+
stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
354358
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
355359

356360
prev_chs = stem_chs

0 commit comments

Comments
 (0)