Skip to content

Commit

Permalink
wav2vec2 trains now
Browse files Browse the repository at this point in the history
  • Loading branch information
sophie460 committed Sep 12, 2024
1 parent 53ff429 commit 4c3c85b
Show file tree
Hide file tree
Showing 8 changed files with 707 additions and 3 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
transformers/
results/
data/
logs/
*.lock
*.toml
*.eeg
*.vhdr
*.vmrk
Expand Down
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12.5
84 changes: 84 additions & 0 deletions CNN_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

# Define the CNN architecture
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) # Input channel = 1 (grayscale), Output channels = 32
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) # Input channel = 32, Output channels = 64
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) # Input channel = 64, Output channels = 128
self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fc1 = nn.Linear(128 * 4 * 4, 512) # Assuming input image size is 28x28
self.fc2 = nn.Linear(512, 10) # 10 classes for MNIST

def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(-1, 128 * 4 * 4) # Flatten the tensor
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Prepare data loaders for MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# Initialize network, criterion, and optimizer
net = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# Training loop
epochs = 5000
for epoch in range(epochs):
running_loss = 0.0
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)

# Zero the parameter gradients
optimizer.zero_grad()

# Forward pass
outputs = net(images)
loss = criterion(outputs, labels[:outputs.shape[0]])

# Backward pass and optimize
loss.backward()
optimizer.step()

running_loss += loss.item()
print(f'Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}')

print('Finished Training')

# Testing loop
correct = 0
total = 0
with torch.no_grad():
for images, labels in testloader:
images, labels = images.to(device), labels.to(device)

# Ensure the batch sizes are consistent
if images.size(0) != labels.size(0):
raise ValueError(f"Batch size mismatch: images batch size = {images.size(0)}, labels batch size = {labels.size(0)}")

outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total}%')
Loading

0 comments on commit 4c3c85b

Please sign in to comment.