@@ -354,6 +354,7 @@ def __init__(
354354 self .dim = dim
355355 self .max_res = max_res
356356 self .temperature = temperature
357+ self .linear_bands = linear_bands
357358 self .in_pixels = in_pixels
358359 self .feat_shape = feat_shape
359360 self .ref_feat_shape = ref_feat_shape
@@ -383,17 +384,7 @@ def __init__(
383384 self .pos_embed_cos = None
384385 else :
385386 # cache full sin/cos embeddings if shape provided up front
386- emb_sin , emb_cos = build_rotary_pos_embed (
387- feat_shape = feat_shape ,
388- dim = dim ,
389- max_res = max_res ,
390- linear_bands = linear_bands ,
391- in_pixels = in_pixels ,
392- ref_feat_shape = self .ref_feat_shape ,
393- grid_offset = self .grid_offset ,
394- grid_indexing = self .grid_indexing ,
395- temperature = self .temperature ,
396- )
387+ emb_sin , emb_cos = self ._get_pos_embed_values (feat_shape )
397388 self .bands = None
398389 self .register_buffer (
399390 'pos_embed_sin' ,
@@ -406,6 +397,30 @@ def __init__(
406397 persistent = False ,
407398 )
408399
400+ def _get_pos_embed_values (self , feat_shape : List [int ]):
401+ emb_sin , emb_cos = build_rotary_pos_embed (
402+ feat_shape = feat_shape ,
403+ dim = self .dim ,
404+ max_res = self .max_res ,
405+ temperature = self .temperature ,
406+ linear_bands = self .linear_bands ,
407+ in_pixels = self .in_pixels ,
408+ ref_feat_shape = self .ref_feat_shape ,
409+ grid_offset = self .grid_offset ,
410+ grid_indexing = self .grid_indexing ,
411+ )
412+ return emb_sin , emb_cos
413+
414+ def update_feat_shape (self , feat_shape : List [int ]):
415+ if self .feat_shape is not None and feat_shape != self .feat_shape :
416+ # only update if feat_shape was set and different from previous value
417+ assert self .pos_embed_sin is not None
418+ assert self .pos_embed_cos is not None
419+ emb_sin , emb_cos = self ._get_pos_embed_values (feat_shape )
420+ self .pos_embed_sin = emb_sin .to (self .pos_embed_sin .device , self .pos_embed_sin .dtype )
421+ self .pos_embed_cos = emb_cos .to (self .pos_embed_cos .device , self .pos_embed_cos .dtype )
422+ self .feat_shape = feat_shape
423+
409424 def get_embed (self , shape : Optional [List [int ]] = None ):
410425 if shape is not None and self .bands is not None :
411426 # rebuild embeddings every call, use if target shape changes
@@ -453,6 +468,7 @@ def __init__(
453468 self .max_res = max_res
454469 self .temperature = temperature
455470 self .in_pixels = in_pixels
471+ self .linear_bands = linear_bands
456472 self .feat_shape = feat_shape
457473 self .ref_feat_shape = ref_feat_shape
458474 self .grid_offset = grid_offset
@@ -480,27 +496,40 @@ def __init__(
480496 self .pos_embed = None
481497 else :
482498 # cache full sin/cos embeddings if shape provided up front
483- embeds = build_rotary_pos_embed (
484- feat_shape = feat_shape ,
485- dim = dim ,
486- max_res = max_res ,
487- linear_bands = linear_bands ,
488- in_pixels = in_pixels ,
489- ref_feat_shape = self .ref_feat_shape ,
490- grid_offset = self .grid_offset ,
491- grid_indexing = self .grid_indexing ,
492- temperature = self .temperature ,
493- )
494499 self .bands = None
495500 self .register_buffer (
496501 'pos_embed' ,
497- torch . cat ( embeds , - 1 ),
502+ self . _get_pos_embed_values ( feat_shape = feat_shape ),
498503 persistent = False ,
499504 )
500505
506+ def _get_pos_embed_values (self , feat_shape : List [int ]):
507+ embeds = build_rotary_pos_embed (
508+ feat_shape = feat_shape ,
509+ dim = self .dim ,
510+ max_res = self .max_res ,
511+ temperature = self .temperature ,
512+ linear_bands = self .linear_bands ,
513+ in_pixels = self .in_pixels ,
514+ ref_feat_shape = self .ref_feat_shape ,
515+ grid_offset = self .grid_offset ,
516+ grid_indexing = self .grid_indexing ,
517+ )
518+ return torch .cat (embeds , - 1 )
519+
520+ def update_feat_shape (self , feat_shape : List [int ]):
521+ if self .feat_shape is not None and feat_shape != self .feat_shape :
522+ # only update if feat_shape was set and different from previous value
523+ assert self .pos_embed is not None
524+ self .pos_embed = self ._get_pos_embed_values (feat_shape ).to (
525+ device = self .pos_embed .device ,
526+ dtype = self .pos_embed .dtype ,
527+ )
528+ self .feat_shape = feat_shape
529+
501530 def get_embed (self , shape : Optional [List [int ]] = None ):
502531 if shape is not None and self .bands is not None :
503- # rebuild embeddings every call, use if target shape changes
532+ # rebuild embeddings from cached bands every call, use if target shape changes
504533 embeds = build_rotary_pos_embed (
505534 shape ,
506535 self .bands ,
@@ -684,6 +713,7 @@ def __init__(
684713
685714 head_dim = dim // num_heads
686715 assert head_dim % 4 == 0 , f"head_dim must be divisible by 4, got { head_dim } "
716+
687717 freqs = init_random_2d_freqs (
688718 head_dim ,
689719 depth ,
@@ -692,18 +722,32 @@ def __init__(
692722 rotate = True ,
693723 ) # (2, depth, num_heads, head_dim//2)
694724 self .freqs = nn .Parameter (freqs )
725+
695726 if feat_shape is not None :
696727 # cache pre-computed grid
697- t_x , t_y = get_mixed_grid (
698- feat_shape ,
699- grid_indexing = grid_indexing ,
700- device = self .freqs .device
701- )
728+ t_x , t_y = self ._get_grid_values (feat_shape )
702729 self .register_buffer ('t_x' , t_x , persistent = False )
703730 self .register_buffer ('t_y' , t_y , persistent = False )
704731 else :
705732 self .t_x = self .t_y = None
706733
734+ def _get_grid_values (self , feat_shape : Optional [List [int ]]):
735+ t_x , t_y = get_mixed_grid (
736+ feat_shape ,
737+ grid_indexing = self .grid_indexing ,
738+ device = self .freqs .device
739+ )
740+ return t_x , t_y
741+
742+ def update_feat_shape (self , feat_shape : Optional [List [int ]]):
743+ if self .feat_shape is not None and feat_shape != self .feat_shape :
744+ assert self .t_x is not None
745+ assert self .t_y is not None
746+ t_x , t_y = self ._get_grid_values (feat_shape )
747+ self .t_x = t_x .to (self .t_x .device , self .t_x .dtype )
748+ self .t_y = t_y .to (self .t_y .device , self .t_y .dtype )
749+ self .feat_shape = feat_shape
750+
707751 def get_embed (self , shape : Optional [List [int ]] = None ) -> torch .Tensor :
708752 """Generate rotary embeddings for the given spatial shape.
709753
0 commit comments