Skip to content

Commit 99486d3

Browse files
[ADD/DEL] log rule attn
1. change return of sharedgroupgru 2. del relational_memory 3. log rule_attn_probs in altscoff/rim 4. log rule_attn_probs in test and logging 5. changed input_attn in logging using get
1 parent a831734 commit 99486d3

File tree

6 files changed

+53
-386
lines changed

6 files changed

+53
-386
lines changed

.vscode/launch.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
"--experiment_name",
6666
"SPRITES_SASBD_444",
6767
"--cfg_json",
68-
"configs/mmnist/rim_sasbd.json",
68+
"configs/mmnist/altscoff_sasbd.json",
6969
"--dataset_dir",
7070
"data",
7171
"--decode_hidden",

group_operations.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,8 @@ def forward(self, input, h):
454454
455455
Outputs:
456456
`hnext`: [N, num_hidden, single_hidden_size],
457-
`attn`: [N, num_OFs, n_templates] (num_bloccks==k==num_object_files)
457+
`attn_sm`: [N, num_hidden, num_rules] from softmax
458+
`attn_gsm`: [N, num_OFs, n_templates] (num_bloccks==k==num_object_files) from gumbel_softmax
458459
"""
459460

460461
#self.blockify_params()
@@ -484,15 +485,16 @@ def forward(self, input, h):
484485
else:
485486
write_key = self.rule_embeddings # [1, num_rules, kdim]
486487

487-
att = torch.nn.functional.gumbel_softmax(torch.matmul(h_read, write_key.permute(0, 2, 1)), tau=0.5, hard=True) # Shape: [N*num_hidden, 1, num_rules]
488+
att_logits = torch.matmul(h_read, write_key.permute(0, 2, 1))
489+
att = torch.nn.functional.gumbel_softmax(att_logits, tau=0.5, hard=True) # Shape: [N*num_hidden, 1, num_rules]
488490

489491
#print('hnext shape before att', hnext.shape)
490492
hnext = torch.bmm(att, hnext) # [N*num_hidden, 1, num_rules], [N*num_hidden, num_rules, hidden_size] -> [N*num_hidden, 1, hidden_size]
491493
hnext = hnext.mean(dim=1) # [N*num_hidden, hidden_size]
492494
hnext = hnext.reshape((bs, self.num_hidden, self.hidden_size)) # [N, num_hidden, hidden_size]
493495
#print('shapes', hnext.shape, cnext.shape)
494496

495-
return hnext, att.data.reshape(bs,self.num_hidden,self.num_rules)
497+
return hnext, nn.Softmax(-1)(att_logits).data.reshape(bs,self.num_hidden,self.num_rules), att.data.reshape(bs,self.num_hidden,self.num_rules)
496498

497499
class SharedBlockLSTM(nn.Module):
498500
"""Dynamic sharing of parameters between blocks(RIM's)

0 commit comments

Comments
 (0)