66
66
RemoteDebugConfig ,
67
67
SessionChainingConfig ,
68
68
InputData ,
69
+ MetricDefinition ,
69
70
)
70
71
71
72
from sagemaker .modules .local_core .local_container import _LocalContainer
@@ -119,7 +120,8 @@ class ModelTrainer(BaseModel):
119
120
from sagemaker.modules.train import ModelTrainer
120
121
from sagemaker.modules.configs import SourceCode, Compute, InputData
121
122
122
- source_code = SourceCode(source_dir="source", entry_script="train.py")
123
+ ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data']
124
+ source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
123
125
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
124
126
model_trainer = ModelTrainer(
125
127
training_image=training_image,
@@ -238,6 +240,7 @@ class ModelTrainer(BaseModel):
238
240
_infra_check_config : Optional [InfraCheckConfig ] = PrivateAttr (default = None )
239
241
_session_chaining_config : Optional [SessionChainingConfig ] = PrivateAttr (default = None )
240
242
_remote_debug_config : Optional [RemoteDebugConfig ] = PrivateAttr (default = None )
243
+ _metric_definitions : Optional [List [MetricDefinition ]] = PrivateAttr (default = None )
241
244
242
245
_temp_recipe_train_dir : Optional [TemporaryDirectory ] = PrivateAttr (default = None )
243
246
@@ -654,6 +657,7 @@ def train(
654
657
channel_name = SM_CODE ,
655
658
data_source = self .source_code .source_dir ,
656
659
key_prefix = input_data_key_prefix ,
660
+ ignore_patterns = self .source_code .ignore_patterns ,
657
661
)
658
662
final_input_data_config .append (source_code_channel )
659
663
@@ -675,6 +679,7 @@ def train(
675
679
channel_name = SM_DRIVERS ,
676
680
data_source = tmp_dir .name ,
677
681
key_prefix = input_data_key_prefix ,
682
+ ignore_patterns = self .source_code .ignore_patterns ,
678
683
)
679
684
final_input_data_config .append (sm_drivers_channel )
680
685
@@ -693,6 +698,7 @@ def train(
693
698
training_image_config = self .training_image_config ,
694
699
container_entrypoint = container_entrypoint ,
695
700
container_arguments = container_arguments ,
701
+ metric_definitions = self ._metric_definitions ,
696
702
)
697
703
698
704
resource_config = self .compute ._to_resource_config ()
@@ -755,7 +761,11 @@ def train(
755
761
local_container .train (wait )
756
762
757
763
def create_input_data_channel (
758
- self , channel_name : str , data_source : DataSourceType , key_prefix : Optional [str ] = None
764
+ self ,
765
+ channel_name : str ,
766
+ data_source : DataSourceType ,
767
+ key_prefix : Optional [str ] = None ,
768
+ ignore_patterns : Optional [List [str ]] = None ,
759
769
) -> Channel :
760
770
"""Create an input data channel for the training job.
761
771
@@ -771,6 +781,10 @@ def create_input_data_channel(
771
781
772
782
If specified, local data will be uploaded to:
773
783
``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
784
+ ignore_patterns: (Optional[List[str]]) :
785
+ The ignore patterns to ignore specific files/folders when uploading to S3.
786
+ If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store',
787
+ '.cache', '.ipynb_checkpoints'].
774
788
"""
775
789
channel = None
776
790
if isinstance (data_source , str ):
@@ -810,11 +824,28 @@ def create_input_data_channel(
810
824
)
811
825
if self .sagemaker_session .default_bucket_prefix :
812
826
key_prefix = f"{ self .sagemaker_session .default_bucket_prefix } /{ key_prefix } "
813
- s3_uri = self .sagemaker_session .upload_data (
814
- path = data_source ,
815
- bucket = self .sagemaker_session .default_bucket (),
816
- key_prefix = key_prefix ,
817
- )
827
+ if ignore_patterns and _is_valid_path (data_source , path_type = "Directory" ):
828
+ tmp_dir = TemporaryDirectory ()
829
+ copied_path = os .path .join (
830
+ tmp_dir .name , os .path .basename (os .path .normpath (data_source ))
831
+ )
832
+ shutil .copytree (
833
+ data_source ,
834
+ copied_path ,
835
+ dirs_exist_ok = True ,
836
+ ignore = shutil .ignore_patterns (* ignore_patterns ),
837
+ )
838
+ s3_uri = self .sagemaker_session .upload_data (
839
+ path = copied_path ,
840
+ bucket = self .sagemaker_session .default_bucket (),
841
+ key_prefix = key_prefix ,
842
+ )
843
+ else :
844
+ s3_uri = self .sagemaker_session .upload_data (
845
+ path = data_source ,
846
+ bucket = self .sagemaker_session .default_bucket (),
847
+ key_prefix = key_prefix ,
848
+ )
818
849
channel = Channel (
819
850
channel_name = channel_name ,
820
851
data_source = DataSource (
@@ -861,7 +892,9 @@ def _get_input_data_config(
861
892
channels .append (input_data )
862
893
elif isinstance (input_data , InputData ):
863
894
channel = self .create_input_data_channel (
864
- input_data .channel_name , input_data .data_source , key_prefix = key_prefix
895
+ input_data .channel_name ,
896
+ input_data .data_source ,
897
+ key_prefix = key_prefix ,
865
898
)
866
899
channels .append (channel )
867
900
else :
@@ -1260,3 +1293,33 @@ def with_checkpoint_config(
1260
1293
"""
1261
1294
self .checkpoint_config = checkpoint_config or configs .CheckpointConfig ()
1262
1295
return self
1296
+
1297
+ def with_metric_definitions (
1298
+ self , metric_definitions : List [MetricDefinition ]
1299
+ ) -> "ModelTrainer" : # noqa: D412
1300
+ """Set the metric definitions for the training job.
1301
+
1302
+ Example:
1303
+
1304
+ .. code:: python
1305
+
1306
+ from sagemaker.modules.train import ModelTrainer
1307
+ from sagemaker.modules.configs import MetricDefinition
1308
+
1309
+ metric_definitions = [
1310
+ MetricDefinition(
1311
+ name="loss",
1312
+ regex="Loss: (.*?)",
1313
+ )
1314
+ ]
1315
+
1316
+ model_trainer = ModelTrainer(
1317
+ ...
1318
+ ).with_metric_definitions(metric_definitions)
1319
+
1320
+ Args:
1321
+ metric_definitions (List[MetricDefinition]):
1322
+ The metric definitions for the training job.
1323
+ """
1324
+ self ._metric_definitions = metric_definitions
1325
+ return self
0 commit comments