Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
pytest.raises(
ValueError, lambda: distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['u'], EXAMPLE_DISTRIBUTIONS['l']))
# test dynamic value range (CategoricalDistribution)
pytest.raises(
ValueError, lambda: distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['c2'],
distributions.CategoricalDistribution(choices=('Roppongi', 'Akasaka'))))
# test dynamic value range (except CategoricalDistribution)
distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['u'], distributions.UniformDistribution(low=-3.0, high=-2.0))
distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['l'], distributions.LogUniformDistribution(low=-0.1, high=1.0))
distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['du'],
distributions.DiscreteUniformDistribution(low=-1.0, high=10.0, q=3.))
distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['iu'], distributions.IntUniformDistribution(low=-1, high=1))
# test different distribution classes
pytest.raises(
ValueError, lambda: distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['u'], EXAMPLE_DISTRIBUTIONS['l']))
# test dynamic value range (CategoricalDistribution)
pytest.raises(
ValueError, lambda: distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['c2'],
distributions.CategoricalDistribution(choices=('Roppongi', 'Akasaka'))))
# test dynamic value range (except CategoricalDistribution)
distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['u'], distributions.UniformDistribution(low=-3.0, high=-2.0))
distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['l'], distributions.LogUniformDistribution(low=-0.1, high=1.0))
distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['du'],
distributions.DiscreteUniformDistribution(low=-1.0, high=10.0, q=3.))
distributions.check_distribution_compatibility(
EXAMPLE_DISTRIBUTIONS['iu'], distributions.IntUniformDistribution(low=-1, high=1))
def _is_relative_param(self, name, distribution):
# type: (str, BaseDistribution) -> bool
if name not in self.relative_params:
return False
if name not in self.relative_search_space:
raise ValueError("The parameter '{}' was sampled by `sample_relative` method "
"but it is not contained in the relative search space.".format(name))
relative_distribution = self.relative_search_space[name]
distributions.check_distribution_compatibility(relative_distribution, distribution)
param_value = self.relative_params[name]
param_value_in_internal_repr = distribution.to_internal_repr(param_value)
return distribution._contains(param_value_in_internal_repr)
def _suggest(self, name, distribution):
# type: (str, BaseDistribution) -> Any
if name not in self._params:
raise ValueError('The value of the parameter \'{}\' is not found. Please set it at '
'the construction of the FixedTrial object.'.format(name))
value = self._params[name]
param_value_in_internal_repr = distribution.to_internal_repr(value)
if not distribution._contains(param_value_in_internal_repr):
raise ValueError("The value {} of the parameter '{}' is out of "
"the range of the distribution {}.".format(value, name, distribution))
if name in self._distributions:
distributions.check_distribution_compatibility(self._distributions[name], distribution)
self._suggested_params[name] = value
self._distributions[name] = distribution
return value
def _suggest(self, name, distribution):
# type: (str, BaseDistribution) -> Any
if name not in self._params:
raise ValueError('The value of the parameter \'{}\' is not found. Please set it at '
'the construction of the FixedTrial object.'.format(name))
value = self._params[name]
param_value_in_internal_repr = distribution.to_internal_repr(value)
if not distribution._contains(param_value_in_internal_repr):
raise ValueError("The value {} of the parameter '{}' is out of "
"the range of the distribution {}.".format(value, name, distribution))
if name in self._distributions:
distributions.check_distribution_compatibility(self._distributions[name], distribution)
self._suggested_params[name] = value
self._distributions[name] = distribution
return value
def set_trial_param(self, trial_id, param_name, param_value_internal, distribution):
# type: (int, str, float, distributions.BaseDistribution) -> bool
session = self.scoped_session()
trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
self.check_trial_is_updatable(trial_id, trial.state)
trial_param = \
models.TrialParamModel.find_by_trial_and_param_name(trial, param_name, session)
if trial_param is not None:
# Raise error in case distribution is incompatible.
distributions.check_distribution_compatibility(
distributions.json_to_distribution(trial_param.distribution_json), distribution)
session.close()
# Return False when distribution is compatible but parameter has already been set.
return False
param = models.TrialParamModel(
trial_id=trial_id,
param_name=param_name,
param_value=param_value_internal,
distribution_json=distributions.distribution_to_json(distribution))
param.check_and_add(session)
commit_success = self._commit_with_integrity_check(session)
def _check_compatibility_with_previous_trial_param_distributions(self, session):
# type: (orm.Session) -> None
trial = TrialModel.find_or_raise_by_id(self.trial_id, session)
previous_record = session.query(TrialParamModel).join(TrialModel). \
filter(TrialModel.study_id == trial.study_id). \
filter(TrialParamModel.param_name == self.param_name).first()
if previous_record is not None:
distributions.check_distribution_compatibility(
distributions.json_to_distribution(previous_record.distribution_json),
distributions.json_to_distribution(self.distribution_json))
def _is_compatible(self, trial):
# type: (FrozenTrial) -> bool
# Thanks to `intersection_search_space()` function, in sequential optimization,
# the parameters of complete trials are always compatible with the search space.
#
# However, in distributed optimization, incompatible trials may complete on a worker
# just after a product search space is calculated on another worker.
for name, distribution in self._search_space.items():
if name not in trial.params:
return False
distributions.check_distribution_compatibility(distribution, trial.distributions[name])
param_value = trial.params[name]
param_internal_value = distribution.to_internal_repr(param_value)
if not distribution._contains(param_internal_value):
return False
return True
def _check_compatibility_with_previous_trial_param_distributions(self, session):
# type: (orm.Session) -> None
trial = TrialModel.find_or_raise_by_id(self.trial_id, session)
previous_record = session.query(TrialParamModel).join(TrialModel). \
filter(TrialModel.study_id == trial.study_id). \
filter(TrialParamModel.param_name == self.param_name).first()
if previous_record is not None:
distributions.check_distribution_compatibility(
distributions.json_to_distribution(previous_record.distribution_json),
distributions.json_to_distribution(self.distribution_json))