Skip to content

Commit 280a90c

Browse files
authored
Support set_input_size() in EVA models (#2554)
* Support set_input_size() in EVA models * Small fix for unused non-cat RotaryEmbedding module
1 parent 19f2bfb commit 280a90c

File tree

2 files changed

+102
-29
lines changed

2 files changed

+102
-29
lines changed

timm/layers/pos_embed_sincos.py

Lines changed: 73 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ def __init__(
354354
self.dim = dim
355355
self.max_res = max_res
356356
self.temperature = temperature
357+
self.linear_bands = linear_bands
357358
self.in_pixels = in_pixels
358359
self.feat_shape = feat_shape
359360
self.ref_feat_shape = ref_feat_shape
@@ -383,17 +384,7 @@ def __init__(
383384
self.pos_embed_cos = None
384385
else:
385386
# cache full sin/cos embeddings if shape provided up front
386-
emb_sin, emb_cos = build_rotary_pos_embed(
387-
feat_shape=feat_shape,
388-
dim=dim,
389-
max_res=max_res,
390-
linear_bands=linear_bands,
391-
in_pixels=in_pixels,
392-
ref_feat_shape=self.ref_feat_shape,
393-
grid_offset=self.grid_offset,
394-
grid_indexing=self.grid_indexing,
395-
temperature=self.temperature,
396-
)
387+
emb_sin, emb_cos = self._get_pos_embed_values(feat_shape)
397388
self.bands = None
398389
self.register_buffer(
399390
'pos_embed_sin',
@@ -406,6 +397,30 @@ def __init__(
406397
persistent=False,
407398
)
408399

400+
def _get_pos_embed_values(self, feat_shape: List[int]):
401+
emb_sin, emb_cos = build_rotary_pos_embed(
402+
feat_shape=feat_shape,
403+
dim=self.dim,
404+
max_res=self.max_res,
405+
temperature=self.temperature,
406+
linear_bands=self.linear_bands,
407+
in_pixels=self.in_pixels,
408+
ref_feat_shape=self.ref_feat_shape,
409+
grid_offset=self.grid_offset,
410+
grid_indexing=self.grid_indexing,
411+
)
412+
return emb_sin, emb_cos
413+
414+
def update_feat_shape(self, feat_shape: List[int]):
415+
if self.feat_shape is not None and feat_shape != self.feat_shape:
416+
# only update if feat_shape was set and different from previous value
417+
assert self.pos_embed_sin is not None
418+
assert self.pos_embed_cos is not None
419+
emb_sin, emb_cos = self._get_pos_embed_values(feat_shape)
420+
self.pos_embed_sin = emb_sin.to(self.pos_embed_sin.device, self.pos_embed_sin.dtype)
421+
self.pos_embed_cos = emb_cos.to(self.pos_embed_cos.device, self.pos_embed_cos.dtype)
422+
self.feat_shape = feat_shape
423+
409424
def get_embed(self, shape: Optional[List[int]] = None):
410425
if shape is not None and self.bands is not None:
411426
# rebuild embeddings every call, use if target shape changes
@@ -453,6 +468,7 @@ def __init__(
453468
self.max_res = max_res
454469
self.temperature = temperature
455470
self.in_pixels = in_pixels
471+
self.linear_bands = linear_bands
456472
self.feat_shape = feat_shape
457473
self.ref_feat_shape = ref_feat_shape
458474
self.grid_offset = grid_offset
@@ -480,27 +496,40 @@ def __init__(
480496
self.pos_embed = None
481497
else:
482498
# cache full sin/cos embeddings if shape provided up front
483-
embeds = build_rotary_pos_embed(
484-
feat_shape=feat_shape,
485-
dim=dim,
486-
max_res=max_res,
487-
linear_bands=linear_bands,
488-
in_pixels=in_pixels,
489-
ref_feat_shape=self.ref_feat_shape,
490-
grid_offset=self.grid_offset,
491-
grid_indexing=self.grid_indexing,
492-
temperature=self.temperature,
493-
)
494499
self.bands = None
495500
self.register_buffer(
496501
'pos_embed',
497-
torch.cat(embeds, -1),
502+
self._get_pos_embed_values(feat_shape=feat_shape),
498503
persistent=False,
499504
)
500505

506+
def _get_pos_embed_values(self, feat_shape: List[int]):
507+
embeds = build_rotary_pos_embed(
508+
feat_shape=feat_shape,
509+
dim=self.dim,
510+
max_res=self.max_res,
511+
temperature=self.temperature,
512+
linear_bands=self.linear_bands,
513+
in_pixels=self.in_pixels,
514+
ref_feat_shape=self.ref_feat_shape,
515+
grid_offset=self.grid_offset,
516+
grid_indexing=self.grid_indexing,
517+
)
518+
return torch.cat(embeds, -1)
519+
520+
def update_feat_shape(self, feat_shape: List[int]):
521+
if self.feat_shape is not None and feat_shape != self.feat_shape:
522+
# only update if feat_shape was set and different from previous value
523+
assert self.pos_embed is not None
524+
self.pos_embed = self._get_pos_embed_values(feat_shape).to(
525+
device=self.pos_embed.device,
526+
dtype=self.pos_embed.dtype,
527+
)
528+
self.feat_shape = feat_shape
529+
501530
def get_embed(self, shape: Optional[List[int]] = None):
502531
if shape is not None and self.bands is not None:
503-
# rebuild embeddings every call, use if target shape changes
532+
# rebuild embeddings from cached bands every call, use if target shape changes
504533
embeds = build_rotary_pos_embed(
505534
shape,
506535
self.bands,
@@ -684,6 +713,7 @@ def __init__(
684713

685714
head_dim = dim // num_heads
686715
assert head_dim % 4 == 0, f"head_dim must be divisible by 4, got {head_dim}"
716+
687717
freqs = init_random_2d_freqs(
688718
head_dim,
689719
depth,
@@ -692,18 +722,32 @@ def __init__(
692722
rotate=True,
693723
) # (2, depth, num_heads, head_dim//2)
694724
self.freqs = nn.Parameter(freqs)
725+
695726
if feat_shape is not None:
696727
# cache pre-computed grid
697-
t_x, t_y = get_mixed_grid(
698-
feat_shape,
699-
grid_indexing=grid_indexing,
700-
device=self.freqs.device
701-
)
728+
t_x, t_y = self._get_grid_values(feat_shape)
702729
self.register_buffer('t_x', t_x, persistent=False)
703730
self.register_buffer('t_y', t_y, persistent=False)
704731
else:
705732
self.t_x = self.t_y = None
706733

734+
def _get_grid_values(self, feat_shape: Optional[List[int]]):
735+
t_x, t_y = get_mixed_grid(
736+
feat_shape,
737+
grid_indexing=self.grid_indexing,
738+
device=self.freqs.device
739+
)
740+
return t_x, t_y
741+
742+
def update_feat_shape(self, feat_shape: Optional[List[int]]):
743+
if self.feat_shape is not None and feat_shape != self.feat_shape:
744+
assert self.t_x is not None
745+
assert self.t_y is not None
746+
t_x, t_y = self._get_grid_values(feat_shape)
747+
self.t_x = t_x.to(self.t_x.device, self.t_x.dtype)
748+
self.t_y = t_y.to(self.t_y.device, self.t_y.dtype)
749+
self.feat_shape = feat_shape
750+
707751
def get_embed(self, shape: Optional[List[int]] = None) -> torch.Tensor:
708752
"""Generate rotary embeddings for the given spatial shape.
709753

timm/models/eva.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,35 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None)
723723
self.global_pool = global_pool
724724
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
725725

726+
def set_input_size(
727+
self,
728+
img_size: Optional[Tuple[int, int]] = None,
729+
patch_size: Optional[Tuple[int, int]] = None,
730+
) -> None:
731+
"""Update the input image resolution and patch size.
732+
733+
Args:
734+
img_size: New input resolution, if None current resolution is used.
735+
patch_size: New patch size, if None existing patch size is used.
736+
"""
737+
prev_grid_size = self.patch_embed.grid_size
738+
self.patch_embed.set_input_size(img_size=img_size, patch_size=patch_size)
739+
740+
if self.pos_embed is not None:
741+
num_prefix_tokens = 0 if self.no_embed_class else self.num_prefix_tokens
742+
num_new_tokens = self.patch_embed.num_patches + num_prefix_tokens
743+
if num_new_tokens != self.pos_embed.shape[1]:
744+
self.pos_embed = nn.Parameter(resample_abs_pos_embed(
745+
self.pos_embed,
746+
new_size=self.patch_embed.grid_size,
747+
old_size=prev_grid_size,
748+
num_prefix_tokens=num_prefix_tokens,
749+
verbose=True,
750+
))
751+
752+
if self.rope is not None:
753+
self.rope.update_feat_shape(self.patch_embed.grid_size)
754+
726755
def _pos_embed(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
727756
if self.dynamic_img_size:
728757
B, H, W, C = x.shape

0 commit comments

Comments
 (0)