Skip to content

Commit 42368fe

Browse files
author
Study-is-happy
committed
v141
1 parent 2e88459 commit 42368fe

13 files changed

+22897
-341
lines changed

LICENSE

Lines changed: 0 additions & 201 deletions
This file was deleted.

NeuFlow/backbone_v8.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

NeuFlow/neuflow.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22
import torch.nn.functional as F
33

4-
from NeuFlow import backbone_v8
4+
from NeuFlow import backbone_v7
55
from NeuFlow import transformer
66
from NeuFlow import matching
77
from NeuFlow import corr
@@ -17,7 +17,7 @@ class NeuFlow(torch.nn.Module):
1717
def __init__(self):
1818
super(NeuFlow, self).__init__()
1919

20-
self.backbone = backbone_v8.CNNEncoder(config.feature_dim_s16, config.context_dim_s16, config.feature_dim_s8, config.context_dim_s8)
20+
self.backbone = backbone_v7.CNNEncoder(config.feature_dim_s16, config.context_dim_s16, config.feature_dim_s8, config.context_dim_s8)
2121

2222
self.cross_attn_s16 = transformer.FeatureAttention(config.feature_dim_s16+config.context_dim_s16, num_layers=2, ffn=True, ffn_dim_expansion=1, post_norm=True)
2323

@@ -30,16 +30,18 @@ def __init__(self):
3030

3131
self.merge_s8 = torch.nn.Sequential(torch.nn.Conv2d(config.feature_dim_s16 + config.feature_dim_s8, config.feature_dim_s8, kernel_size=3, stride=1, padding=1, bias=False),
3232
torch.nn.GELU(),
33-
torch.nn.Conv2d(config.feature_dim_s8, config.feature_dim_s8, kernel_size=3, stride=1, padding=1, bias=False))
33+
torch.nn.Conv2d(config.feature_dim_s8, config.feature_dim_s8, kernel_size=3, stride=1, padding=1, bias=False),
34+
torch.nn.BatchNorm2d(config.feature_dim_s8))
3435

3536
self.context_merge_s8 = torch.nn.Sequential(torch.nn.Conv2d(config.context_dim_s16 + config.context_dim_s8, config.context_dim_s8, kernel_size=3, stride=1, padding=1, bias=False),
3637
torch.nn.GELU(),
37-
torch.nn.Conv2d(config.context_dim_s8, config.context_dim_s8, kernel_size=3, stride=1, padding=1, bias=False))
38+
torch.nn.Conv2d(config.context_dim_s8, config.context_dim_s8, kernel_size=3, stride=1, padding=1, bias=False),
39+
torch.nn.BatchNorm2d(config.context_dim_s8))
3840

3941
self.refine_s16 = refine.Refine(config.context_dim_s16, config.iter_context_dim_s16, num_layers=5, levels=1, radius=4, inter_dim=128)
4042
self.refine_s8 = refine.Refine(config.context_dim_s8, config.iter_context_dim_s8, num_layers=5, levels=1, radius=4, inter_dim=96)
4143

42-
self.conv_s8 = backbone_v8.ConvBlock(3, config.feature_dim_s1, kernel_size=8, stride=8, padding=0)
44+
self.conv_s8 = backbone_v7.ConvBlock(3, config.feature_dim_s1, kernel_size=8, stride=8, padding=0)
4345
self.upsample_s8 = upsample.UpSample(config.feature_dim_s1, upsample_factor=8)
4446

4547
for p in self.parameters():
@@ -70,7 +72,7 @@ def split_features(self, features, context_dim, feature_dim):
7072

7173
return features, torch.relu(context)
7274

73-
def forward(self, img0, img1, iters_s16=3, iters_s8=7):
75+
def forward(self, img0, img1, iters_s16=2, iters_s8=7):
7476

7577
flow_list = []
7678

@@ -122,7 +124,6 @@ def forward(self, img0, img1, iters_s16=3, iters_s8=7):
122124

123125
context_s16 = F.interpolate(context_s16, scale_factor=2, mode='nearest')
124126

125-
context_s8 = torch.zeros_like(context_s8)
126127
context_s8 = self.context_merge_s8(torch.cat([context_s8, context_s16], dim=1))
127128

128129
iter_context_s8 = self.init_iter_context_s8

NeuFlow/refine.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ def __init__(self, context_dim, iter_context_dim, num_layers, levels, radius, in
2727

2828
self.conv3 = torch.nn.Conv2d(inter_dim, iter_context_dim+2, kernel_size=3, stride=1, padding=1, padding_mode='zeros', bias=True)
2929

30-
self.hidden_act = torch.nn.Tanh()
30+
# self.hidden_act = torch.nn.Tanh()
31+
self.hidden_act = torch.nn.Hardtanh(min_val=-4.0, max_val=4.0)
3132
# self.hidden_norm = torch.nn.BatchNorm2d(feature_dim)
3233

3334
def init_bhwd(self, batch_size, height, width, device, amp):

NeuFlow/transformer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, feature_dim, num_layers, ffn=True, ffn_dim_expansion=1, post_
6666
self.post_norm = post_norm
6767

6868
if self.post_norm:
69-
self.norm = torch.nn.LayerNorm(feature_dim)
69+
self.norm = torch.nn.BatchNorm2d(feature_dim)
7070

7171
def forward(self, concat_features0):
7272

@@ -79,11 +79,11 @@ def forward(self, concat_features0):
7979
concat_features0 = layer(concat_features0, concat_features1)
8080
concat_features1 = torch.cat(concat_features0.chunk(chunks=2, dim=0)[::-1], dim=0)
8181

82-
if self.post_norm:
83-
concat_features0 = self.norm(concat_features0)
84-
8582
# reshape back
8683
concat_features0 = concat_features0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W]
84+
85+
if self.post_norm:
86+
concat_features0 = self.norm(concat_features0)
8787

8888
return concat_features0
8989

0 commit comments

Comments
 (0)