|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | + |
| 4 | + |
| 5 | +def do_nothing(x): |
| 6 | + return x |
| 7 | + |
| 8 | + |
| 9 | +def bipartite_soft_matching_x_w(x, w, r, scaling_factors): |
| 10 | + # We can only reduce by a maximum of 50% channels |
| 11 | + # metric shape: [B * N, C] |
| 12 | + b = x.shape[0] |
| 13 | + t = x.shape[1] |
| 14 | + r = min(r, t // 2) |
| 15 | + |
| 16 | + if r <= 0: |
| 17 | + return do_nothing, do_nothing |
| 18 | + |
| 19 | + with torch.no_grad(): |
| 20 | + # xa, xb shape: [B * N, C/2] |
| 21 | + xa, xb = x[..., ::2], x[..., 1::2] |
| 22 | + # wa, wb shape: [cout, C/2] |
| 23 | + wa, wb = w[..., ::2], w[..., 1::2] |
| 24 | + xa_c, xb_c = xa.shape[1], xb.shape[1] |
| 25 | + |
| 26 | + # shape: [C/2, C/2] |
| 27 | + # fast version |
| 28 | + # xdist = (xa.t().reshape(xa_c, b, 1) - xb.reshape(1, b, xb_c)).sum(1) |
| 29 | + # score_ij = wi (yi - yj) / 2 + wj (yj - yi) / 2 |
| 30 | + # score_a[i, :] = w:i xdist: |
| 31 | + |
| 32 | + xdist = torch.cdist(xa.t(), xb.t(), p=2.0) |
| 33 | + scores_a = torch.zeros(xa_c, xb_c, device=x.device) |
| 34 | + scores_b = torch.zeros(xa_c, xb_c, device=x.device) |
| 35 | + scores_fast = torch.zeros(xa_c, xb_c, device=x.device) |
| 36 | + for i in range(xb_c): |
| 37 | + scores_a[i, :] = (wa[:, i].unsqueeze(1) * xdist[i]).sum(0) |
| 38 | + for j in range(xb_c): |
| 39 | + scores_b[:, j] = (wb[:, j].unsqueeze(1) * (xdist[:, j])).sum(0) |
| 40 | + scores_fast = (scores_a + scores_b).pow(2) |
| 41 | + scores = scores_fast |
| 42 | + |
| 43 | + if scaling_factors is not None: |
| 44 | + split_mask = scaling_factors != 1.0 |
| 45 | + split_mask_a = split_mask[::2] |
| 46 | + split_index_a = split_mask_a.nonzero().squeeze() |
| 47 | + |
| 48 | + split_mask_b = split_mask[1::2] |
| 49 | + split_index_b = split_mask_b.nonzero().squeeze() |
| 50 | + scores.index_fill_( |
| 51 | + dim=0, index=split_index_a, value=torch.finfo(scores.dtype).max |
| 52 | + ) |
| 53 | + scores.index_fill_( |
| 54 | + dim=1, index=split_index_b, value=torch.finfo(scores.dtype).max |
| 55 | + ) |
| 56 | + |
| 57 | + # scores_a = torch.zeros(xa_c, xb_c, device=x.device) |
| 58 | + # for i in range(xb_c): |
| 59 | + # scores_a[i, :] = (wa[:, i].unsqueeze(1) * xa[:, i]).mean(0) |
| 60 | + |
| 61 | + # slow version |
| 62 | + # scores = torch.zeros(t // 2, t // 2, device=x.device) |
| 63 | + # for i in range(t // 2): |
| 64 | + # for j in range(t // 2): |
| 65 | + # scores[i, j] = ( |
| 66 | + # ( |
| 67 | + # (wa[..., i] * (xa[..., i] - xb[..., j]).sum()) |
| 68 | + # + (wb[..., j] * (xb[..., j] - xa[..., i]).sum()) |
| 69 | + # ) |
| 70 | + # .mean(0) |
| 71 | + # .pow(2) |
| 72 | + # ) |
| 73 | + |
| 74 | + # node max, node_idx shape: [C/2], index of b |
| 75 | + # Draw one edge from each token in A to its most similar token in B. |
| 76 | + node_min, node_idx = scores.min(dim=-1) |
| 77 | + # edge_idx shape: [C/2] |
| 78 | + # Keep the r most similar edges. index of a |
| 79 | + edge_idx = node_min.argsort(dim=-1, descending=False) |
| 80 | + |
| 81 | + # unm_idx shape: [C/2 -r] |
| 82 | + # unm_idx = edge_idx[r:] # Unassembled Channels |
| 83 | + # src_idx shape: [r] |
| 84 | + src_idx = edge_idx[:r] # Assembled Channels |
| 85 | + dst_idx = node_idx[src_idx] |
| 86 | + return src_idx, dst_idx, scores[src_idx, dst_idx] |
| 87 | + |
| 88 | + |
| 89 | +def assembly(x, src_idx, dst_idx, r, mode="mean") -> torch.Tensor: |
| 90 | + # shape of src dst: [B, N, C] |
| 91 | + B, N, C = x.shape |
| 92 | + |
| 93 | + ori_src_idx = torch.arange(0, C, 2, device=x.device) |
| 94 | + ori_dst_idx = torch.arange(1, C, 2, device=x.device) |
| 95 | + src, dst = x[..., ori_src_idx], x[..., ori_dst_idx] |
| 96 | + src_C = src.shape[-1] |
| 97 | + dst_C = dst.shape[-1] |
| 98 | + |
| 99 | + # we set mask to 0 when channel is assembled |
| 100 | + channel_mask = torch.ones(C, device=x.device, dtype=x.dtype) |
| 101 | + m_idx = ori_src_idx[src_idx] |
| 102 | + channel_mask[m_idx] = 0.0 |
| 103 | + |
| 104 | + n, t1, c = src.shape |
| 105 | + sub_src = src.gather(dim=-1, index=src_idx.expand(n, t1, r)) |
| 106 | + dst = dst.scatter_reduce(-1, dst_idx.expand(n, t1, r), sub_src, reduce=mode) |
| 107 | + src = src.view(B, N, src_C, 1) |
| 108 | + dst = dst.view(B, N, dst_C, 1) |
| 109 | + if src_C == dst_C: |
| 110 | + assembled_x = torch.cat([src, dst], dim=-1).view(B, N, C) |
| 111 | + else: |
| 112 | + assembled_x = torch.cat([src[..., :-1, :], dst], dim=-1).view( |
| 113 | + B, N, src_C + dst_C - 1 |
| 114 | + ) |
| 115 | + assembled_x = torch.cat( |
| 116 | + [assembled_x, src[..., -1, :].reshape(B, N, 1)], dim=-1 |
| 117 | + ).view(B, N, src_C + dst_C) |
| 118 | + assembled_x = assembled_x.index_select(-1, (channel_mask != 0).nonzero().squeeze()) |
| 119 | + return assembled_x |
| 120 | + |
| 121 | + |
| 122 | +class CAModule(nn.Module): |
| 123 | + def __init__(self, num_assembled_channels): |
| 124 | + super().__init__() |
| 125 | + self.num_assembled_channels = num_assembled_channels |
| 126 | + self.have_assembled = False |
| 127 | + self.src_idx = None |
| 128 | + self.dst_idx = None |
| 129 | + self.num_disassembly = None |
| 130 | + self.scaling_factors = None |
| 131 | + |
| 132 | + def find_similar_channels(self, x, fcs): |
| 133 | + B, N, C = x.shape |
| 134 | + x = x.view(B * N, C) |
| 135 | + |
| 136 | + fc_weight = [] |
| 137 | + if not isinstance(fcs, list): |
| 138 | + fcs = [fcs] |
| 139 | + for fc in fcs: |
| 140 | + fc_weight.append(fc.weight) |
| 141 | + fc_weight = torch.cat(fc_weight, dim=0) |
| 142 | + |
| 143 | + x = x.float() |
| 144 | + src_idx, dst_idx, scores = bipartite_soft_matching_x_w( |
| 145 | + x, fc_weight, self.num_assembled_channels, self.scaling_factors |
| 146 | + ) |
| 147 | + del self.src_idx |
| 148 | + del self.dst_idx |
| 149 | + print("Score: {}".format(scores)) |
| 150 | + self.register_buffer("src_idx", src_idx) |
| 151 | + self.register_buffer("dst_idx", dst_idx) |
| 152 | + self.have_assembled = True |
| 153 | + |
| 154 | + def forward(self, x): |
| 155 | + # only perform assembly after find_similar_channels |
| 156 | + if self.have_assembled: |
| 157 | + B, N, C = x.shape |
| 158 | + # if size is None: |
| 159 | + # size = torch.ones_like(x[0, 0]) |
| 160 | + # size = size.view(1, 1, C) |
| 161 | + |
| 162 | + # x = assembly( |
| 163 | + # x * size, |
| 164 | + # self.src_idx, |
| 165 | + # self.dst_idx, |
| 166 | + # self.num_assembled_channels, |
| 167 | + # mode="sum", |
| 168 | + # ) |
| 169 | + # size = assembly( |
| 170 | + # size, |
| 171 | + # self.src_idx, |
| 172 | + # self.dst_idx, |
| 173 | + # self.num_assembled_channels, |
| 174 | + # mode="sum", |
| 175 | + # ) |
| 176 | + # x = x / size |
| 177 | + x = assembly( |
| 178 | + x, |
| 179 | + self.src_idx, |
| 180 | + self.dst_idx, |
| 181 | + self.num_assembled_channels, |
| 182 | + mode="mean", |
| 183 | + ) |
| 184 | + return x |
0 commit comments