Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further SeasonTST finetuning #8

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions PatchTST_self_supervised/src/callback/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import time
import numpy as np
from pathlib import Path
import logging


class TrackTimerCB(Callback):
Expand Down Expand Up @@ -121,15 +122,18 @@ def after_epoch_valid(self):
self.recorder['valid_'+name].append( values[name] )


def after_batch_train(self): self.accumulate() # save batch recorder
def after_batch_train(self):
self.accumulate() # save batch recorder
logging.info(f"Batch loss: {self.batch_recorder['batch_losses'][-1]}")

def after_batch_valid(self): self.accumulate()

def accumulate(self ):
xb, yb = self.batch
bs = len(xb)
self.batch_recorder['n_samples'].append(bs)
# get batch loss
loss = self.loss.detach()*bs if self.mean_reduction_ else self.loss.detach()
loss = self.loss.detach()*bs if self.mean_reduction_ else self.loss.detach()
self.batch_recorder['batch_losses'].append(loss)

if yb is None: self.batch_recorder['with_metrics'] = False
Expand Down
2 changes: 1 addition & 1 deletion PatchTST_self_supervised/src/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def fit(self, n_epochs, lr=None, cbs=None, do_valid=True):
def fit_one_cycle(self, n_epochs, lr_max=None, pct_start=0.3):
self.n_epochs = n_epochs
self.lr_max = lr_max if lr_max else self.lr
cb = OneCycleLR(lr_max=self.lr_max, pct_start=pct_start)
cb = OneCycleLR(lr_max=self.lr_max, pct_start=pct_start, verbose=True)
self.fit(self.n_epochs, cbs=cb)

def one_epoch(self, train):
Expand Down
38 changes: 11 additions & 27 deletions SeasonTST/SeasonTST_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
datefmt="%m/%d/%Y %I:%M:%S %p",
filename=f'logs/{datetime.datetime.now().strftime("%Y_%m_%d_%I:%M")}_finetune.log',
encoding="utf-8",
level=logging.DEBUG,
level=logging.INFO,
)


Expand All @@ -51,10 +51,6 @@
def finetune_func(learner, save_path, args, lr=0.001):
print("end-to-end finetuning")

if not os.path.exists(save_path):
os.makedirs(save_path)

print(save_path)
# fit the data to the model and save
learner.fine_tune(
n_epochs=args.n_epochs_finetune, base_lr=lr, freeze_epochs=args.freeze_epochs
Expand Down Expand Up @@ -107,20 +103,6 @@ def save_recorders(learner, args):
)


def test_func(weight_path, learner, args, dls):

out = learner.test(
dls.test, weight_path=weight_path, scores=[mse, mae]
) # out: a list of [pred, targ, score]
print("score:", out[2])
# save results
pd.DataFrame(np.array(out[2]).reshape(1, -1), columns=["mse", "mae"]).to_csv(
args.save_path + args.save_finetuned_model + "_acc.csv",
float_format="%.6f",
index=False,
)
return out


def load_config():

Expand All @@ -135,13 +117,14 @@ def load_config():
"revin": 0, # reversible instance normalization
"mask_ratio": 0.4, # masking ratio for the input
"lr": 1e-3,
"batch_size": 128,
"batch_size": 64,
"drop_last": False,
"num_workers": 6,
"prefetch_factor": 3,
"n_epochs_pretrain": 1, # number of pre-training epochs,
"n_epochs_pretrain": 20, # number of pre-training epochs,
"freeze_epochs": 0,
"n_epochs_finetune": 250,
"pretrained_model_id": 2500, # id of the saved pretrained model
"n_epochs_finetune": 10,
"pretrained_model_id": 2, # id of the saved pretrained model
"save_finetuned_model": "./finetuned_d128",
"save_path": "saved_models" + "/masked_patchtst/",
}
Expand Down Expand Up @@ -186,17 +169,18 @@ def main():
# Create dataloader
dls = get_dls(config_obj, SeasonTST_Dataset, data, mask)

# suggested_lr = find_lr(config_obj, dls)
# This is what I got on a small dataset. In case one wants to skip this for testing.
suggested_lr = 0.00017073526474706903
suggested_lr = 0.0002 # 0.000298364724028334
learner = get_learner(config_obj, dls, suggested_lr, model)
suggested_lr = learner.lr_finder()
print(suggested_lr)

learner = get_learner(config_obj, dls, suggested_lr, model)

# This function will save the model weights to config_obj.save_finetuned_model. ie will not overwrite the pretrained model.
# However, there is currently no set-up to do finetuning from the result of a previous finetuning.
# To continue training from a previous fine-tuning checkpoint, the path needs to be explicity fed to the get_model function
finetune_func(learner, pretrained_model_path, config_obj, suggested_lr)


if __name__ == "__main__":
# PYTHONPATH=$(pwd) python SeasonTST/SeasonTST_finetune.py
main()
52 changes: 29 additions & 23 deletions SeasonTST/SeasonTST_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
datefmt="%m/%d/%Y %I:%M:%S %p",
filename=f'logs/{datetime.datetime.now().strftime("%Y_%m_%d_%I_%M")}_train.log',
encoding="utf-8",
level=logging.DEBUG,
level=logging.INFO,
)


Expand Down Expand Up @@ -95,10 +95,11 @@ def load_config():
"mask_value": -99, # Value to assign to masked elements of data input
"lr": 1e-3,
"batch_size": 128,
"drop_last":True,
"prefetch_factor": 3,
"num_workers": 6,
"n_epochs_pretrain": 1, # number of pre-training epochs
"pretrained_model_id": 2500, # id of the saved pretrained model
"n_epochs_pretrain": 20, # number of pre-training epochs
"pretrained_model_id": 2, # id of the saved pretrained model
}

config_obj = SimpleNamespace(**config)
Expand All @@ -109,37 +110,42 @@ def main():
data, mask = load_data()
config_obj = load_config()

save_path = "saved_models" + "/masked_patchtst/"
pretrained_model = (
"patchtst_pretrained_cw"
+ str(config_obj.sequence_length)
+ "_patch"
+ str(config_obj.patch_len)
+ "_stride"
+ str(config_obj.stride)
+ "_epochs-pretrain"
+ str(config_obj.n_epochs_pretrain)
+ "_mask"
+ str(config_obj.mask_ratio)
+ "_model"
+ str(config_obj.pretrained_model_id)
)
pretrained_model_path = save_path + pretrained_model + ".pth"

# Creates train valid and test datasets for one epoch. Notice that they are in different locations!
dls = get_dls(config_obj, SeasonTST_Dataset, data, mask)

model = get_model(config_obj)

model = get_model(
config_obj, headtype="pretrain", weights_path=pretrained_model_path, exclude_head=False
)

# suggested_lr = find_lr(config_obj, dls)
# This is what I got on a small dataset. In case one wants to skip this for testing.
suggested_lr = 0.00020565123083486514

save_pretrained_model = (
"patchtst_pretrained_cw"
+ str(config_obj.sequence_length)
+ "_patch"
+ str(config_obj.patch_len)
+ "_stride"
+ str(config_obj.stride)
+ "_epochs-pretrain"
+ str(config_obj.n_epochs_pretrain)
+ "_mask"
+ str(config_obj.mask_ratio)
+ "_model"
+ str(config_obj.pretrained_model_id)
)
save_path = "saved_models" + "/masked_patchtst/"


pretrain_func(
save_pretrained_model, save_path, config_obj, model, dls, suggested_lr
pretrained_model, save_path, config_obj, model, dls, suggested_lr
)

pretrained_model_name = save_path + save_pretrained_model + ".pth"

model = transfer_weights(pretrained_model_name, model)
model = transfer_weights(pretrained_model_path, model)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion SeasonTST/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def scale(self, batch):
for var, data_var in batch.data_vars.items():
batch[var] = (
data_var - self.scaling_factors["mean"][var]
) / self.scaling_factors["mean"][var]
) / self.scaling_factors["std"][var]
return batch

def __len__(self):
Expand Down
6 changes: 2 additions & 4 deletions SeasonTST/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_dls(
batch_size=config_obj.batch_size,
workers=config_obj.num_workers,
prefetch_factor=config_obj.prefetch_factor,
drop_last=config_obj.drop_last
)

dls.vars, dls.len = dls.train.dataset[0][0].shape[1], config_obj.sequence_length
Expand Down Expand Up @@ -74,17 +75,14 @@ def get_model(config, headtype="pretrain", weights_path=None, exclude_head=True)
return model


def find_lr(config_obj, dls):
def find_lr(model, config_obj, dls):
"""
# This method typically involves training the model for a few epochs with a range of learning rates and recording
the loss at each step. The learning rate that gives the fastest decrease in loss is considered optimal or
near-optimal for the training process.

:param config_obj:
:return:
"""

model = get_model(config_obj)
# get loss
loss_func = torch.nn.MSELoss(reduction="mean")
# get callbacks
Expand Down
Loading