Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support arbitrary metric logging from torchmetrics #677

Open
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

sichu2023
Copy link
Collaborator

@sichu2023 sichu2023 commented Feb 5, 2025

Description

Training and validation torchmetrics.Metric can is organized by TorchmetricsConfig. This encapsulates metric class instantiation and naming through get_metric_name instead of using field_factory.

Currently model parallelism is not supported and will raise NotImplementedError.

Type of changes

  • New feature (non-breaking change which adds functionality)

@sichu2023
Copy link
Collaborator Author

/build-ci

@sichu2023 sichu2023 self-assigned this Feb 5, 2025
@sichu2023 sichu2023 force-pushed the sichu/metric-config branch 2 times, most recently from 44e7217 to a79fc4e Compare February 5, 2025 08:49
@codecov-commenter
Copy link

codecov-commenter commented Feb 5, 2025

Codecov Report

Attention: Patch coverage is 90.16393% with 6 lines in your changes missing coverage. Please review.

Please upload report for BASE (main@9cec09f). Learn more about missing BASE report.

✅ All tests successful. No failed tests found.

Files with missing lines Patch % Lines
...-packages/bionemo-llm/src/bionemo/llm/lightning.py 84.00% 4 Missing ⚠️
...emo-esm2/src/bionemo/esm2/scripts/finetune_esm2.py 85.71% 1 Missing ⚠️
...ionemo-esm2/src/bionemo/esm2/scripts/train_esm2.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@           Coverage Diff           @@
##             main     #677   +/-   ##
=======================================
  Coverage        ?   86.27%           
=======================================
  Files           ?      119           
  Lines           ?     7249           
  Branches        ?        0           
=======================================
  Hits            ?     6254           
  Misses          ?      995           
  Partials        ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Collaborator

@farhadrgh farhadrgh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tentatively approve to unblock, but the tests are failing, and I wasn’t able to experiment with it.

# configure metrics
self.train_metric = self.config.train_metric.get_instance() if self.config.train_metric else None
self.valid_metric = self.config.valid_metric.get_instance() if self.config.valid_metric else None
if (self.train_metric or self.valid_metric) and not self.is_on_logging_device:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, so the logic is that at least one rank will throw this error if model parallelism is initiated? Is it any metric logging? or only torch metric?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes and this is an imperfect solution which might lock the other devices.

We can condition on tensor/pipeline model parallelism If parallel_state has the functionality to inspect the tensor/pipeline model parallel size even when tp=pp=1. get_tensor_model_parallel_world_size and get_pipeline_model_parallel_world_size will throw an error in such case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to have suggestions on alternatives.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho, this error should be thrown when the metric config is initialised. In pydantic it could be

from pydantic import BaseModel, Field
 
class TorchmetricsConfig(BaseModel):
    class_path: str
    task: Literal["lm", "classification", "regression"]
    metric_name: str
    kwargs: Optional[dict[str, Any]] = Field(default_factory=dict)
    model_parallelism: bool = False
....
    def validate_parallel_logging(self) -> None:
        if self.model_parallelism:
            raise NotImplementedError(f"{self.__class__.__name__} logging is not implemented with model parallelism yet.")

and then you would need to intilised metrcis with

valid_metric = TorchmetricsConfig(
            class_path="Accuracy",
            task="classification",
            kwargs={
                "task": "multiclass",
                "threshold": 0.5,
                "num_classes": data_module.train_dataset.label_tokenizer.vocab_size,
            },
            metric_name="val_acc",
            model_parallelism: any([p > 1 [tensor_model_parallel, pipeline_model_parallel]])
        )

alternatively you can add parallelism check to the config and pass pp and tp settings

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since model parallelism is supported at BionemoLightningModule level I prefer to move the model parallelism check outside of TorchmetricsConfig. I have moved that into the training and finetuning scripts.

Copy link
Collaborator

@pstjohn pstjohn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we get some tests for these new classes? Would be great to have unit tests around MetricConfig that show how it's used, as well as a short training run with some simple model that ensures it gets serialized correctly, produces the right results, etc.

@@ -84,6 +84,7 @@ def test_esm2_finetune_token_classifier(
assert weights_ckpt.is_dir()
assert io.is_distributed_ckpt(weights_ckpt)
assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1]
# assert trainer.logged_metrics["val_acc"].item() <= 0.5 # TODO @farhad for a reasonable value
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@farhadrgh will have to pick your brain.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests for fine-tuning are on a dummy dataset which overfits. I can replace that with a more reasonable dataset and uncomment these lines later

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Feel free to push your commits directly.

@@ -136,6 +137,7 @@ def test_esm2_finetune_regressor(
assert weights_ckpt.is_dir()
assert io.is_distributed_ckpt(weights_ckpt)
assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1]
# assert trainer.logged_metrics["val_mse"].item() <= 0.5 # TODO @farhad for a reasonable value
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@farhadrgh will have to pick your brain.

@@ -189,6 +191,7 @@ def test_esm2_finetune_classifier(
assert weights_ckpt.is_dir()
assert io.is_distributed_ckpt(weights_ckpt)
assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1]
# assert trainer.logged_metrics["val_acc"].item() <= 0.5 # TODO @farhad for a reasonable value
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@farhadrgh will have to pick your brain.

@sichu2023 sichu2023 requested a review from farhadrgh February 5, 2025 22:06


@dataclass
class TorchmetricsConfig:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not to use pydantic?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would reimplement it to pydantic and add some validations

ie

from pydantic import BaseModel, Field, model_validator

class TorchmetricsConfig(BaseModel):
    class_path: str
    task: Literal["lm", "classification", "regression"]
    metric_name: str
    kwargs: Optional[dict[str, Any]] = Field(default_factory=dict)
    
    @model_validator(mode='after')
    def validate_class_path(self) -> 'TorchmetricsConfig':
        try:
            self.get_instance()
        except (ImportError, AttributeError) as e:
            raise ValueError(f"Invalid class_path: {self.class_path}. Error: {str(e)}")
        return self
    
    def get_instance(self) -> torchmetrics.Metric:
        """Dynamically imports and instantiates the metric class."""
        if "." in self.class_path:
            module_path, class_name = self.class_path.rsplit(".", 1)
            module = importlib.import_module(f"torchmetrics.{module_path}")
        else:
            class_name = self.class_path
            module = importlib.import_module("torchmetrics")
            
        cls = getattr(module, class_name)
        return cls(**self.kwargs)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we still have implementation beyond pydantic, e.g. train_esm.py and finetune_esm.py where most unittest and dev happen. That alone will be another discussion.

sub-packages/bionemo-llm/src/bionemo/llm/model/config.py Outdated Show resolved Hide resolved
# configure metrics
self.train_metric = self.config.train_metric.get_instance() if self.config.train_metric else None
self.valid_metric = self.config.valid_metric.get_instance() if self.config.valid_metric else None
if (self.train_metric or self.valid_metric) and not self.is_on_logging_device:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imho, this error should be thrown when the metric config is initialised. In pydantic it could be

from pydantic import BaseModel, Field
 
class TorchmetricsConfig(BaseModel):
    class_path: str
    task: Literal["lm", "classification", "regression"]
    metric_name: str
    kwargs: Optional[dict[str, Any]] = Field(default_factory=dict)
    model_parallelism: bool = False
....
    def validate_parallel_logging(self) -> None:
        if self.model_parallelism:
            raise NotImplementedError(f"{self.__class__.__name__} logging is not implemented with model parallelism yet.")

and then you would need to intilised metrcis with

valid_metric = TorchmetricsConfig(
            class_path="Accuracy",
            task="classification",
            kwargs={
                "task": "multiclass",
                "threshold": 0.5,
                "num_classes": data_module.train_dataset.label_tokenizer.vocab_size,
            },
            metric_name="val_acc",
            model_parallelism: any([p > 1 [tensor_model_parallel, pipeline_model_parallel]])
        )

alternatively you can add parallelism check to the config and pass pp and tp settings

@sichu2023 sichu2023 requested a review from dorotat-nv February 6, 2025 17:36
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
This reverts commit a180864.

Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
@sichu2023 sichu2023 force-pushed the sichu/metric-config branch from ecc9f67 to 4d61dc6 Compare February 6, 2025 21:19
@sichu2023 sichu2023 enabled auto-merge February 6, 2025 23:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants