Skip to content

Commit

Permalink
Merge pull request #3350 from flairNLP/device_check
Browse files Browse the repository at this point in the history
Add a convenience conversion for flair.device
  • Loading branch information
alanakbik authored Oct 24, 2023
2 parents ed53c42 + c9f7f70 commit a0e5444
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions flair/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,10 @@ def train_custom(
if epoch == 0:
self.check_for_and_delete_previous_best_models(base_path)

# Sanity conversion: if flair.device was set as a string, convert to torch.device
if isinstance(flair.device, str):
flair.device = torch.device(flair.device)

# -- AmpPlugin -> wraps with AMP
# -- AnnealingPlugin -> initialize schedulers (requires instantiated optimizer)
with contextlib.ExitStack() as context_stack:
Expand Down

0 comments on commit a0e5444

Please sign in to comment.