Skip to content

Commit 44f1a34

Browse files
authored
Merge pull request #2281 from huggingface/convnext_zepto
convnext zepto, rmsnorm experiments
2 parents e3242a5 + 545bd40 commit 44f1a34

File tree

5 files changed

+89
-11
lines changed

5 files changed

+89
-11
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from .mixed_conv2d import MixedConv2d
3535
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
3636
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
37-
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
37+
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
3838
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
3939
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
4040
from .padding import get_padding, get_same_padding, pad_same

timm/layers/create_norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch.nn as nn
1212

13-
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm
13+
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
1414
from torchvision.ops.misc import FrozenBatchNorm2d
1515

1616
_NORM_MAP = dict(
@@ -22,6 +22,7 @@
2222
layernorm=LayerNorm,
2323
layernorm2d=LayerNorm2d,
2424
rmsnorm=RmsNorm,
25+
rmsnorm2d=RmsNorm2d,
2526
frozenbatchnorm2d=FrozenBatchNorm2d,
2627
)
2728
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}

timm/layers/norm.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,41 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
152152
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
153153
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
154154
return x
155+
156+
157+
class RmsNorm2d(nn.Module):
158+
""" RmsNorm w/ fast (apex) norm if available
159+
"""
160+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
161+
normalized_shape: Tuple[int, ...]
162+
eps: float
163+
elementwise_affine: bool
164+
165+
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
166+
factory_kwargs = {'device': device, 'dtype': dtype}
167+
super().__init__()
168+
normalized_shape = channels
169+
if isinstance(normalized_shape, numbers.Integral):
170+
# mypy error: incompatible types in assignment
171+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
172+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
173+
self.eps = eps
174+
self.elementwise_affine = affine
175+
if self.elementwise_affine:
176+
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
177+
else:
178+
self.register_parameter('weight', None)
179+
180+
self.reset_parameters()
181+
182+
def reset_parameters(self) -> None:
183+
if self.elementwise_affine:
184+
nn.init.ones_(self.weight)
185+
186+
def forward(self, x: torch.Tensor) -> torch.Tensor:
187+
x = x.permute(0, 2, 3, 1)
188+
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
189+
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
190+
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
191+
x = x.permute(0, 3, 1, 2)
192+
return x

timm/models/convnext.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545

4646
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
4747
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
48-
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
48+
LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple
4949
from timm.layers import NormMlpClassifierHead, ClassifierHead
5050
from ._builder import build_model_with_cfg
5151
from ._features import feature_take_indices
@@ -289,24 +289,27 @@ def __init__(
289289
super().__init__()
290290
assert output_stride in (8, 16, 32)
291291
kernel_sizes = to_ntuple(4)(kernel_sizes)
292-
if norm_layer is None:
293-
norm_layer = LayerNorm2d
294-
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
292+
use_rms = isinstance(norm_layer, str) and norm_layer.startswith('rmsnorm')
293+
if norm_layer is None or use_rms:
294+
norm_layer = RmsNorm2d if use_rms else LayerNorm2d
295+
norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm)
295296
if norm_eps is not None:
296297
norm_layer = partial(norm_layer, eps=norm_eps)
297298
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
298299
else:
299300
assert conv_mlp,\
300301
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
302+
norm_layer = get_norm_layer(norm_layer)
301303
norm_layer_cl = norm_layer
302304
if norm_eps is not None:
303305
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
306+
act_layer = get_act_layer(act_layer)
304307

305308
self.num_classes = num_classes
306309
self.drop_rate = drop_rate
307310
self.feature_info = []
308311

309-
assert stem_type in ('patch', 'overlap', 'overlap_tiered')
312+
assert stem_type in ('patch', 'overlap', 'overlap_tiered', 'overlap_act')
310313
if stem_type == 'patch':
311314
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
312315
self.stem = nn.Sequential(
@@ -316,11 +319,12 @@ def __init__(
316319
stem_stride = patch_size
317320
else:
318321
mid_chs = make_divisible(dims[0] // 2) if 'tiered' in stem_type else dims[0]
319-
self.stem = nn.Sequential(
322+
self.stem = nn.Sequential(*filter(None, [
320323
nn.Conv2d(in_chans, mid_chs, kernel_size=3, stride=2, padding=1, bias=conv_bias),
324+
act_layer() if 'act' in stem_type else None,
321325
nn.Conv2d(mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias),
322326
norm_layer(dims[0]),
323-
)
327+
]))
324328
stem_stride = 4
325329

326330
self.stages = nn.Sequential()
@@ -592,6 +596,13 @@ def _cfgv2(url='', **kwargs):
592596
hf_hub_id='timm/',
593597
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
594598

599+
'convnext_zepto_rms.ra4_e3600_r224_in1k': _cfg(
600+
hf_hub_id='timm/',
601+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
602+
'convnext_zepto_rms_ols.untrained': _cfg(
603+
# hf_hub_id='timm/',
604+
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
605+
test_input_size=(3, 256, 256), test_crop_pct=0.95),
595606
'convnext_atto.d2_in1k': _cfg(
596607
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
597608
hf_hub_id='timm/',
@@ -600,6 +611,9 @@ def _cfgv2(url='', **kwargs):
600611
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
601612
hf_hub_id='timm/',
602613
test_input_size=(3, 288, 288), test_crop_pct=0.95),
614+
'convnext_atto_rms.untrained': _cfg(
615+
#hf_hub_id='timm/',
616+
test_input_size=(3, 256, 256), test_crop_pct=0.95),
603617
'convnext_femto.d1_in1k': _cfg(
604618
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
605619
hf_hub_id='timm/',
@@ -968,6 +982,23 @@ def _cfgv2(url='', **kwargs):
968982
})
969983

970984

985+
@register_model
986+
def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
987+
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
988+
model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d')
989+
model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
990+
return model
991+
992+
993+
@register_model
994+
def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt:
995+
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
996+
model_args = dict(
997+
depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d', stem_type='overlap_act')
998+
model = _create_convnext('convnext_zepto_rms_ols', pretrained=pretrained, **dict(model_args, **kwargs))
999+
return model
1000+
1001+
9711002
@register_model
9721003
def convnext_atto(pretrained=False, **kwargs) -> ConvNeXt:
9731004
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
@@ -984,6 +1015,14 @@ def convnext_atto_ols(pretrained=False, **kwargs) -> ConvNeXt:
9841015
return model
9851016

9861017

1018+
@register_model
1019+
def convnext_atto_rms(pretrained=False, **kwargs) -> ConvNeXt:
1020+
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
1021+
model_args = dict(depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), conv_mlp=True, norm_layer='rmsnorm2d')
1022+
model = _create_convnext('convnext_atto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
1023+
return model
1024+
1025+
9871026
@register_model
9881027
def convnext_femto(pretrained=False, **kwargs) -> ConvNeXt:
9891028
# timm femto variant

timm/models/vision_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2019,7 +2019,7 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
20192019
hf_hub_id='timm/',
20202020
input_size=(3, 160, 160), crop_pct=0.95),
20212021
'test_vit3.r160_in1k': _cfg(
2022-
#hf_hub_id='timm/',
2022+
hf_hub_id='timm/',
20232023
input_size=(3, 160, 160), crop_pct=0.95),
20242024
}
20252025

@@ -3238,7 +3238,7 @@ def test_vit3(pretrained: bool = False, **kwargs) -> VisionTransformer:
32383238
""" ViT Test
32393239
"""
32403240
model_args = dict(
3241-
patch_size=16, embed_dim=96, depth=10, num_heads=3, mlp_ratio=2,
3241+
patch_size=16, embed_dim=96, depth=9, num_heads=3, mlp_ratio=2,
32423242
class_token=False, reg_tokens=1, global_pool='map', init_values=1e-5)
32433243
model = _create_vision_transformer('test_vit3', pretrained=pretrained, **dict(model_args, **kwargs))
32443244
return model

0 commit comments

Comments
 (0)