diff --git a/run_python_examples.sh b/run_python_examples.sh index 2d769c0ae1..ef6d0228c5 100755 --- a/run_python_examples.sh +++ b/run_python_examples.sh @@ -167,6 +167,10 @@ function gat() { uv run main.py --epochs 1 --dry-run || error "graph attention network failed" } +function swin() { + uv run swin_transformer.py --epochs 1 --dry-run || error "swin transformer failed" +} + eval "base_$(declare -f stop)" function stop() { @@ -191,8 +195,8 @@ function stop() { time_sequence_prediction/traindata.pt \ word_language_model/model.pt \ gcn/cora/ \ - gat/cora/ || error "couldn't clean up some files" - + gat/cora/ \ + swin_trasformer/swin_cifar10.pt || error "couldn't clean up some files" git checkout fast_neural_style/images/output-images/amber-candy.jpg || error "couldn't clean up fast neural style image" base_stop "$1" @@ -220,6 +224,7 @@ function run_all() { run fx run gcn run gat + run swin_transformer } # by default, run all examples diff --git a/swin_transformer/README.md b/swin_transformer/README.md new file mode 100644 index 0000000000..37f789be37 --- /dev/null +++ b/swin_transformer/README.md @@ -0,0 +1,61 @@ +# Swin Transformer on CIFAR-10 + +This project demonstrates a minimal implementation of a **Swin Transformer** for image classification on the **CIFAR-10** dataset using PyTorch. + +It includes: +- Patch embedding and window-based self-attention +- Shifted windows for hierarchical representation +- Training and testing logic using standard PyTorch utilities + +--- + +## Files + +- `swin_transformer.py` — Full implementation of the Swin Transformer model, training loop, and evaluation on CIFAR-10. +- `README.md` — This file. + +--- + +## Requirements + +- Python 3.8+ +- PyTorch 2.6 or later +- `torchvision` (for CIFAR-10 dataset) + +Install dependencies: + +```bash +pip install -r requirements.txt +``` + +--- + +## Usage + +### Train & Save the model + +```bash +python swin_transformer.py --epochs 10 --batch-size 64 --lr 0.001 --save-model +``` + +### Test the model + +Testing is done automatically after each epoch. To only test, run with: + +```bash +python swin_transformer.py --epochs 1 +`` + +The model will be saved as `swin_cifar10.pt`. + +--- + +## Features + +- Uses shifted window attention for local self-attention. +- Patch-based embedding with a lightweight network. +- Trains on CIFAR-10 with `Adam` optimizer and learning rate scheduling. +- Prints loss and accuracy per epoch. + +--- + diff --git a/swin_transformer/requirements.txt b/swin_transformer/requirements.txt new file mode 100644 index 0000000000..9a083ba390 --- /dev/null +++ b/swin_transformer/requirements.txt @@ -0,0 +1,2 @@ +torch>=2.6 +torchvision diff --git a/swin_transformer/swin_transformer.py b/swin_transformer/swin_transformer.py new file mode 100644 index 0000000000..a29fbd5fff --- /dev/null +++ b/swin_transformer/swin_transformer.py @@ -0,0 +1,203 @@ +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.optim.lr_scheduler import StepLR +from torchvision import datasets, transforms + +# ---------- Core Swin Components ---------- + +class PatchEmbed(nn.Module): + def __init__(self, img_size=32, patch_size=4, in_chans=3, embed_dim=48): + super().__init__() + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x): + x = self.proj(x) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + return x + +def window_partition(x, window_size): + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + +def window_reverse(windows, window_size, H, W): + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + +class WindowAttention(nn.Module): + def __init__(self, dim, window_size, num_heads): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3) + self.proj = nn.Linear(dim, dim) + + def forward(self, x): + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads) + q, k, v = qkv.permute(2, 0, 3, 1, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + + out = (attn @ v).transpose(1, 2).reshape(B_, N, C) + return self.proj(out) + +class SwinTransformerBlock(nn.Module): + def __init__(self, dim, input_resolution, num_heads, window_size=4, shift_size=0): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.window_size = window_size + self.shift_size = shift_size + + self.norm1 = nn.LayerNorm(dim) + self.attn = WindowAttention(dim, window_size, num_heads) + self.norm2 = nn.LayerNorm(dim) + + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + nn.GELU(), + nn.Linear(dim * 4, dim) + ) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + x = x.view(B, H, W, C) + + if self.shift_size > 0: + shifted_x = torch.roll(x, (-self.shift_size, -self.shift_size), (1, 2)) + else: + shifted_x = x + + windows = window_partition(shifted_x, self.window_size) + windows = windows.view(-1, self.window_size * self.window_size, C) + + attn_windows = self.attn(self.norm1(windows)) + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + + shifted_x = window_reverse(attn_windows, self.window_size, H, W) + + if self.shift_size > 0: + x = torch.roll(shifted_x, (self.shift_size, self.shift_size), (1, 2)) + else: + x = shifted_x + + x = x.view(B, H * W, C) + x = x + self.mlp(self.norm2(x)) + return x + +# ---------- Final Network ---------- + +class SwinTinyNet(nn.Module): + def __init__(self, num_classes=10): + super(SwinTinyNet, self).__init__() + self.patch_embed = PatchEmbed(img_size=32, patch_size=4, in_chans=3, embed_dim=48) + self.block1 = SwinTransformerBlock(dim=48, input_resolution=(8, 8), num_heads=3, window_size=4, shift_size=0) + self.block2 = SwinTransformerBlock(dim=48, input_resolution=(8, 8), num_heads=3, window_size=4, shift_size=2) + self.norm = nn.LayerNorm(48) + self.fc = nn.Linear(48, num_classes) + + def forward(self, x): + x = self.patch_embed(x) + x = self.block1(x) + x = self.block2(x) + x = self.norm(x) + x = x.mean(dim=1) + x = self.fc(x) + return F.log_softmax(x, dim=1) + +# ---------- Training and Testing ---------- + +def train(args, model, device, train_loader, optimizer, epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + data, target = data.to(device), target.to(device) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.item())) + if args.dry_run: + break + +def test(args, model, device, test_loader): + model.eval() + test_loss = 0 + correct = 0 + with torch.no_grad(): + for data, target in test_loader: + data, target = data.to(device), target.to(device) + output = model(data) + test_loss += F.nll_loss(output, target, reduction='sum').item() + pred = output.argmax(dim=1, keepdim=True) + correct += pred.eq(target.view_as(pred)).sum().item() + if args.dry_run: + break + + test_loss /= len(test_loader.dataset) + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + +# ---------- Main ---------- + +def main(): + parser = argparse.ArgumentParser(description='Swin Transformer CIFAR10 Example') + parser.add_argument('--batch-size', type=int, default=64) + parser.add_argument('--test-batch-size', type=int, default=1000) + parser.add_argument('--epochs', type=int, default=10) + parser.add_argument('--lr', type=float, default=0.01) + parser.add_argument('--gamma', type=float, default=0.7) + parser.add_argument('--dry-run', action='store_true') + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--log-interval', type=int, default=10) + parser.add_argument('--save-model', action='store_true') + args = parser.parse_args() + + use_accel = torch.accelerator.is_available() + device = torch.accelerator.current_accelerator() if use_accel else torch.device("cpu") + print(f"Using device: {device}") + + torch.manual_seed(args.seed) + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5,), (0.5,)) + ]) + + train_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('../data', train=True, download=True, transform=transform), + batch_size=args.batch_size, shuffle=True) + + test_loader = torch.utils.data.DataLoader( + datasets.CIFAR10('../data', train=False, transform=transform), + batch_size=args.test_batch_size, shuffle=False) + + model = SwinTinyNet().to(device) + optimizer = optim.Adam(model.parameters(), lr=args.lr) + scheduler = StepLR(optimizer, step_size=3, gamma=args.gamma) + + for epoch in range(1, args.epochs + 1): + train(args, model, device, train_loader, optimizer, epoch) + test(args, model, device, test_loader) + scheduler.step() + + if args.save_model: + torch.save(model.state_dict(), "swin_cifar10.pt") +main() \ No newline at end of file