Skip to content

Commit

Permalink
refactor: change literal type to str due to a known bug in Typer
Browse files Browse the repository at this point in the history
  • Loading branch information
SkalskiP authored and onuralpszr committed Sep 24, 2024
1 parent 379b658 commit 9d64e32
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions maestro/trainer/models/florence_2/entrypoint.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import dataclasses
from typing import Annotated, Literal, Optional, Union
from typing import Annotated, Optional

import rich
import torch
Expand All @@ -16,7 +16,7 @@
DEFAULT_FLORENCE2_MODEL_REVISION,
DEVICE,
)
from maestro.trainer.models.florence_2.core import Configuration, LoraInitLiteral
from maestro.trainer.models.florence_2.core import Configuration
from maestro.trainer.models.florence_2.core import evaluate as florence2_evaluate
from maestro.trainer.models.florence_2.core import train as florence2_train

Expand Down Expand Up @@ -70,15 +70,15 @@ def train(
typer.Option("--epochs", help="Number of training epochs"),
] = 10,
optimizer: Annotated[
Literal["sgd", "adamw", "adam"],
str,
typer.Option("--optimizer", help="Optimizer to use for training"),
] = "adamw",
lr: Annotated[
float,
typer.Option("--lr", help="Learning rate for the optimizer"),
] = 1e-5,
lr_scheduler: Annotated[
Literal["linear", "cosine", "polynomial"],
str,
typer.Option("--lr_scheduler", help="Learning rate scheduler"),
] = "linear",
batch_size: Annotated[
Expand Down Expand Up @@ -110,15 +110,15 @@ def train(
typer.Option("--lora_dropout", help="Dropout probability for LoRA layers"),
] = 0.05,
bias: Annotated[
Literal["none", "all", "lora_only"],
str,
typer.Option("--bias", help="Which bias to train"),
] = "none",
use_rslora: Annotated[
bool,
typer.Option("--use_rslora/--no_use_rslora", help="Whether to use RSLoRA"),
] = True,
init_lora_weights: Annotated[
Union[bool, LoraInitLiteral],
str,
typer.Option("--init_lora_weights", help="How to initialize LoRA weights"),
] = "gaussian",
output_dir: Annotated[
Expand Down

0 comments on commit 9d64e32

Please sign in to comment.