Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _set_trial_intermediate_value_without_commit(self, session, trial_id, step,
intermediate_value):
# type: (orm.Session, int, int, float) -> bool
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
self.check_trial_is_updatable(trial_id, trial.state)
trial_value = models.TrialValueModel.find_by_trial_and_step(trial, step, session)
if trial_value is not None:
return False
trial_value = models.TrialValueModel(trial_id=trial_id,
step=step,
value=intermediate_value)
session.add(trial_value)
return True
def get_trial_param(self, trial_id, param_name):
# type: (int, str) -> float
session = self.scoped_session()
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
trial_param = models.TrialParamModel.find_or_raise_by_trial_and_param_name(
trial, param_name, session)
# Terminate transaction explicitly to avoid connection timeout during transaction.
self._commit(session)
return trial_param.param_value
def _set_trial_param_without_commit(self, session, trial_id, param_name, param_value_internal,
distribution):
# type: (orm.Session, int, str, float, distributions.BaseDistribution) -> bool
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
self.check_trial_is_updatable(trial_id, trial.state)
trial_param = \
models.TrialParamModel.find_by_trial_and_param_name(trial, param_name, session)
if trial_param is not None:
# Raise error in case distribution is incompatible.
distributions.check_distribution_compatibility(
distributions.json_to_distribution(trial_param.distribution_json), distribution)
# Terminate transaction explicitly to avoid connection timeout during transaction.
self._commit(session)
# Return False when distribution is compatible but parameter has already been set.
return False
param = models.TrialParamModel(
def set_trial_state(self, trial_id, state):
# type: (int, structs.TrialState) -> None
session = self.scoped_session()
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
self.check_trial_is_updatable(trial_id, trial.state)
trial.state = state
if state.is_finished():
trial.datetime_complete = datetime.now()
self._commit(session)
def get_study_id_from_trial_id(self, trial_id):
# type: (int) -> int
session = self.scoped_session()
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
study_id = trial.study_id
session.close()
return study_id
def set_trial_param(self, trial_id, param_name, param_value_internal, distribution):
# type: (int, str, float, distributions.BaseDistribution) -> bool
session = self.scoped_session()
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
self.check_trial_is_updatable(trial_id, trial.state)
trial_param = \
models.TrialParamModel.find_by_trial_and_param_name(trial, param_name, session)
if trial_param is not None:
# Raise error in case distribution is incompatible.
distributions.check_distribution_compatibility(
distributions.json_to_distribution(trial_param.distribution_json), distribution)
session.close()
# Return False when distribution is compatible but parameter has already been set.
return False
param = models.TrialParamModel(
def get_trial_param(self, trial_id, param_name):
# type: (int, str) -> float
session = self.scoped_session()
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
trial_param = models.TrialParamModel.find_or_raise_by_trial_and_param_name(
trial, param_name, session)
param_value = trial_param.param_value
session.close()
return param_value
def _create_new_trial_number(self, trial_id):
# type: (int) -> int
session = self.scoped_session()
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
trial_number = trial.count_past_trials(session)
self.set_trial_system_attr(trial.trial_id, '_number', trial_number)
session.close()
return trial_number