Skip to content

Commit 0215512

Browse files
chad119Chad Chiang
andauthored
feat: Add support for MetricDefinitions in ModelTrainer (#5202)
* feat: Add support for MetricDefinitions in ModelTrainer * style fix * Update model_trainer.py to generate the doc * resolve unit test failed * solve another unit test error --------- Co-authored-by: Chad Chiang <[email protected]>
1 parent 70b2f9a commit 0215512

File tree

3 files changed

+63
-0
lines changed

3 files changed

+63
-0
lines changed

src/sagemaker/modules/configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
RemoteDebugConfig,
4343
SessionChainingConfig,
4444
InstanceGroup,
45+
MetricDefinition,
4546
)
4647

4748
from sagemaker.modules.utils import convert_unassigned_to_none
@@ -68,6 +69,7 @@
6869
"Compute",
6970
"Networking",
7071
"InputData",
72+
"MetricDefinition",
7173
]
7274

7375

src/sagemaker/modules/train/model_trainer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
RemoteDebugConfig,
6767
SessionChainingConfig,
6868
InputData,
69+
MetricDefinition,
6970
)
7071

7172
from sagemaker.modules.local_core.local_container import _LocalContainer
@@ -239,6 +240,7 @@ class ModelTrainer(BaseModel):
239240
_infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None)
240241
_session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None)
241242
_remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
243+
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)
242244

243245
_temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
244246

@@ -696,6 +698,7 @@ def train(
696698
training_image_config=self.training_image_config,
697699
container_entrypoint=container_entrypoint,
698700
container_arguments=container_arguments,
701+
metric_definitions=self._metric_definitions,
699702
)
700703

701704
resource_config = self.compute._to_resource_config()
@@ -1290,3 +1293,33 @@ def with_checkpoint_config(
12901293
"""
12911294
self.checkpoint_config = checkpoint_config or configs.CheckpointConfig()
12921295
return self
1296+
1297+
def with_metric_definitions(
1298+
self, metric_definitions: List[MetricDefinition]
1299+
) -> "ModelTrainer": # noqa: D412
1300+
"""Set the metric definitions for the training job.
1301+
1302+
Example:
1303+
1304+
.. code:: python
1305+
1306+
from sagemaker.modules.train import ModelTrainer
1307+
from sagemaker.modules.configs import MetricDefinition
1308+
1309+
metric_definitions = [
1310+
MetricDefinition(
1311+
name="loss",
1312+
regex="Loss: (.*?)",
1313+
)
1314+
]
1315+
1316+
model_trainer = ModelTrainer(
1317+
...
1318+
).with_metric_definitions(metric_definitions)
1319+
1320+
Args:
1321+
metric_definitions (List[MetricDefinition]):
1322+
The metric definitions for the training job.
1323+
"""
1324+
self._metric_definitions = metric_definitions
1325+
return self

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
FileSystemDataSource,
6565
Channel,
6666
DataSource,
67+
MetricDefinition,
6768
)
6869
from sagemaker.modules.distributed import Torchrun, SMP, MPI
6970
from sagemaker.modules.train.sm_recipes.utils import _load_recipes_cfg
@@ -705,6 +706,32 @@ def test_remote_debug_config(mock_training_job, modules_session):
705706
)
706707

707708

709+
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
710+
def test_metric_definitions(mock_training_job, modules_session):
711+
image_uri = DEFAULT_IMAGE
712+
role = DEFAULT_ROLE
713+
metric_definitions = [
714+
MetricDefinition(
715+
name="loss",
716+
regex="Loss: (.*?);",
717+
)
718+
]
719+
720+
model_trainer = ModelTrainer(
721+
training_image=image_uri, sagemaker_session=modules_session, role=role
722+
).with_metric_definitions(metric_definitions)
723+
724+
with patch("sagemaker.modules.train.model_trainer.Session.upload_data") as mock_upload_data:
725+
mock_upload_data.return_value = "s3://dummy-bucket/dummy-prefix"
726+
model_trainer.train()
727+
728+
mock_training_job.create.assert_called_once()
729+
assert (
730+
mock_training_job.create.call_args.kwargs["algorithm_specification"].metric_definitions
731+
== metric_definitions
732+
)
733+
734+
708735
@patch("sagemaker.modules.train.model_trainer._get_unique_name")
709736
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
710737
def test_model_trainer_full_init(mock_training_job, mock_unique_name, modules_session):
@@ -822,6 +849,7 @@ def mock_upload_data(path, bucket, key_prefix):
822849
training_input_mode=training_input_mode,
823850
training_image=training_image,
824851
algorithm_name=None,
852+
metric_definitions=None,
825853
container_entrypoint=DEFAULT_ENTRYPOINT,
826854
container_arguments=DEFAULT_ARGUMENTS,
827855
training_image_config=training_image_config,

0 commit comments

Comments
 (0)