Skip to content

Commit b2cd5aa

Browse files
Merge pull request #2 from HopefulRational/revert-1-dynamic_routing
Revert "stop gradient flow during dynamic routing coefficients calculation"
2 parents 1609009 + 184d6ff commit b2cd5aa

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

deepcaps.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,14 +713,13 @@ def forward(self, x):
713713
x = x.unsqueeze(2).unsqueeze(dim=4)
714714

715715
u_hat = torch.matmul(self.W, x).squeeze() # u_hat -> [batch_size, 32, 10, 32]
716-
u_hat_detached = u_hat.detach() #detach the u_hat vector to stop the gradient flow during the calculation of the coefficients for dynamic routing.
717716

718717
# b_ij = torch.zeros((batch_size, self.num_routes, self.num_capsules, 1))
719718
b_ij = x.new(x.shape[0], self.num_routes, self.num_capsules, 1).zero_()
720719

721720
for itr in range(self.routing_iters):
722721
c_ij = func.softmax(b_ij, dim=2)
723-
s_j = (c_ij * u_hat_detached).sum(dim=1, keepdim=True) + self.bias #use detached u_hat during all the iteration except the final iteration.
722+
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) + self.bias
724723
v_j = squash(s_j, dim=-1)
725724

726725
if itr < self.routing_iters-1:

0 commit comments

Comments
 (0)