45
45
46
46
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD , OPENAI_CLIP_MEAN , OPENAI_CLIP_STD
47
47
from timm .layers import trunc_normal_ , AvgPool2dSame , DropPath , Mlp , GlobalResponseNormMlp , \
48
- LayerNorm2d , LayerNorm , create_conv2d , get_act_layer , make_divisible , to_ntuple
48
+ LayerNorm2d , LayerNorm , RmsNorm2d , RmsNorm , create_conv2d , get_act_layer , get_norm_layer , make_divisible , to_ntuple
49
49
from timm .layers import NormMlpClassifierHead , ClassifierHead
50
50
from ._builder import build_model_with_cfg
51
51
from ._features import feature_take_indices
@@ -289,24 +289,27 @@ def __init__(
289
289
super ().__init__ ()
290
290
assert output_stride in (8 , 16 , 32 )
291
291
kernel_sizes = to_ntuple (4 )(kernel_sizes )
292
- if norm_layer is None :
293
- norm_layer = LayerNorm2d
294
- norm_layer_cl = norm_layer if conv_mlp else LayerNorm
292
+ use_rms = isinstance (norm_layer , str ) and norm_layer .startswith ('rmsnorm' )
293
+ if norm_layer is None or use_rms :
294
+ norm_layer = RmsNorm2d if use_rms else LayerNorm2d
295
+ norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm )
295
296
if norm_eps is not None :
296
297
norm_layer = partial (norm_layer , eps = norm_eps )
297
298
norm_layer_cl = partial (norm_layer_cl , eps = norm_eps )
298
299
else :
299
300
assert conv_mlp ,\
300
301
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
302
+ norm_layer = get_norm_layer (norm_layer )
301
303
norm_layer_cl = norm_layer
302
304
if norm_eps is not None :
303
305
norm_layer_cl = partial (norm_layer_cl , eps = norm_eps )
306
+ act_layer = get_act_layer (act_layer )
304
307
305
308
self .num_classes = num_classes
306
309
self .drop_rate = drop_rate
307
310
self .feature_info = []
308
311
309
- assert stem_type in ('patch' , 'overlap' , 'overlap_tiered' )
312
+ assert stem_type in ('patch' , 'overlap' , 'overlap_tiered' , 'overlap_act' )
310
313
if stem_type == 'patch' :
311
314
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
312
315
self .stem = nn .Sequential (
@@ -316,11 +319,12 @@ def __init__(
316
319
stem_stride = patch_size
317
320
else :
318
321
mid_chs = make_divisible (dims [0 ] // 2 ) if 'tiered' in stem_type else dims [0 ]
319
- self .stem = nn .Sequential (
322
+ self .stem = nn .Sequential (* filter ( None , [
320
323
nn .Conv2d (in_chans , mid_chs , kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
324
+ act_layer () if 'act' in stem_type else None ,
321
325
nn .Conv2d (mid_chs , dims [0 ], kernel_size = 3 , stride = 2 , padding = 1 , bias = conv_bias ),
322
326
norm_layer (dims [0 ]),
323
- )
327
+ ]) )
324
328
stem_stride = 4
325
329
326
330
self .stages = nn .Sequential ()
@@ -592,6 +596,13 @@ def _cfgv2(url='', **kwargs):
592
596
hf_hub_id = 'timm/' ,
593
597
crop_pct = 0.95 , test_input_size = (3 , 288 , 288 ), test_crop_pct = 1.0 ),
594
598
599
+ 'convnext_zepto_rms.ra4_e3600_r224_in1k' : _cfg (
600
+ hf_hub_id = 'timm/' ,
601
+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 )),
602
+ 'convnext_zepto_rms_ols.untrained' : _cfg (
603
+ # hf_hub_id='timm/',
604
+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
605
+ test_input_size = (3 , 256 , 256 ), test_crop_pct = 0.95 ),
595
606
'convnext_atto.d2_in1k' : _cfg (
596
607
url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth' ,
597
608
hf_hub_id = 'timm/' ,
@@ -600,6 +611,9 @@ def _cfgv2(url='', **kwargs):
600
611
url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth' ,
601
612
hf_hub_id = 'timm/' ,
602
613
test_input_size = (3 , 288 , 288 ), test_crop_pct = 0.95 ),
614
+ 'convnext_atto_rms.untrained' : _cfg (
615
+ #hf_hub_id='timm/',
616
+ test_input_size = (3 , 256 , 256 ), test_crop_pct = 0.95 ),
603
617
'convnext_femto.d1_in1k' : _cfg (
604
618
url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth' ,
605
619
hf_hub_id = 'timm/' ,
@@ -968,6 +982,23 @@ def _cfgv2(url='', **kwargs):
968
982
})
969
983
970
984
985
+ @register_model
986
+ def convnext_zepto_rms (pretrained = False , ** kwargs ) -> ConvNeXt :
987
+ # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
988
+ model_args = dict (depths = (2 , 2 , 4 , 2 ), dims = (32 , 64 , 128 , 256 ), conv_mlp = True , norm_layer = 'rmsnorm2d' )
989
+ model = _create_convnext ('convnext_zepto_rms' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
990
+ return model
991
+
992
+
993
+ @register_model
994
+ def convnext_zepto_rms_ols (pretrained = False , ** kwargs ) -> ConvNeXt :
995
+ # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
996
+ model_args = dict (
997
+ depths = (2 , 2 , 4 , 2 ), dims = (32 , 64 , 128 , 256 ), conv_mlp = True , norm_layer = 'rmsnorm2d' , stem_type = 'overlap_act' )
998
+ model = _create_convnext ('convnext_zepto_rms_ols' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
999
+ return model
1000
+
1001
+
971
1002
@register_model
972
1003
def convnext_atto (pretrained = False , ** kwargs ) -> ConvNeXt :
973
1004
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
@@ -984,6 +1015,14 @@ def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
984
1015
return model
985
1016
986
1017
1018
+ @register_model
1019
+ def convnext_atto_rms (pretrained = False , ** kwargs ) -> ConvNeXt :
1020
+ # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
1021
+ model_args = dict (depths = (2 , 2 , 6 , 2 ), dims = (40 , 80 , 160 , 320 ), conv_mlp = True , norm_layer = 'rmsnorm2d' )
1022
+ model = _create_convnext ('convnext_atto_rms' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1023
+ return model
1024
+
1025
+
987
1026
@register_model
988
1027
def convnext_femto (pretrained = False , ** kwargs ) -> ConvNeXt :
989
1028
# timm femto variant
0 commit comments