Skip to content

Commit 0a67435

Browse files
[ADD] calc acl and mcl
1. consistency_measure 2. metric 3. new test func 4. fixed val mmnist set 5. logging (should have no prob)
1 parent dde126d commit 0a67435

File tree

11 files changed

+143
-20
lines changed

11 files changed

+143
-20
lines changed

.vscode/launch.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,11 @@
7171
"--decode_hidden",
7272
"false",
7373
"--num_slots",
74-
"3"
74+
"3",
75+
"--task",
76+
"MMNIST",
77+
"--use_val_set",
78+
"false",
7579
// "--batch_size",
7680
// "8"
7781
],

argument_parser.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def _to_int(foo):
9595
else: #['1', '2', '3']
9696
return [int(foo) for foo in foo]
9797
return _to_int(list__)
98-
print(mmnist_num_obj('1,2,3,4;1;0'))
9998

10099
def argument_parser():
101100
"""Function to parse all the arguments"""
@@ -118,7 +117,7 @@ def argument_parser():
118117
parser.add_argument('--ball_trainset', type=str2ballset, default=None, help='train set for ball task')
119118
parser.add_argument('--ball_testset', type=str2ballset, default=None, help='test set for ball task')
120119

121-
parser.add_argument('--mmnist_num_objects', '--num_objects', '--num_obj', type=mmnist_num_obj, default=[[2],[2],[1,2,3]],
120+
parser.add_argument('--mmnist_num_objects', '--num_objects', '--num_obj', type=mmnist_num_obj, default=[[2],[2],[2]],
122121
help='number of objects in the MMNIST task (train/test/val). default: 2;2;1,2,3')
123122

124123
# Training Settings
@@ -138,6 +137,7 @@ def argument_parser():
138137
parser.add_argument('--test_frequency', type=int, default=10,
139138
metavar="Frequency at which we log the intermediate variables of the model",
140139
help='Just type in a positive integer')
140+
parser.add_argument('--use_val_set', type=str2bool, default=False)
141141
parser.add_argument('--path_to_load_model', type=str, default="",
142142
metavar='Relative Path to load the model',
143143
help='Relative Path to load the model. If this is empty, no model'
@@ -297,7 +297,8 @@ def argument_parser():
297297
elif args.task == 'VOR':
298298
args.mot_gt_file = os.path.join(args.dataset_dir, 'gt_jsons', 'vor_test.json')
299299

300-
300+
if args.use_val_set == True and args.task != 'MMNIST':
301+
raise NotImplementedError
301302

302303
return args
303304

datasets/MovingMNIST.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ class MovingMNIST(data.Dataset):
4545
("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
4646
("mnist_test_seq.npy", "be083ec986bfe91a449d63653c411eb2"),
4747
]
48+
val_dataset = 'mmnist_val.pt'
4849
def __init__(self, root, train=True, n_frames_input=10, n_frames_output=10, num_objects=[2],
4950
static_prob=-1,
5051
download=False,
5152
transform=None,
52-
length=int(1e4),):
53+
length=int(1e4),
54+
val=False):
5355
'''
5456
Args:
5557
`root`: Root directory of the dataset (mnist dataset and moving mnist test set)
@@ -69,6 +71,10 @@ def __init__(self, root, train=True, n_frames_input=10, n_frames_output=10, num_
6971
super(MovingMNIST, self).__init__()
7072
self.root = root
7173
self.is_train = train
74+
if not self.is_train:
75+
self.is_val = val
76+
else:
77+
self.is_val = False
7278

7379
if download:
7480
self.download()
@@ -79,12 +85,22 @@ def __init__(self, root, train=True, n_frames_input=10, n_frames_output=10, num_
7985
self.dataset = None
8086
if train:
8187
self.mnist, self.mnist_label = load_mnist(root)
88+
elif self.is_val:
89+
if num_objects[0] != 2:
90+
self.mnist, self.mnist_label = load_mnist(root)
91+
else:
92+
self.dataset = torch.load(os.path.join(root, self.val_dataset))
8293
else:
8394
if num_objects[0] != 2:
8495
self.mnist, self.mnist_label = load_mnist(root)
8596
else:
8697
self.dataset = load_fixed_set(root, False)
87-
self.length = length if self.dataset is None else self.dataset.shape[1]
98+
if self.dataset is None:
99+
self.length = length
100+
elif self.is_val:
101+
self.length = len(self.dataset)
102+
else:
103+
self.length = self.dataset.shape[1]
88104

89105
self.num_objects = num_objects
90106
self.n_frames_input = n_frames_input
@@ -180,6 +196,9 @@ def __getitem__(self, idx):
180196
num_digits = random.choice(self.num_objects)
181197
# Generate data on the fly
182198
images, ind_images, labels = self.generate_moving_mnist(num_digits)
199+
elif self.is_val:
200+
labels, input, output, ind_images = *self.dataset[idx],
201+
return labels, input, output, ind_images
183202
else:
184203
images = self.dataset[:, idx, ...]
185204

datasets/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,12 @@ def setup_dataloader(args):
7474
)
7575
val_set = MovingMNIST(
7676
root=args.dataset_dir,
77-
train=True,
77+
train=False,
7878
n_frames_input=10,
7979
n_frames_output=10,
8080
num_objects=args.mmnist_num_objects[2],# 1 2 3
81-
download=True
81+
download=True,
82+
val=True,
8283
)
8384
elif args.task == 'BBALL':
8485
train_set = BouncingBall(root=args.dataset_dir, train=True, length=20, filename=args.ball_trainset)
@@ -158,7 +159,7 @@ def setup_dataloader(args):
158159
dataset=val_set,
159160
batch_size=args.batch_size,
160161
shuffle=True,
161-
num_workers=4 if not DEBUG else 0,
162+
num_workers=0 if not DEBUG else 0,
162163
worker_init_fn=seed_worker,
163164
generator=g,
164165
)

gen_mmnist.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
import numpy as np
3+
from datasets import MovingMNIST
4+
import argparse
5+
from tqdm import tqdm
6+
7+
# ind_image = np.load('ind_images.npy')
8+
9+
val_set = MovingMNIST(
10+
root='data',
11+
train=True,
12+
n_frames_input=10,
13+
n_frames_output=10,
14+
num_objects=[2],# 1 2 3
15+
download=False,
16+
length=2000
17+
)
18+
19+
# list_labels, list_input, list_output, list_ind_images = [], [], [], []
20+
# for idx in tqdm(range(len(val_set))):
21+
# labels,input,output, ind_images = val_set[idx]
22+
# list_labels.append(labels)
23+
# list_input.append(input)
24+
# list_output.append(output)
25+
# list_ind_images.append(ind_images)
26+
27+
# tensors = torch.stack(list_labels), torch.stack(list_input), torch.stack(list_output), torch.stack(list_ind_images)
28+
# names = ['labels.npy', 'input.npy', 'output.npy', 'ind_images.npy']
29+
# for name, tensor in zip(names, tensors):
30+
# np.save(name, tensor.numpy())
31+
32+
dataset = []
33+
for idx in tqdm(range(len(val_set))):
34+
dataset.append(val_set[idx])
35+
36+
torch.save(dataset, 'mmnist_val.pt')
37+
...

networks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -742,7 +742,7 @@ def forward(self, x, h_prev, M_prev=None):
742742
if "SEP" in self.decoder_type:
743743
curr_dec_out_, curr_channels, curr_alpha_mask = self.decoder(encoded_input)
744744
next_dec_out_, next_channels, next_alpha_mask = self.decoder(pred_latent)
745-
if self.do_logging:
745+
if self.do_logging or True: # always log ind_output
746746
blocked_out_ = next_channels*next_alpha_mask
747747
self.hidden_features['individual_output'] = blocked_out_.detach()
748748
self.hidden_features['individual_recons'] = (curr_channels*curr_alpha_mask).detach()
@@ -752,7 +752,7 @@ def forward(self, x, h_prev, M_prev=None):
752752
else:
753753
if "SEP" in self.decoder_type:
754754
next_dec_out_, next_channels, next_alpha_mask = self.decoder(h_new)
755-
if self.do_logging:
755+
if self.do_logging or True: # always log ind_output
756756
blocked_out_ = next_channels*next_alpha_mask
757757
self.hidden_features['individual_output'] = blocked_out_.detach()
758758
else:

test_mmnist.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from utils.visualize import VecStack, make_grid_video, plot_heatmap, mplfig_to_video
1717
from utils.logging import log_stats, enable_logging, setup_wandb_columns
1818
from utils.metric import f1_score, gen_masks, get_mot_metrics
19+
from utils.metric import consistency_measure
1920
from tqdm import tqdm
2021
import wandb
2122
from utils import util
@@ -46,7 +47,7 @@ def get_grad_norm(model):
4647
return total_norm
4748

4849
# @torch.no_grad()
49-
def test(model, test_loader, args, loss_fn, writer, rollout=True, epoch=0, log_columns=None):
50+
def test(model, test_loader, args, loss_fn, writer, rollout=True, epoch=0, log_columns=None, calc_csty=False):
5051
'''test(model, test_loader, args, loss_fn, writer, rollout)'''
5152
start_time = time()
5253
# wandb table
@@ -90,6 +91,8 @@ def test(model, test_loader, args, loss_fn, writer, rollout=True, epoch=0, log_c
9091
ssim = 0.
9192
most_used_units = []
9293
pred_list = []
94+
epoch_avr_len = 0.
95+
epoch_max_len = 0.
9396
id_counter = 0
9497
for batch_idx, data in enumerate(tqdm(test_loader) if __name__ == "__main__" else test_loader): # tqdm doesn't work here?
9598
if args.task == 'MMNIST':
@@ -131,6 +134,7 @@ def test(model, test_loader, args, loss_fn, writer, rollout=True, epoch=0, log_c
131134
data.shape[3],
132135
data.shape[4])
133136
) # (BS, num_blocks, T, C, H, W)
137+
ind_pred = torch.empty((data.shape[0], args.num_hidden, data.shape[1]-rollout_start, data.shape[2], data.shape[3], data.shape[4]))
134138
reconstruction = []
135139
individual_recons = []
136140
soft_masks = [] # list of batches of masks
@@ -171,6 +175,8 @@ def test(model, test_loader, args, loss_fn, writer, rollout=True, epoch=0, log_c
171175
f1 += f1_frame
172176

173177
prediction[:, frame+1, :, :, :] = preds
178+
if frame >= rollout_start:
179+
ind_pred[:, :, frame-rollout_start, :, :, :] = model.hidden_features['individual_output']
174180
if do_logging:
175181
blocked_prediction[:, 0, frame+1, :, :, :] = preds # dim == 6
176182
blocked_prediction[:, 1:, frame+1, :, :, :] = model.hidden_features['individual_output']
@@ -255,6 +261,14 @@ def test(model, test_loader, args, loss_fn, writer, rollout=True, epoch=0, log_c
255261
soft_masks=torch.stack(soft_masks, dim=1).cpu(), # [BS, T, K, H, W]
256262
)
257263

264+
# calculate consistency
265+
avr_len, max_len = None, None
266+
if 'SEP' in args.decoder_type and calc_csty:
267+
avr_len, max_len = consistency_measure(ind_pred, ind_digits[:, :, rollout_start:, ...],
268+
corr_padding=(1,1), output_ids=False, reduction='mean', exclude_background=True)
269+
epoch_avr_len += avr_len
270+
epoch_max_len += max_len
271+
258272
if not rollout:
259273
ssim += pt_ssim.ssim(data[:,1:,:,:,:].reshape((-1,1,data.shape[3],data.shape[4])), # data.shape = (batch, frame, 1, height, width)
260274
prediction[:,1:,:,:,:].reshape((-1,1,data.shape[3],data.shape[4])))
@@ -284,6 +298,8 @@ def test(model, test_loader, args, loss_fn, writer, rollout=True, epoch=0, log_c
284298
epoch_recon_loss /= len(test_loader)
285299
epoch_pred_loss /= len(test_loader)
286300
epoch_mseloss = epoch_mseloss / (batch_idx+1)
301+
epoch_avr_len /= len(test_loader)
302+
epoch_max_len /= len(test_loader)
287303
ssim = ssim / (batch_idx+1)
288304
f1_avg = f1 / (batch_idx+1) / (data.shape[1]-1)
289305

@@ -336,6 +352,9 @@ def test(model, test_loader, args, loss_fn, writer, rollout=True, epoch=0, log_c
336352
metrics['rule_attn_probs_sm'] = torch.stack(rule_attn_probs_sm, dim=1) # Shape: [N, T, num_hidden, num_rules]
337353
if len(rule_attn_probs_gsm) > 0:
338354
metrics['rule_attn_probs_gsm'] = torch.stack(rule_attn_probs_gsm, dim=1) # Shape: [N, T, num_hidden, num_rules]
355+
if 'SEP' in args.decoder_type and calc_csty:
356+
metrics['avr_len'] = epoch_avr_len
357+
metrics['max_len'] = epoch_max_len
339358

340359
# slot attention
341360
if args.use_slot_attention:
@@ -346,6 +365,7 @@ def test(model, test_loader, args, loss_fn, writer, rollout=True, epoch=0, log_c
346365
print('test runtime:', time() - start_time)
347366
return epoch_loss, epoch_recon_loss, epoch_pred_loss, prediction, data, metrics, test_table
348367

368+
349369
@torch.no_grad()
350370
def dec_rim_util(model, h):
351371
"""check the contribution of the (num_module)-th RIM
@@ -425,13 +445,14 @@ def main():
425445
# call test function
426446
test_loss, recon_loss, pred_loss, prediction, data, metrics, test_table = test(
427447
model = model,
428-
test_loader = test_loader,
448+
test_loader = val_loader if args.use_val_set else test_loader,
429449
args = args,
430450
loss_fn = loss_fn,
431451
writer = writer,
432452
rollout = True,
433453
epoch = epoch,
434454
log_columns = columns,
455+
calc_csty = True if args.use_val_set else False,
435456
)
436457
log_stats(
437458
args=args,

test_mmnist_val.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from argument_parser import argument_parser
2+
from datasets import setup_dataloader
3+
4+
5+
args = argument_parser()
6+
7+
foo, val_loader, bar = setup_dataloader(args)
8+
9+
print(len(val_loader.dataset))
10+
t = next(iter(val_loader))
11+
print(t[3].shape)
12+
...

train_mmnist.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from utils.logging import log_stats, setup_wandb_columns
1919
from datasets import setup_dataloader
2020
from tqdm import tqdm
21-
from test_mmnist import dec_rim_util, test
21+
from test_mmnist import test
2222

2323
import os
2424
from os import listdir
@@ -123,7 +123,7 @@ def main():
123123
columns = setup_wandb_columns(args) # artifact columns
124124

125125
# data setup
126-
train_loader, _, test_loader = setup_dataloader(args=args)
126+
train_loader, val_loader, test_loader = setup_dataloader(args=args)
127127

128128
# model setup
129129
model, optimizer, scheduler, loss_fn, start_epoch, train_batch_idx, best_mse = setup_model(args=args)
@@ -160,13 +160,14 @@ def main():
160160
"""test model accuracy and log intermediate variables here"""
161161
test_loss, test_recon_loss, test_pred_loss, prediction, data, metrics, test_table = test(
162162
model = model,
163-
test_loader = test_loader,
163+
test_loader = val_loader if args.use_val_set else test_loader,
164164
args = args,
165165
loss_fn = loss_fn,
166166
writer = writer,
167167
rollout = True,
168168
epoch = epoch,
169169
log_columns=columns if epoch%50==0 else None,
170+
calc_csty = True if args.use_val_set else False
170171
)
171172
log_stats(
172173
args=args,

utils/consistensy_measure/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,39 @@ def consistency_measure(
88
target_seq: torch.Tensor,
99
corr_padding: tuple=(0, 0),
1010
output_ids: bool=True,
11+
reduction: str='none',
12+
exclude_background: bool=True,
1113
):
1214
"""
1315
input:
1416
`input_seq`: [N, K1, T, C, H, W]
1517
`target_seq`: [N, K2, T, C, H, W]
1618
`corr_padding`: (h-wise, w-wise) padding for correlation operation
19+
return:
20+
`avr_len`, `max_len`, (`IDs`)
1721
"""
1822
input_seq = input_seq.permute(2, 0, 1, 3, 4, 5) # [T, N, K1, C, H, W]
1923
target_seq = target_seq.permute(2, 0, 1, 3, 4, 5) # [T, N, K2, C, H, W]
2024
IDs = []
2125
for t in range(input_seq.shape[0]):
2226
corr_coef = normalized_corr(input_seq[t], target_seq[t], padding=corr_padding) # [N, K1, K2]
2327
_, indices = torch.max(corr_coef, dim=-1) # indices, shape [N, K1,].
28+
bg_flag = input_seq[t].sum(dim=(-1, -2, -3)) < 0.1 * target_seq[t].sum(dim=(-1,-2,-3)).mean(-1, keepdim=True) # [N, K1]
29+
indices[bg_flag] = target_seq.shape[1]+1 # extra ID for background
2430
IDs.append(indices)
2531
IDs = torch.stack(IDs, dim=-1) # shape [N, 3, T]
2632
avr_len = average_consistent_length(IDs)
2733
max_len = maximum_consistent_length(IDs)
2834

35+
if reduction == 'mean':
36+
avr_len = avr_len.mean()
37+
max_len = max_len.mean()
38+
elif reduction == 'sum':
39+
avr_len = avr_len.sum()
40+
max_len = max_len.sum()
41+
else:
42+
pass
43+
2944
if output_ids:
3045
return avr_len, max_len, IDs
3146
else:

0 commit comments

Comments
 (0)