-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathconfig.py
More file actions
24 lines (22 loc) · 682 Bytes
/
config.py
File metadata and controls
24 lines (22 loc) · 682 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
from pathlib import Path
def get_config():
return {
"batch_size": 12,
"num_epochs": 20,
"lr": 1e-4,
"seq_len": 350,
"d_model": 512,
"lang_src": "en",
"lang_target": "it",
"model_folder": "weights",
"model_filename": "tmodel_",
"preload": None,
"tokenizer_file": "tokenizer_{0}.json",
"experiment_name": "runs/tmodel"
}
def get_weights_file_path(config, epoch: str):
model_folder = config['model_folder']
model_basename = config['model_basename']
model_filename = f"{model_basename}{epoch}.pt"
return str(Path('.') / model_folder / model_filename)