Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
runtime=_task_models.RuntimeMetadata(
type=_task_models.RuntimeMetadata.RuntimeType.FLYTE_SDK,
version=__version__,
flavor='sagemaker'
),
discoverable=cacheable,
timeout=timeout,
retries=_literal_models.RetryStrategy(retries=retries),
interruptible=interruptible,
discovery_version=cache_version,
deprecated_error_message="",
),
interface=_interface.TypedInterface(
inputs={
"hpo_job_config": _interface_model.Variable(
_sdk_types.Types.Proto(_hpo_job_pb2.HPOJobConfig).to_flyte_literal_type(), ""
),
},
outputs={
"model": _interface_model.Variable(
type=_idl_types.LiteralType(
blob=_core_types.BlobType(
format="",
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
),
description=""
)
}
),
custom=MessageToDict(hpo_job),
)
def to_flyte_idl(self):
if self._tuning_strategy == _sdk_sagemaker_types.HyperparameterTuningStrategy.BAYESIAN:
idl_strategy = _idl_hpo_job.HPOJobConfig.HyperparameterTuningStrategy.BAYESIAN
elif self._tuning_strategy == _sdk_sagemaker_types.HyperparameterTuningStrategy.RANDOM:
idl_strategy = _idl_hpo_job.HPOJobConfig.HyperparameterTuningStrategy.RANDOM
else:
raise _user_exceptions.FlyteValidationException(
"Invalid Hyperparameter Tuning Strategy: {}".format(self._tuning_strategy))
if self._training_job_early_stopping_type == _sdk_sagemaker_types.TrainingJobEarlyStoppingType.OFF:
idl_training_early_stopping_type = _idl_hpo_job.HPOJobConfig.TrainingJobEarlyStoppingType.OFF
elif self._training_job_early_stopping_type == _sdk_sagemaker_types.TrainingJobEarlyStoppingType.AUTO:
idl_training_early_stopping_type = _idl_hpo_job.HPOJobConfig.TrainingJobEarlyStoppingType.AUTO
else:
raise _user_exceptions.FlyteValidationException(
"Invalid Training Job Early Stopping Type (in HPO Config): {}".format(
self._training_job_early_stopping_type))
return _idl_hpo_job.HPOJobConfig(
hyperparameter_ranges=self._hyperparameter_ranges.to_flyte_idl(),
tuning_strategy=idl_strategy,
tuning_objective=self._tuning_objective.to_flyte_idl(),
training_job_early_stopping_type=idl_training_early_stopping_type,
)
def to_flyte_idl(self):
return _idl_hpo_job.HPOJob(
max_number_of_training_jobs=self._max_number_of_training_jobs,
max_parallel_training_jobs=self._max_parallel_training_jobs,
training_job=self._training_job,
)
def to_flyte_idl(self):
if self.objective_type == _sdk_sagemaker_types.HyperparameterTuningObjectiveType.MINIMIZE:
objective_type = _idl_hpo_job.HyperparameterTuningObjective.MINIMIZE
elif self.objective_type == _sdk_sagemaker_types.HyperparameterTuningObjectiveType.MAXIMIZE:
objective_type = _idl_hpo_job.HyperparameterTuningObjective.MAXIMIZE
else:
raise _user_exceptions.FlyteValidationException(
"Invalid SageMaker Hyperparameter Tuning Objective Type Specified"
)
return _idl_hpo_job.HyperparameterTuningObjective(
objective_type=objective_type,
metric_name=self._metric_name,
)
def to_flyte_idl(self):
if self.objective_type == _sdk_sagemaker_types.HyperparameterTuningObjectiveType.MINIMIZE:
objective_type = _idl_hpo_job.HyperparameterTuningObjective.MINIMIZE
elif self.objective_type == _sdk_sagemaker_types.HyperparameterTuningObjectiveType.MAXIMIZE:
objective_type = _idl_hpo_job.HyperparameterTuningObjective.MAXIMIZE
else:
raise _user_exceptions.FlyteValidationException(
"Invalid SageMaker Hyperparameter Tuning Objective Type Specified"
)
return _idl_hpo_job.HyperparameterTuningObjective(
objective_type=objective_type,
metric_name=self._metric_name,
)