Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def to_flyte_idl(self):
return _training_job.StoppingCondition(
max_runtime_in_seconds=self.max_runtime_in_seconds,
max_wait_time_in_seconds=self.max_wait_time_in_seconds,
)
def to_flyte_idl(self):
return _training_job.TrainingJobConfig(
instance_count=self.instance_count,
instance_type=self.instance_type,
volume_size_in_gb=self.volume_size_in_gb,
)
def from_flyte_idl(cls, pb2_object):
input_mode = _sdk_sagemaker_types.InputMode.FILE
if pb2_object.input_mode == _training_job.InputMode.PIPE:
input_mode = _sdk_sagemaker_types.InputMode.PIPE
algorithm_name = _sdk_sagemaker_types.AlgorithmName.CUSTOM
if pb2_object.algorithm_name == _training_job.AlgorithmName.XGBOOST:
algorithm_name = _sdk_sagemaker_types.AlgorithmName.XGBOOST
return cls(
input_mode=input_mode,
algorithm_name=algorithm_name,
algorithm_version=pb2_object.algorithm_version,
metric_definitions=[MetricDefinition.from_flyte_idl(m) for m in pb2_object.metric_definitions],
)
def to_flyte_idl(self):
"""
:return: _training_job.TrainingJob
"""
return _training_job.TrainingJob(
algorithm_specification=self.algorithm_specification.to_flyte_idl(),
training_job_config=self.training_job_config.to_flyte_idl(),
)
input_mode = _training_job.InputMode.FILE
elif self.input_mode == _sdk_sagemaker_types.InputMode.PIPE:
input_mode = _training_job.InputMode.PIPE
else:
raise _user_exceptions.FlyteValidationException(
"Invalid SageMaker Input Mode Specified: [{}]".format(self.input_mode))
if self.algorithm_name == _sdk_sagemaker_types.AlgorithmName.CUSTOM:
alg_name = _training_job.AlgorithmName.CUSTOM
elif self.algorithm_name == _sdk_sagemaker_types.AlgorithmName.XGBOOST:
alg_name = _training_job.AlgorithmName.XGBOOST
else:
raise _user_exceptions.FlyteValidationException(
"Invalid SageMaker Algorithm Name Specified: [{}]".format(self.algorithm_name))
return _training_job.AlgorithmSpecification(
input_mode=input_mode,
algorithm_name=alg_name,
algorithm_version=self.algorithm_version,
metric_definitions=[m.to_flyte_idl() for m in self.metric_definitions],
)
def to_flyte_idl(self):
if self.input_mode == _sdk_sagemaker_types.InputMode.FILE:
input_mode = _training_job.InputMode.FILE
elif self.input_mode == _sdk_sagemaker_types.InputMode.PIPE:
input_mode = _training_job.InputMode.PIPE
else:
raise _user_exceptions.FlyteValidationException(
"Invalid SageMaker Input Mode Specified: [{}]".format(self.input_mode))
if self.algorithm_name == _sdk_sagemaker_types.AlgorithmName.CUSTOM:
alg_name = _training_job.AlgorithmName.CUSTOM
elif self.algorithm_name == _sdk_sagemaker_types.AlgorithmName.XGBOOST:
alg_name = _training_job.AlgorithmName.XGBOOST
else:
raise _user_exceptions.FlyteValidationException(
"Invalid SageMaker Algorithm Name Specified: [{}]".format(self.algorithm_name))
return _training_job.AlgorithmSpecification(
input_mode=input_mode,
algorithm_name=alg_name,
algorithm_version=self.algorithm_version,
metric_definitions=[m.to_flyte_idl() for m in self.metric_definitions],
)
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
),
),
description="",
),
"validation": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="csv",
dimensionality=_core_types.BlobType.BlobDimensionality.MULTIPART
),
),
description="",
),
"stopping_condition": _interface_model.Variable(
_sdk_types.Types.Proto(_training_job_pb2.StoppingCondition).to_flyte_literal_type(), ""
)
},
outputs={
"model": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="",
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
description=""
)
}
),
custom=MessageToDict(self._training_job_model.to_flyte_idl()),
)