Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
import itertools
import numpy as np
import pytest
from optuna.samplers.tpe.parzen_estimator import _ParzenEstimator
from optuna.samplers.tpe.sampler import default_weights
from optuna import type_checking
if type_checking.TYPE_CHECKING:
from typing import Dict # NOQA
from typing import List # NOQA
class TestParzenEstimator(object):
@staticmethod
@pytest.mark.parametrize(
'mus, prior, magic_clip, endpoints',
itertools.product(
([], [0.4], [-0.4, 0.4]), # mus
(True, False), # prior
(True, False), # magic_clip
(True, False), # endpoints
))
def test_calculate_shape_check(mus, prior, magic_clip, endpoints):
# type: (List[float], bool, bool, bool) -> None
import copy
import json
import pytest
from optuna import distributions
from optuna import type_checking
if type_checking.TYPE_CHECKING:
from typing import Any # NOQA
from typing import Dict # NOQA
from typing import List # NOQA
EXAMPLE_DISTRIBUTIONS = {
'u': distributions.UniformDistribution(low=1., high=2.),
'l': distributions.LogUniformDistribution(low=0.001, high=100),
'du': distributions.DiscreteUniformDistribution(low=1., high=10., q=2.),
'iu': distributions.IntUniformDistribution(low=1, high=10),
'c1': distributions.CategoricalDistribution(choices=(2.71, -float('inf'))),
'c2': distributions.CategoricalDistribution(choices=('Roppongi', 'Azabu'))
} # type: Dict[str, Any]
EXAMPLE_JSONS = {
'u': '{"name": "UniformDistribution", "attributes": {"low": 1.0, "high": 2.0}}',
'l': '{"name": "LogUniformDistribution", "attributes": {"low": 0.001, "high": 100}}',
from optuna.study import create_study
from optuna.testing.visualization import prepare_study_with_trials
from optuna import type_checking
from optuna.visualization.intermediate_values import plot_intermediate_values
if type_checking.TYPE_CHECKING:
from optuna.trial import Trial # NOQA
def test_plot_intermediate_values():
# type: () -> None
# Test with no trials.
study = prepare_study_with_trials(no_trials=True)
figure = plot_intermediate_values(study)
assert not figure.data
def objective(trial, report_intermediate_values):
# type: (Trial, bool) -> float
if report_intermediate_values:
trial.report(1.0, step=0)
from ignite.engine import Engine
from mock import Mock
from mock import patch
import pytest
import optuna
from optuna.testing.integration import create_running_trial
from optuna.testing.integration import DeterministicPruner
from optuna import type_checking
if type_checking.TYPE_CHECKING:
from typing import Iterable # NOQA
def test_pytorch_ignite_pruning_handler():
# type: () -> None
def update(engine, batch):
# type: (Engine, Iterable) -> None
pass
trainer = Engine(update)
# The pruner is activated.
study = optuna.create_study(pruner=DeterministicPruner(True))
trial = create_running_trial(study, 1.0)
from mock import patch
import numpy as np
import pytest
import warnings
from optuna import distributions
from optuna import samplers
from optuna import storages
from optuna.study import create_study
from optuna.testing.integration import DeterministicPruner
from optuna.testing.sampler import DeterministicRelativeSampler
from optuna.trial import FixedTrial
from optuna.trial import Trial
from optuna import type_checking
if type_checking.TYPE_CHECKING:
from datetime import datetime # NOQA
import typing # NOQA
parametrize_storage = pytest.mark.parametrize(
'storage_init_func',
[storages.InMemoryStorage, lambda: storages.RDBStorage('sqlite:///:memory:')])
@parametrize_storage
def test_suggest_uniform(storage_init_func):
# type: (typing.Callable[[], storages.BaseStorage]) -> None
mock = Mock()
mock.side_effect = [1., 2., 3.]
sampler = samplers.RandomSampler()
import pytest
from optuna.distributions import LogUniformDistribution
from optuna.structs import StudyDirection
from optuna.study import create_study
from optuna.testing.visualization import prepare_study_with_trials
from optuna import type_checking
from optuna.visualization.contour import _generate_contour_subplot
from optuna.visualization.contour import plot_contour
if type_checking.TYPE_CHECKING:
from typing import List, Optional # NOQA
from optuna.trial import Trial # NOQA
@pytest.mark.parametrize(
'params', [
[],
['param_a'],
['param_a', 'param_b'],
['param_a', 'param_b', 'param_c'],
['param_a', 'param_b', 'param_c', 'param_d'],
None,
]
)
def test_plot_contour(params):
from optuna.logging import get_logger
from optuna.structs import StudyDirection
from optuna.structs import TrialState
from optuna import type_checking
from optuna.visualization.utils import _check_plotly_availability
from optuna.visualization.utils import is_available
if type_checking.TYPE_CHECKING:
from optuna.study import Study # NOQA
if is_available():
from optuna.visualization.plotly_imports import go
logger = get_logger(__name__)
def plot_optimization_history(study):
# type: (Study) -> go.Figure
"""Plot optimization history of all trials in a study.
Example:
The following code snippet shows how to plot optimization history.
import abc
from optuna import structs
from optuna import type_checking
if type_checking.TYPE_CHECKING:
from typing import Any # NOQA
from typing import Dict # NOQA
from typing import List # NOQA
from typing import Optional # NOQA
from optuna import distributions # NOQA
DEFAULT_STUDY_NAME_PREFIX = 'no-name-'
class BaseStorage(object, metaclass=abc.ABCMeta):
"""Base class for storages.
This class is not supposed to be directly accessed by library users.
Storage classes abstract a backend database and provide library internal interfaces to
from optuna.logging import get_logger
from optuna.structs import TrialState
from optuna import type_checking
from optuna.visualization.utils import _check_plotly_availability
from optuna.visualization.utils import _is_log_scale
from optuna.visualization.utils import is_available
if type_checking.TYPE_CHECKING:
from typing import List # NOQA
from typing import Optional # NOQA
from optuna.structs import FrozenTrial # NOQA
from optuna.study import Study # NOQA
from optuna.visualization.plotly_imports import Scatter # NOQA
if is_available():
from optuna.visualization.plotly_imports import go
from optuna.visualization.plotly_imports import make_subplots
logger = get_logger(__name__)
def plot_slice(study, params=None):
# type: (Study, Optional[List[str]]) -> go.Figure
from optuna.logging import get_logger
from optuna.structs import TrialState
from optuna import type_checking
from optuna.visualization.utils import _check_plotly_availability
from optuna.visualization.utils import is_available
if type_checking.TYPE_CHECKING:
from optuna.study import Study # NOQA
if is_available():
from optuna.visualization.plotly_imports import go
logger = get_logger(__name__)
def plot_intermediate_values(study):
# type: (Study) -> go.Figure
"""Plot intermediate values of all trials in a study.
Example:
The following code snippet shows how to plot intermediate values.