1
1
import torch
2
2
import torch .nn .functional as F
3
3
4
- from NeuFlow import backbone_v8
4
+ from NeuFlow import backbone_v7
5
5
from NeuFlow import transformer
6
6
from NeuFlow import matching
7
7
from NeuFlow import corr
@@ -17,7 +17,7 @@ class NeuFlow(torch.nn.Module):
17
17
def __init__ (self ):
18
18
super (NeuFlow , self ).__init__ ()
19
19
20
- self .backbone = backbone_v8 .CNNEncoder (config .feature_dim_s16 , config .context_dim_s16 , config .feature_dim_s8 , config .context_dim_s8 )
20
+ self .backbone = backbone_v7 .CNNEncoder (config .feature_dim_s16 , config .context_dim_s16 , config .feature_dim_s8 , config .context_dim_s8 )
21
21
22
22
self .cross_attn_s16 = transformer .FeatureAttention (config .feature_dim_s16 + config .context_dim_s16 , num_layers = 2 , ffn = True , ffn_dim_expansion = 1 , post_norm = True )
23
23
@@ -30,16 +30,18 @@ def __init__(self):
30
30
31
31
self .merge_s8 = torch .nn .Sequential (torch .nn .Conv2d (config .feature_dim_s16 + config .feature_dim_s8 , config .feature_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
32
32
torch .nn .GELU (),
33
- torch .nn .Conv2d (config .feature_dim_s8 , config .feature_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ))
33
+ torch .nn .Conv2d (config .feature_dim_s8 , config .feature_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
34
+ torch .nn .BatchNorm2d (config .feature_dim_s8 ))
34
35
35
36
self .context_merge_s8 = torch .nn .Sequential (torch .nn .Conv2d (config .context_dim_s16 + config .context_dim_s8 , config .context_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
36
37
torch .nn .GELU (),
37
- torch .nn .Conv2d (config .context_dim_s8 , config .context_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ))
38
+ torch .nn .Conv2d (config .context_dim_s8 , config .context_dim_s8 , kernel_size = 3 , stride = 1 , padding = 1 , bias = False ),
39
+ torch .nn .BatchNorm2d (config .context_dim_s8 ))
38
40
39
41
self .refine_s16 = refine .Refine (config .context_dim_s16 , config .iter_context_dim_s16 , num_layers = 5 , levels = 1 , radius = 4 , inter_dim = 128 )
40
42
self .refine_s8 = refine .Refine (config .context_dim_s8 , config .iter_context_dim_s8 , num_layers = 5 , levels = 1 , radius = 4 , inter_dim = 96 )
41
43
42
- self .conv_s8 = backbone_v8 .ConvBlock (3 , config .feature_dim_s1 , kernel_size = 8 , stride = 8 , padding = 0 )
44
+ self .conv_s8 = backbone_v7 .ConvBlock (3 , config .feature_dim_s1 , kernel_size = 8 , stride = 8 , padding = 0 )
43
45
self .upsample_s8 = upsample .UpSample (config .feature_dim_s1 , upsample_factor = 8 )
44
46
45
47
for p in self .parameters ():
@@ -70,7 +72,7 @@ def split_features(self, features, context_dim, feature_dim):
70
72
71
73
return features , torch .relu (context )
72
74
73
- def forward (self , img0 , img1 , iters_s16 = 3 , iters_s8 = 7 ):
75
+ def forward (self , img0 , img1 , iters_s16 = 2 , iters_s8 = 7 ):
74
76
75
77
flow_list = []
76
78
@@ -122,7 +124,6 @@ def forward(self, img0, img1, iters_s16=3, iters_s8=7):
122
124
123
125
context_s16 = F .interpolate (context_s16 , scale_factor = 2 , mode = 'nearest' )
124
126
125
- context_s8 = torch .zeros_like (context_s8 )
126
127
context_s8 = self .context_merge_s8 (torch .cat ([context_s8 , context_s16 ], dim = 1 ))
127
128
128
129
iter_context_s8 = self .init_iter_context_s8
0 commit comments