@@ -136,7 +136,7 @@ def __init__(self, in_dim, out_dim, dropout=0.0):
136136 if in_dim != out_dim else nn .Identity ())
137137
138138 def forward (self , x , feat_cache = None , feat_idx = [0 ]):
139- h = self . shortcut ( x )
139+ old_x = x
140140 for layer in self .residual :
141141 if isinstance (layer , CausalConv3d ) and feat_cache is not None :
142142 idx = feat_idx [0 ]
@@ -156,7 +156,7 @@ def forward(self, x, feat_cache=None, feat_idx=[0]):
156156 feat_idx [0 ] += 1
157157 else :
158158 x = layer (x )
159- return x + h
159+ return x + self . shortcut ( old_x )
160160
161161
162162def patchify (x , patch_size ):
@@ -327,7 +327,7 @@ def __init__(self,
327327 self .downsamples = nn .Sequential (* downsamples )
328328
329329 def forward (self , x , feat_cache = None , feat_idx = [0 ]):
330- x_copy = x . clone ()
330+ x_copy = x
331331 for module in self .downsamples :
332332 x = module (x , feat_cache , feat_idx )
333333
@@ -369,7 +369,7 @@ def __init__(self,
369369 self .upsamples = nn .Sequential (* upsamples )
370370
371371 def forward (self , x , feat_cache = None , feat_idx = [0 ], first_chunk = False ):
372- x_main = x . clone ()
372+ x_main = x
373373 for module in self .upsamples :
374374 x_main = module (x_main , feat_cache , feat_idx )
375375 if self .avg_shortcut is not None :
0 commit comments