-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support arbitrary metric logging from torchmetrics #677
base: main
Are you sure you want to change the base?
Conversation
6232173
to
c43eff6
Compare
/build-ci |
44e7217
to
a79fc4e
Compare
Codecov ReportAttention: Patch coverage is
✅ All tests successful. No failed tests found. Additional details and impacted files@@ Coverage Diff @@
## main #677 +/- ##
=======================================
Coverage ? 86.27%
=======================================
Files ? 119
Lines ? 7249
Branches ? 0
=======================================
Hits ? 6254
Misses ? 995
Partials ? 0 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tentatively approve to unblock, but the tests are failing, and I wasn’t able to experiment with it.
sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/model.py
Outdated
Show resolved
Hide resolved
# configure metrics | ||
self.train_metric = self.config.train_metric.get_instance() if self.config.train_metric else None | ||
self.valid_metric = self.config.valid_metric.get_instance() if self.config.valid_metric else None | ||
if (self.train_metric or self.valid_metric) and not self.is_on_logging_device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, so the logic is that at least one rank will throw this error if model parallelism is initiated? Is it any metric logging? or only torch metric?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes and this is an imperfect solution which might lock the other devices.
We can condition on tensor/pipeline model parallelism If parallel_state
has the functionality to inspect the tensor/pipeline model parallel size even when tp=pp=1. get_tensor_model_parallel_world_size
and get_pipeline_model_parallel_world_size
will throw an error in such case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be great to have suggestions on alternatives.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho, this error should be thrown when the metric config is initialised. In pydantic it could be
from pydantic import BaseModel, Field
class TorchmetricsConfig(BaseModel):
class_path: str
task: Literal["lm", "classification", "regression"]
metric_name: str
kwargs: Optional[dict[str, Any]] = Field(default_factory=dict)
model_parallelism: bool = False
....
def validate_parallel_logging(self) -> None:
if self.model_parallelism:
raise NotImplementedError(f"{self.__class__.__name__} logging is not implemented with model parallelism yet.")
and then you would need to intilised metrcis with
valid_metric = TorchmetricsConfig(
class_path="Accuracy",
task="classification",
kwargs={
"task": "multiclass",
"threshold": 0.5,
"num_classes": data_module.train_dataset.label_tokenizer.vocab_size,
},
metric_name="val_acc",
model_parallelism: any([p > 1 [tensor_model_parallel, pipeline_model_parallel]])
)
alternatively you can add parallelism check to the config and pass pp and tp settings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since model parallelism is supported at BionemoLightningModule
level I prefer to move the model parallelism check outside of TorchmetricsConfig
. I have moved that into the training and finetuning scripts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we get some tests for these new classes? Would be great to have unit tests around MetricConfig
that show how it's used, as well as a short training run with some simple model that ensures it gets serialized correctly, produces the right results, etc.
@@ -84,6 +84,7 @@ def test_esm2_finetune_token_classifier( | |||
assert weights_ckpt.is_dir() | |||
assert io.is_distributed_ckpt(weights_ckpt) | |||
assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1] | |||
# assert trainer.logged_metrics["val_acc"].item() <= 0.5 # TODO @farhad for a reasonable value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@farhadrgh will have to pick your brain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The tests for fine-tuning are on a dummy dataset which overfits. I can replace that with a more reasonable dataset and uncomment these lines later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Feel free to push your commits directly.
@@ -136,6 +137,7 @@ def test_esm2_finetune_regressor( | |||
assert weights_ckpt.is_dir() | |||
assert io.is_distributed_ckpt(weights_ckpt) | |||
assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1] | |||
# assert trainer.logged_metrics["val_mse"].item() <= 0.5 # TODO @farhad for a reasonable value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@farhadrgh will have to pick your brain.
@@ -189,6 +191,7 @@ def test_esm2_finetune_classifier( | |||
assert weights_ckpt.is_dir() | |||
assert io.is_distributed_ckpt(weights_ckpt) | |||
assert simple_ft_metrics.collection_train["loss"][0] > simple_ft_metrics.collection_train["loss"][-1] | |||
# assert trainer.logged_metrics["val_acc"].item() <= 0.5 # TODO @farhad for a reasonable value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@farhadrgh will have to pick your brain.
|
||
|
||
@dataclass | ||
class TorchmetricsConfig: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not to use pydantic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would reimplement it to pydantic and add some validations
ie
from pydantic import BaseModel, Field, model_validator
class TorchmetricsConfig(BaseModel):
class_path: str
task: Literal["lm", "classification", "regression"]
metric_name: str
kwargs: Optional[dict[str, Any]] = Field(default_factory=dict)
@model_validator(mode='after')
def validate_class_path(self) -> 'TorchmetricsConfig':
try:
self.get_instance()
except (ImportError, AttributeError) as e:
raise ValueError(f"Invalid class_path: {self.class_path}. Error: {str(e)}")
return self
def get_instance(self) -> torchmetrics.Metric:
"""Dynamically imports and instantiates the metric class."""
if "." in self.class_path:
module_path, class_name = self.class_path.rsplit(".", 1)
module = importlib.import_module(f"torchmetrics.{module_path}")
else:
class_name = self.class_path
module = importlib.import_module("torchmetrics")
cls = getattr(module, class_name)
return cls(**self.kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we still have implementation beyond pydantic, e.g. train_esm.py
and finetune_esm.py
where most unittest and dev happen. That alone will be another discussion.
# configure metrics | ||
self.train_metric = self.config.train_metric.get_instance() if self.config.train_metric else None | ||
self.valid_metric = self.config.valid_metric.get_instance() if self.config.valid_metric else None | ||
if (self.train_metric or self.valid_metric) and not self.is_on_logging_device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
imho, this error should be thrown when the metric config is initialised. In pydantic it could be
from pydantic import BaseModel, Field
class TorchmetricsConfig(BaseModel):
class_path: str
task: Literal["lm", "classification", "regression"]
metric_name: str
kwargs: Optional[dict[str, Any]] = Field(default_factory=dict)
model_parallelism: bool = False
....
def validate_parallel_logging(self) -> None:
if self.model_parallelism:
raise NotImplementedError(f"{self.__class__.__name__} logging is not implemented with model parallelism yet.")
and then you would need to intilised metrcis with
valid_metric = TorchmetricsConfig(
class_path="Accuracy",
task="classification",
kwargs={
"task": "multiclass",
"threshold": 0.5,
"num_classes": data_module.train_dataset.label_tokenizer.vocab_size,
},
metric_name="val_acc",
model_parallelism: any([p > 1 [tensor_model_parallel, pipeline_model_parallel]])
)
alternatively you can add parallelism check to the config and pass pp and tp settings
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
This reverts commit a180864. Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
Signed-off-by: sichu <[email protected]>
ecc9f67
to
4d61dc6
Compare
Description
Training and validation
torchmetrics.Metric
can is organized byTorchmetricsConfig
. This encapsulates metric class instantiation and naming throughget_metric_name
instead of usingfield_factory
.Currently model parallelism is not supported and will raise
NotImplementedError
.Type of changes