From 9d64e32f152deaca12e759fbeaa88fe55b1c10f6 Mon Sep 17 00:00:00 2001 From: SkalskiP Date: Tue, 24 Sep 2024 14:57:39 +0200 Subject: [PATCH] refactor: change literal type to str due to a known bug in Typer --- maestro/trainer/models/florence_2/entrypoint.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/maestro/trainer/models/florence_2/entrypoint.py b/maestro/trainer/models/florence_2/entrypoint.py index aecefd9..f8ae79b 100644 --- a/maestro/trainer/models/florence_2/entrypoint.py +++ b/maestro/trainer/models/florence_2/entrypoint.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Annotated, Literal, Optional, Union +from typing import Annotated, Optional import rich import torch @@ -16,7 +16,7 @@ DEFAULT_FLORENCE2_MODEL_REVISION, DEVICE, ) -from maestro.trainer.models.florence_2.core import Configuration, LoraInitLiteral +from maestro.trainer.models.florence_2.core import Configuration from maestro.trainer.models.florence_2.core import evaluate as florence2_evaluate from maestro.trainer.models.florence_2.core import train as florence2_train @@ -70,7 +70,7 @@ def train( typer.Option("--epochs", help="Number of training epochs"), ] = 10, optimizer: Annotated[ - Literal["sgd", "adamw", "adam"], + str, typer.Option("--optimizer", help="Optimizer to use for training"), ] = "adamw", lr: Annotated[ @@ -78,7 +78,7 @@ def train( typer.Option("--lr", help="Learning rate for the optimizer"), ] = 1e-5, lr_scheduler: Annotated[ - Literal["linear", "cosine", "polynomial"], + str, typer.Option("--lr_scheduler", help="Learning rate scheduler"), ] = "linear", batch_size: Annotated[ @@ -110,7 +110,7 @@ def train( typer.Option("--lora_dropout", help="Dropout probability for LoRA layers"), ] = 0.05, bias: Annotated[ - Literal["none", "all", "lora_only"], + str, typer.Option("--bias", help="Which bias to train"), ] = "none", use_rslora: Annotated[ @@ -118,7 +118,7 @@ def train( typer.Option("--use_rslora/--no_use_rslora", help="Whether to use RSLoRA"), ] = True, init_lora_weights: Annotated[ - Union[bool, LoraInitLiteral], + str, typer.Option("--init_lora_weights", help="How to initialize LoRA weights"), ] = "gaussian", output_dir: Annotated[