@@ -36,15 +36,13 @@ def __init__(
3636 The total number of positional features will be max_freqs^2.
3737 """
3838 super ().__init__ ()
39- self .dtype = dtype
39+ self .dtype = dtype
4040 self .max_freqs = max_freqs
4141 self .hidden_size_input = hidden_size_input
4242
4343 # A linear layer to project the concatenated input features and
4444 # positional encodings to the final output dimension.
45- self .embedder = nn .Sequential (
46- operations .Linear (in_channels + max_freqs ** 2 , hidden_size_input , dtype = dtype , device = device )
47- )
45+ self .embedder = operations .Linear (in_channels + max_freqs ** 2 , hidden_size_input , dtype = dtype , device = device )
4846
4947 @lru_cache (maxsize = 4 )
5048 def fetch_pos (self , patch_size : int , device : torch .device , dtype : torch .dtype ) -> torch .Tensor :
@@ -101,7 +99,7 @@ def fetch_pos(self, patch_size: int, device: torch.device, dtype: torch.dtype) -
10199
102100 return dct
103101
104- def forward (self , inputs : torch .Tensor , embedder_dtype : torch . dtype ) -> torch .Tensor :
102+ def forward (self , inputs : torch .Tensor ) -> torch .Tensor :
105103 """
106104 Forward pass for the embedder.
107105
@@ -117,16 +115,11 @@ def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Te
117115 # Infer the patch side length from the number of pixels (P^2).
118116 patch_size = int (P2 ** 0.5 )
119117
120- # Possibly run the operation with a different dtype.
121118 input_dtype = inputs .dtype
122- if embedder_dtype != input_dtype or self .dtype != input_dtype :
123- embedder = self .embedder .to (dtype = embedder_dtype )
124- inputs = inputs .to (dtype = embedder_dtype )
125- else :
126- embedder = self .embedder
119+ inputs = inputs .to (dtype = self .dtype )
127120
128121 # Fetch the pre-computed or cached positional embeddings.
129- dct = self .fetch_pos (patch_size , inputs .device , embedder_dtype )
122+ dct = self .fetch_pos (patch_size , inputs .device , self . dtype )
130123
131124 # Repeat the positional embeddings for each item in the batch.
132125 dct = dct .repeat (B , 1 , 1 )
@@ -136,10 +129,7 @@ def forward(self, inputs: torch.Tensor, embedder_dtype: torch.dtype) -> torch.Te
136129 inputs = torch .cat ((inputs , dct ), dim = - 1 )
137130
138131 # Project the combined tensor to the target hidden size.
139- inputs = embedder (inputs )
140-
141- # No-op if already the same dtype.
142- return inputs .to (dtype = input_dtype )
132+ return self .embedder (inputs ).to (dtype = input_dtype )
143133
144134
145135class NerfGLUBlock (nn .Module ):
0 commit comments