@@ -354,6 +354,7 @@ def __init__(
354
354
self .dim = dim
355
355
self .max_res = max_res
356
356
self .temperature = temperature
357
+ self .linear_bands = linear_bands
357
358
self .in_pixels = in_pixels
358
359
self .feat_shape = feat_shape
359
360
self .ref_feat_shape = ref_feat_shape
@@ -383,17 +384,7 @@ def __init__(
383
384
self .pos_embed_cos = None
384
385
else :
385
386
# cache full sin/cos embeddings if shape provided up front
386
- emb_sin , emb_cos = build_rotary_pos_embed (
387
- feat_shape = feat_shape ,
388
- dim = dim ,
389
- max_res = max_res ,
390
- linear_bands = linear_bands ,
391
- in_pixels = in_pixels ,
392
- ref_feat_shape = self .ref_feat_shape ,
393
- grid_offset = self .grid_offset ,
394
- grid_indexing = self .grid_indexing ,
395
- temperature = self .temperature ,
396
- )
387
+ emb_sin , emb_cos = self ._get_pos_embed_values (feat_shape )
397
388
self .bands = None
398
389
self .register_buffer (
399
390
'pos_embed_sin' ,
@@ -406,6 +397,30 @@ def __init__(
406
397
persistent = False ,
407
398
)
408
399
400
+ def _get_pos_embed_values (self , feat_shape : List [int ]):
401
+ emb_sin , emb_cos = build_rotary_pos_embed (
402
+ feat_shape = feat_shape ,
403
+ dim = self .dim ,
404
+ max_res = self .max_res ,
405
+ temperature = self .temperature ,
406
+ linear_bands = self .linear_bands ,
407
+ in_pixels = self .in_pixels ,
408
+ ref_feat_shape = self .ref_feat_shape ,
409
+ grid_offset = self .grid_offset ,
410
+ grid_indexing = self .grid_indexing ,
411
+ )
412
+ return emb_sin , emb_cos
413
+
414
+ def update_feat_shape (self , feat_shape : List [int ]):
415
+ if self .feat_shape is not None and feat_shape != self .feat_shape :
416
+ # only update if feat_shape was set and different from previous value
417
+ assert self .pos_embed_sin is not None
418
+ assert self .pos_embed_cos is not None
419
+ emb_sin , emb_cos = self ._get_pos_embed_values (feat_shape )
420
+ self .pos_embed_sin = emb_sin .to (self .pos_embed_sin .device , self .pos_embed_sin .dtype )
421
+ self .pos_embed_cos = emb_cos .to (self .pos_embed_cos .device , self .pos_embed_cos .dtype )
422
+ self .feat_shape = feat_shape
423
+
409
424
def get_embed (self , shape : Optional [List [int ]] = None ):
410
425
if shape is not None and self .bands is not None :
411
426
# rebuild embeddings every call, use if target shape changes
@@ -453,6 +468,7 @@ def __init__(
453
468
self .max_res = max_res
454
469
self .temperature = temperature
455
470
self .in_pixels = in_pixels
471
+ self .linear_bands = linear_bands
456
472
self .feat_shape = feat_shape
457
473
self .ref_feat_shape = ref_feat_shape
458
474
self .grid_offset = grid_offset
@@ -480,27 +496,40 @@ def __init__(
480
496
self .pos_embed = None
481
497
else :
482
498
# cache full sin/cos embeddings if shape provided up front
483
- embeds = build_rotary_pos_embed (
484
- feat_shape = feat_shape ,
485
- dim = dim ,
486
- max_res = max_res ,
487
- linear_bands = linear_bands ,
488
- in_pixels = in_pixels ,
489
- ref_feat_shape = self .ref_feat_shape ,
490
- grid_offset = self .grid_offset ,
491
- grid_indexing = self .grid_indexing ,
492
- temperature = self .temperature ,
493
- )
494
499
self .bands = None
495
500
self .register_buffer (
496
501
'pos_embed' ,
497
- torch . cat ( embeds , - 1 ),
502
+ self . _get_pos_embed_values ( feat_shape = feat_shape ),
498
503
persistent = False ,
499
504
)
500
505
506
+ def _get_pos_embed_values (self , feat_shape : List [int ]):
507
+ embeds = build_rotary_pos_embed (
508
+ feat_shape = feat_shape ,
509
+ dim = self .dim ,
510
+ max_res = self .max_res ,
511
+ temperature = self .temperature ,
512
+ linear_bands = self .linear_bands ,
513
+ in_pixels = self .in_pixels ,
514
+ ref_feat_shape = self .ref_feat_shape ,
515
+ grid_offset = self .grid_offset ,
516
+ grid_indexing = self .grid_indexing ,
517
+ )
518
+ return torch .cat (embeds , - 1 )
519
+
520
+ def update_feat_shape (self , feat_shape : List [int ]):
521
+ if self .feat_shape is not None and feat_shape != self .feat_shape :
522
+ # only update if feat_shape was set and different from previous value
523
+ assert self .pos_embed is not None
524
+ self .pos_embed = self ._get_pos_embed_values (feat_shape ).to (
525
+ device = self .pos_embed .device ,
526
+ dtype = self .pos_embed .dtype ,
527
+ )
528
+ self .feat_shape = feat_shape
529
+
501
530
def get_embed (self , shape : Optional [List [int ]] = None ):
502
531
if shape is not None and self .bands is not None :
503
- # rebuild embeddings every call, use if target shape changes
532
+ # rebuild embeddings from cached bands every call, use if target shape changes
504
533
embeds = build_rotary_pos_embed (
505
534
shape ,
506
535
self .bands ,
@@ -684,6 +713,7 @@ def __init__(
684
713
685
714
head_dim = dim // num_heads
686
715
assert head_dim % 4 == 0 , f"head_dim must be divisible by 4, got { head_dim } "
716
+
687
717
freqs = init_random_2d_freqs (
688
718
head_dim ,
689
719
depth ,
@@ -692,18 +722,32 @@ def __init__(
692
722
rotate = True ,
693
723
) # (2, depth, num_heads, head_dim//2)
694
724
self .freqs = nn .Parameter (freqs )
725
+
695
726
if feat_shape is not None :
696
727
# cache pre-computed grid
697
- t_x , t_y = get_mixed_grid (
698
- feat_shape ,
699
- grid_indexing = grid_indexing ,
700
- device = self .freqs .device
701
- )
728
+ t_x , t_y = self ._get_grid_values (feat_shape )
702
729
self .register_buffer ('t_x' , t_x , persistent = False )
703
730
self .register_buffer ('t_y' , t_y , persistent = False )
704
731
else :
705
732
self .t_x = self .t_y = None
706
733
734
+ def _get_grid_values (self , feat_shape : Optional [List [int ]]):
735
+ t_x , t_y = get_mixed_grid (
736
+ feat_shape ,
737
+ grid_indexing = self .grid_indexing ,
738
+ device = self .freqs .device
739
+ )
740
+ return t_x , t_y
741
+
742
+ def update_feat_shape (self , feat_shape : Optional [List [int ]]):
743
+ if self .feat_shape is not None and feat_shape != self .feat_shape :
744
+ assert self .t_x is not None
745
+ assert self .t_y is not None
746
+ t_x , t_y = self ._get_grid_values (feat_shape )
747
+ self .t_x = t_x .to (self .t_x .device , self .t_x .dtype )
748
+ self .t_y = t_y .to (self .t_y .device , self .t_y .dtype )
749
+ self .feat_shape = feat_shape
750
+
707
751
def get_embed (self , shape : Optional [List [int ]] = None ) -> torch .Tensor :
708
752
"""Generate rotary embeddings for the given spatial shape.
709
753
0 commit comments