-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
707 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,9 @@ | ||
transformers/ | ||
results/ | ||
data/ | ||
logs/ | ||
*.lock | ||
*.toml | ||
*.eeg | ||
*.vhdr | ||
*.vmrk | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3.12.5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
import torchvision | ||
import torchvision.transforms as transforms | ||
|
||
# Define the CNN architecture | ||
class SimpleCNN(nn.Module): | ||
def __init__(self): | ||
super(SimpleCNN, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # Input channel = 1 (grayscale), Output channels = 32 | ||
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # Input channel = 32, Output channels = 64 | ||
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # Input channel = 64, Output channels = 128 | ||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0) | ||
self.fc1 = nn.Linear(128 * 4 * 4, 512) # Assuming input image size is 28x28 | ||
self.fc2 = nn.Linear(512, 10) # 10 classes for MNIST | ||
|
||
def forward(self, x): | ||
x = self.pool(F.relu(self.conv1(x))) | ||
x = self.pool(F.relu(self.conv2(x))) | ||
x = self.pool(F.relu(self.conv3(x))) | ||
x = x.view(-1, 128 * 4 * 4) # Flatten the tensor | ||
x = F.relu(self.fc1(x)) | ||
x = self.fc2(x) | ||
return x | ||
|
||
# Set device | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
|
||
# Prepare data loaders for MNIST dataset | ||
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) | ||
|
||
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) | ||
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True) | ||
|
||
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) | ||
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False) | ||
|
||
# Initialize network, criterion, and optimizer | ||
net = SimpleCNN().to(device) | ||
criterion = nn.CrossEntropyLoss() | ||
optimizer = optim.Adam(net.parameters(), lr=0.001) | ||
|
||
# Training loop | ||
epochs = 5000 | ||
for epoch in range(epochs): | ||
running_loss = 0.0 | ||
for images, labels in trainloader: | ||
images, labels = images.to(device), labels.to(device) | ||
|
||
# Zero the parameter gradients | ||
optimizer.zero_grad() | ||
|
||
# Forward pass | ||
outputs = net(images) | ||
loss = criterion(outputs, labels[:outputs.shape[0]]) | ||
|
||
# Backward pass and optimize | ||
loss.backward() | ||
optimizer.step() | ||
|
||
running_loss += loss.item() | ||
print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}') | ||
|
||
print('Finished Training') | ||
|
||
# Testing loop | ||
correct = 0 | ||
total = 0 | ||
with torch.no_grad(): | ||
for images, labels in testloader: | ||
images, labels = images.to(device), labels.to(device) | ||
|
||
# Ensure the batch sizes are consistent | ||
if images.size(0) != labels.size(0): | ||
raise ValueError(f"Batch size mismatch: images batch size = {images.size(0)}, labels batch size = {labels.size(0)}") | ||
|
||
outputs = net(images) | ||
_, predicted = torch.max(outputs.data, 1) | ||
total += labels.size(0) | ||
correct += (predicted == labels).sum().item() | ||
|
||
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total}%') |
Oops, something went wrong.