Skip to content

Commit

Permalink
add pos embedding
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 15, 2021
1 parent 132dda0 commit e8997d2
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
11 changes: 8 additions & 3 deletions tests/vit/test_unit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
import torch

from vision_transformers.vit import EmbeddedPatch
from vision_transformers.vit import EmbeddedPatch, positional_embedding


@pytest.mark.xfail
Expand All @@ -13,8 +13,13 @@ def test_embedded_shape():

input_ = torch.zeros((2, 3, 28, 28))

em_patch = EmbeddedPatch(14, (28, 28))
em_patch = EmbeddedPatch(14, (28, 28), dim=512)

patches = em_patch(input_)

assert patches.shape == (2, 4, 3, 14, 14)
assert patches.shape == (2, 4, 512)


def test_positional_em():

assert positional_embedding(512).shape == (512,)
43 changes: 34 additions & 9 deletions vision_transformers/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@

import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from einops import rearrange


class PatchSizeError(ValueError):
class VitError(ValueError):
pass


class PatchSizeError(VitError):
pass


class EmbeddedPatch(nn.Module):
def __init__(self, size_of_patch: int, input_shape: Tuple[int, int]):
def __init__(
self, size_of_patch: int, input_shape: Tuple[int, int], channels=3, dim=512
):
super().__init__()

if (input_shape[0] % size_of_patch != 0) or (
Expand All @@ -21,12 +27,31 @@ def __init__(self, size_of_patch: int, input_shape: Tuple[int, int]):
)

self.P = size_of_patch
self.input_shape = input_shape
self.N = (input_shape[0] // self.P) * (input_shape[1] // self.P)
self.dim = dim

self.rearrange = Rearrange(
" n c (h p1) (w p2) -> n (h w) c p1 p2", p2=self.P, p1=self.P
)
self.linear = nn.Linear(channels * self.P * self.P, dim)

def forward(self, x: torch.Tensor) -> torch.Tensor():
return self.rearrange(x)
x = rearrange(
x, " n c (h p1) (w p2) -> n (h w) (p1 p2 c)", p2=self.P, p1=self.P
)
x = self.linear(x)
pos_em = positional_embedding(self.dim)
return x + pos_em


class DimError(VitError):
pass


def positional_embedding(dim: int):
if dim % 2 != 0:
raise DimError(f"dim {dim} should pair")

d = dim // 2

w = (10_000 * torch.ones(d)).pow(2 * torch.arange(d) / d)
cos = torch.cos(w)
sin = torch.sin(w)

return rearrange([sin, cos], " a b -> (b a) ")

0 comments on commit e8997d2

Please sign in to comment.