Skip to content

Commit

Permalink
Hotfix: Invalid instance type check for HPO/Training (#205)
Browse files Browse the repository at this point in the history
Training/HPO jobs cannot use KMS keys for volume encryption when launched on instance types that provide their own volume encryption. The previous check for this condition was incorrect because it did not account for the ml.* prefix used by SageMaker instances.
  • Loading branch information
dustins authored Jun 28, 2024
1 parent 7e2e531 commit e23eaad
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
3 changes: 2 additions & 1 deletion backend/src/ml_space_lambda/hpo_job/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
def _normalize_job_definition(definition, iam_role, param_file):
definition["RoleArn"] = iam_role
definition["OutputDataConfig"]["KmsKeyId"] = param_file["pSMSKMSKeyId"]
instance_type = definition["ResourceConfig"]["InstanceType"].removeprefix("ml.")
definition["ResourceConfig"]["VolumeKmsKeyId"] = (
"" if definition["ResourceConfig"]["InstanceType"] in kms_unsupported_instances() else param_file["pSMSKMSKeyId"]
"" if instance_type in kms_unsupported_instances() else param_file["pSMSKMSKeyId"]
)
definition["VpcConfig"] = {
"SecurityGroupIds": param_file["pSMSSecurityGroupId"],
Expand Down
5 changes: 2 additions & 3 deletions backend/src/ml_space_lambda/training_job/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create(event, context):
if "AlgorithmName" in algorithm_specs and "TrainingImage" in algorithm_specs:
del algorithm_specs["AlgorithmName"]

instance_type = resource_config["InstanceType"].removeprefix("ml.")
training_job_definition = dict(
TrainingJobName=training_job_name,
HyperParameters=hyper_parameters,
Expand All @@ -87,9 +88,7 @@ def create(event, context):
"InstanceType": resource_config["InstanceType"],
"InstanceCount": int(resource_config["InstanceCount"]),
"VolumeSizeInGB": int(resource_config["VolumeSizeInGB"]),
"VolumeKmsKeyId": (
"" if resource_config["InstanceType"] in kms_unsupported_instances() else param_file["pSMSKMSKeyId"]
),
"VolumeKmsKeyId": ("" if instance_type in kms_unsupported_instances() else param_file["pSMSKMSKeyId"]),
},
VpcConfig={
"SecurityGroupIds": param_file["pSMSSecurityGroupId"],
Expand Down

0 comments on commit e23eaad

Please sign in to comment.