How to use the optuna.distributions.check_distribution_compatibility function in optuna

To help you get started, we’ve selected a few optuna examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github optuna / optuna / tests / test_distributions.py View on Github external
pytest.raises(
        ValueError, lambda: distributions.check_distribution_compatibility(
            EXAMPLE_DISTRIBUTIONS['u'], EXAMPLE_DISTRIBUTIONS['l']))

    # test dynamic value range (CategoricalDistribution)
    pytest.raises(
        ValueError, lambda: distributions.check_distribution_compatibility(
            EXAMPLE_DISTRIBUTIONS['c2'],
            distributions.CategoricalDistribution(choices=('Roppongi', 'Akasaka'))))

    # test dynamic value range (except CategoricalDistribution)
    distributions.check_distribution_compatibility(
        EXAMPLE_DISTRIBUTIONS['u'], distributions.UniformDistribution(low=-3.0, high=-2.0))
    distributions.check_distribution_compatibility(
        EXAMPLE_DISTRIBUTIONS['l'], distributions.LogUniformDistribution(low=-0.1, high=1.0))
    distributions.check_distribution_compatibility(
        EXAMPLE_DISTRIBUTIONS['du'],
        distributions.DiscreteUniformDistribution(low=-1.0, high=10.0, q=3.))
    distributions.check_distribution_compatibility(
        EXAMPLE_DISTRIBUTIONS['iu'], distributions.IntUniformDistribution(low=-1, high=1))
github optuna / optuna / tests / test_distributions.py View on Github external
# test different distribution classes
    pytest.raises(
        ValueError, lambda: distributions.check_distribution_compatibility(
            EXAMPLE_DISTRIBUTIONS['u'], EXAMPLE_DISTRIBUTIONS['l']))

    # test dynamic value range (CategoricalDistribution)
    pytest.raises(
        ValueError, lambda: distributions.check_distribution_compatibility(
            EXAMPLE_DISTRIBUTIONS['c2'],
            distributions.CategoricalDistribution(choices=('Roppongi', 'Akasaka'))))

    # test dynamic value range (except CategoricalDistribution)
    distributions.check_distribution_compatibility(
        EXAMPLE_DISTRIBUTIONS['u'], distributions.UniformDistribution(low=-3.0, high=-2.0))
    distributions.check_distribution_compatibility(
        EXAMPLE_DISTRIBUTIONS['l'], distributions.LogUniformDistribution(low=-0.1, high=1.0))
    distributions.check_distribution_compatibility(
        EXAMPLE_DISTRIBUTIONS['du'],
        distributions.DiscreteUniformDistribution(low=-1.0, high=10.0, q=3.))
    distributions.check_distribution_compatibility(
        EXAMPLE_DISTRIBUTIONS['iu'], distributions.IntUniformDistribution(low=-1, high=1))
github optuna / optuna / optuna / trial.py View on Github external
def _is_relative_param(self, name, distribution):
        # type: (str, BaseDistribution) -> bool

        if name not in self.relative_params:
            return False

        if name not in self.relative_search_space:
            raise ValueError("The parameter '{}' was sampled by `sample_relative` method "
                             "but it is not contained in the relative search space.".format(name))

        relative_distribution = self.relative_search_space[name]
        distributions.check_distribution_compatibility(relative_distribution, distribution)

        param_value = self.relative_params[name]
        param_value_in_internal_repr = distribution.to_internal_repr(param_value)
        return distribution._contains(param_value_in_internal_repr)
github optuna / optuna / optuna / trial.py View on Github external
def _suggest(self, name, distribution):
        # type: (str, BaseDistribution) -> Any

        if name not in self._params:
            raise ValueError('The value of the parameter \'{}\' is not found. Please set it at '
                             'the construction of the FixedTrial object.'.format(name))

        value = self._params[name]
        param_value_in_internal_repr = distribution.to_internal_repr(value)
        if not distribution._contains(param_value_in_internal_repr):
            raise ValueError("The value {} of the parameter '{}' is out of "
                             "the range of the distribution {}.".format(value, name, distribution))

        if name in self._distributions:
            distributions.check_distribution_compatibility(self._distributions[name], distribution)

        self._suggested_params[name] = value
        self._distributions[name] = distribution

        return value
github optuna / optuna / optuna / trial.py View on Github external
def _suggest(self, name, distribution):
        # type: (str, BaseDistribution) -> Any

        if name not in self._params:
            raise ValueError('The value of the parameter \'{}\' is not found. Please set it at '
                             'the construction of the FixedTrial object.'.format(name))

        value = self._params[name]
        param_value_in_internal_repr = distribution.to_internal_repr(value)
        if not distribution._contains(param_value_in_internal_repr):
            raise ValueError("The value {} of the parameter '{}' is out of "
                             "the range of the distribution {}.".format(value, name, distribution))

        if name in self._distributions:
            distributions.check_distribution_compatibility(self._distributions[name], distribution)

        self._suggested_params[name] = value
        self._distributions[name] = distribution

        return value
github optuna / optuna / optuna / storages / rdb / storage.py View on Github external
def set_trial_param(self, trial_id, param_name, param_value_internal, distribution):
        # type: (int, str, float, distributions.BaseDistribution) -> bool

        session = self.scoped_session()

        trial = models.TrialModel.find_or_raise_by_id(trial_id, session)
        self.check_trial_is_updatable(trial_id, trial.state)

        trial_param = \
            models.TrialParamModel.find_by_trial_and_param_name(trial, param_name, session)

        if trial_param is not None:
            # Raise error in case distribution is incompatible.
            distributions.check_distribution_compatibility(
                distributions.json_to_distribution(trial_param.distribution_json), distribution)

            session.close()

            # Return False when distribution is compatible but parameter has already been set.
            return False

        param = models.TrialParamModel(
            trial_id=trial_id,
            param_name=param_name,
            param_value=param_value_internal,
            distribution_json=distributions.distribution_to_json(distribution))

        param.check_and_add(session)
        commit_success = self._commit_with_integrity_check(session)
github optuna / optuna / optuna / storages / rdb / models.py View on Github external
def _check_compatibility_with_previous_trial_param_distributions(self, session):
        # type: (orm.Session) -> None

        trial = TrialModel.find_or_raise_by_id(self.trial_id, session)

        previous_record = session.query(TrialParamModel).join(TrialModel). \
            filter(TrialModel.study_id == trial.study_id). \
            filter(TrialParamModel.param_name == self.param_name).first()
        if previous_record is not None:
            distributions.check_distribution_compatibility(
                distributions.json_to_distribution(previous_record.distribution_json),
                distributions.json_to_distribution(self.distribution_json))
github optuna / optuna / optuna / integration / cma.py View on Github external
def _is_compatible(self, trial):
        # type: (FrozenTrial) -> bool

        # Thanks to `intersection_search_space()` function, in sequential optimization,
        # the parameters of complete trials are always compatible with the search space.
        #
        # However, in distributed optimization, incompatible trials may complete on a worker
        # just after a product search space is calculated on another worker.

        for name, distribution in self._search_space.items():
            if name not in trial.params:
                return False

            distributions.check_distribution_compatibility(distribution, trial.distributions[name])
            param_value = trial.params[name]
            param_internal_value = distribution.to_internal_repr(param_value)
            if not distribution._contains(param_internal_value):
                return False

        return True
github optuna / optuna / optuna / storages / rdb / models.py View on Github external
def _check_compatibility_with_previous_trial_param_distributions(self, session):
        # type: (orm.Session) -> None

        trial = TrialModel.find_or_raise_by_id(self.trial_id, session)

        previous_record = session.query(TrialParamModel).join(TrialModel). \
            filter(TrialModel.study_id == trial.study_id). \
            filter(TrialParamModel.param_name == self.param_name).first()
        if previous_record is not None:
            distributions.check_distribution_compatibility(
                distributions.json_to_distribution(previous_record.distribution_json),
                distributions.json_to_distribution(self.distribution_json))