Skip to content

Commit 1609009

Browse files
Merge pull request #1 from Ugenteraan/dynamic_routing
stop gradient flow during dynamic routing coefficients calculation
2 parents 2aad682 + 92bba28 commit 1609009

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

deepcaps.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,13 +713,14 @@ def forward(self, x):
713713
x = x.unsqueeze(2).unsqueeze(dim=4)
714714

715715
u_hat = torch.matmul(self.W, x).squeeze() # u_hat -> [batch_size, 32, 10, 32]
716+
u_hat_detached = u_hat.detach() #detach the u_hat vector to stop the gradient flow during the calculation of the coefficients for dynamic routing.
716717

717718
# b_ij = torch.zeros((batch_size, self.num_routes, self.num_capsules, 1))
718719
b_ij = x.new(x.shape[0], self.num_routes, self.num_capsules, 1).zero_()
719720

720721
for itr in range(self.routing_iters):
721722
c_ij = func.softmax(b_ij, dim=2)
722-
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) + self.bias
723+
s_j = (c_ij * u_hat_detached).sum(dim=1, keepdim=True) + self.bias #use detached u_hat during all the iteration except the final iteration.
723724
v_j = squash(s_j, dim=-1)
724725

725726
if itr < self.routing_iters-1:

0 commit comments

Comments
 (0)