-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_setup.py
More file actions
121 lines (102 loc) · 3.59 KB
/
test_setup.py
File metadata and controls
121 lines (102 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""
Quick test to verify CW2 setup is working.
Student: Martynas Prascevicius
Student ID: 001263199
"""
import sys
from pathlib import Path
# Add src to path
sys.path.insert(0, str(Path(__file__).parent / 'src'))
print("=" * 70)
print("Testing CW2 Setup")
print("=" * 70)
# Test 1: Import modules
print("\n1. Testing imports...")
try:
import torch
import transformers
import numpy as np
import sklearn
print(" ✓ All core libraries available")
print(f" PyTorch: {torch.__version__}")
print(f" Transformers: {transformers.__version__}")
except ImportError as e:
print(f" ✗ Missing library: {e}")
sys.exit(1)
# Test 2: Check device
print("\n2. Checking device...")
device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f" ✓ Using device: {device}")
# Test 3: Load data
print("\n3. Loading IMDB data...")
try:
from data_loader import load_imdb_from_directory
project_root = Path(__file__).parent
data_dir = project_root / 'data'
train_texts, train_labels, test_texts, test_labels = load_imdb_from_directory(data_dir)
print(f" ✓ Train: {len(train_texts)} samples")
print(f" ✓ Test: {len(test_texts)} samples")
except Exception as e:
print(f" ✗ Failed to load data: {e}")
sys.exit(1)
# Test 4: Load tokenizer
print("\n4. Loading DistilBERT tokenizer...")
try:
from transformers import DistilBertTokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
print(f" ✓ Tokenizer loaded")
# Test tokenization
test_text = train_texts[0][:100]
tokens = tokenizer(test_text, max_length=64, truncation=True, padding='max_length')
print(f" ✓ Sample tokenized: {len(tokens['input_ids'])} tokens")
except Exception as e:
print(f" ✗ Failed to load tokenizer: {e}")
sys.exit(1)
# Test 5: Load model
print("\n5. Loading enhanced DistilBERT model...")
try:
from enhanced_model import EnhancedDistilBERT
model = EnhancedDistilBERT(pooling_strategy='cls')
model.to(device)
print(f" ✓ Model loaded")
print(f" ✓ Total params: {model.get_num_total_params():,}")
print(f" ✓ Trainable params: {model.get_num_trainable_params():,}")
except Exception as e:
print(f" ✗ Failed to load model: {e}")
sys.exit(1)
# Test 6: Test forward pass
print("\n6. Testing forward pass...")
try:
import torch
batch_size = 2
seq_len = 64
# Create dummy batch
input_ids = torch.randint(0, 1000, (batch_size, seq_len)).to(device)
attention_mask = torch.ones(batch_size, seq_len).to(device)
# Forward pass
with torch.no_grad():
logits = model(input_ids, attention_mask)
print(f" ✓ Forward pass successful")
print(f" ✓ Output shape: {logits.shape}")
except Exception as e:
print(f" ✗ Failed forward pass: {e}")
sys.exit(1)
# Test 7: Check experiment configs
print("\n7. Checking experiment configurations...")
try:
from experiment_configs import get_experiment, EXPERIMENTS
baseline = get_experiment('baseline_default')
print(f" ✓ Total experiments defined: {len(EXPERIMENTS)}")
print(f" ✓ Baseline config: LR={baseline.learning_rate}, Batch={baseline.batch_size}")
except Exception as e:
print(f" ✗ Failed to load configs: {e}")
sys.exit(1)
print("\n" + "=" * 70)
print("All tests passed! ✓")
print("=" * 70)
print("\nYou're ready to run experiments!")
print("\nTo run baseline:")
print(" cd /Users/m2000uk/Desktop/coding/AI")
print(" source venv/bin/activate")
print(" cd CW2")
print(" python3 src/experiment_runner.py --experiment baseline_default")