Skip to content

Commit

Permalink
use einops
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 15, 2021
1 parent 483df12 commit 132dda0
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[settings]
known_third_party = PIL,matplotlib,numpy,pytest,pytorch_lightning,torch,torchmetrics,torchvision
known_third_party = PIL,einops,matplotlib,numpy,pytest,pytorch_lightning,torch,torchmetrics,torchvision
14 changes: 13 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ authors = ["sami jaghouar <[email protected]>"]
python = "^3.8"
torch = "^1.9.1"
torchvision = "^0.10.1"
einops = "^0.3.2"

[tool.poetry.dev-dependencies]
pytest = "^5.2"
Expand Down
17 changes: 6 additions & 11 deletions vision_transformers/vit.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from itertools import product
from typing import Tuple

import torch
import torch.nn as nn
from einops.layers.torch import Rearrange


class PatchSizeError(ValueError):
Expand All @@ -24,14 +24,9 @@ def __init__(self, size_of_patch: int, input_shape: Tuple[int, int]):
self.input_shape = input_shape
self.N = (input_shape[0] // self.P) * (input_shape[1] // self.P)

def forward(self, x: torch.Tensor) -> torch.Tensor():
return torch.stack(
[
x[:, :, i * self.P : (i + 1) * self.P, j * self.P : (j + 1) * self.P]
for i, j in product(
range(0, self.input_shape[0] // self.P),
range(0, self.input_shape[1] // self.P),
)
],
dim=1,
self.rearrange = Rearrange(
" n c (h p1) (w p2) -> n (h w) c p1 p2", p2=self.P, p1=self.P
)

def forward(self, x: torch.Tensor) -> torch.Tensor():
return self.rearrange(x)

0 comments on commit 132dda0

Please sign in to comment.