Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
import pytest
import re
import subprocess
from subprocess import CalledProcessError
import tempfile
import optuna
from optuna.cli import Studies
from optuna.storages.base import DEFAULT_STUDY_NAME_PREFIX
from optuna.storages import RDBStorage
from optuna.structs import CLIUsageError
from optuna.testing.storage import StorageSupplier
from optuna import type_checking
if type_checking.TYPE_CHECKING:
from typing import List # NOQA
from optuna.trial import Trial # NOQA
def test_create_study_command():
# type: () -> None
with StorageSupplier('new') as storage:
assert isinstance(storage, RDBStorage)
storage_url = str(storage.engine.url)
# Create study.
command = ['optuna', 'create-study', '--storage', storage_url]
subprocess.check_call(command)
'pytorch_ignite': ["PyTorchIgnitePruningHandler"],
'pytorch_lightning': ['PyTorchLightningPruningCallback'],
'sklearn': ['OptunaSearchCV'],
'mxnet': ['MXNetPruningCallback'],
'skopt': ['SkoptSampler'],
'tensorflow': ['TensorFlowPruningHook'],
'tfkeras': ['TFKerasPruningCallback'],
'xgboost': ['XGBoostPruningCallback'],
'fastai': ['FastAIPruningCallback'],
}
__all__ = list(_import_structure.keys()) + sum(_import_structure.values(), [])
if TYPE_CHECKING:
from optuna.integration.chainer import ChainerPruningExtension # NOQA
from optuna.integration.chainermn import ChainerMNStudy # NOQA
from optuna.integration.cma import CmaEsSampler # NOQA
from optuna.integration.fastai import FastAIPruningCallback # NOQA
from optuna.integration.keras import KerasPruningCallback # NOQA
from optuna.integration.lightgbm import LightGBMPruningCallback # NOQA
from optuna.integration.lightgbm import LightGBMTuner # NOQA
from optuna.integration.mxnet import MXNetPruningCallback # NOQA
from optuna.integration.pytorch_ignite import PyTorchIgnitePruningHandler # NOQA
from optuna.integration.pytorch_lightning import PyTorchLightningPruningCallback # NOQA
from optuna.integration.sklearn import OptunaSearchCV # NOQA
from optuna.integration.skopt import SkoptSampler # NOQA
from optuna.integration.tensorflow import TensorFlowPruningHook # NOQA
from optuna.integration.tfkeras import TFKerasPruningCallback # NOQA
from optuna.integration.xgboost import XGBoostPruningCallback # NOQA
else:
import abc
import json
import six
from optuna import type_checking
if type_checking.TYPE_CHECKING:
from typing import Any # NOQA
from typing import Dict # NOQA
from typing import Tuple # NOQA
from typing import Union # NOQA
@six.add_metaclass(abc.ABCMeta)
class BaseDistribution(object):
"""Base class for distributions.
Note that distribution classes are not supposed to be called by library users.
They are used by :class:`~optuna.trial.Trial` and :class:`~optuna.samplers` internally.
"""
def to_external_repr(self, param_value_in_internal_repr):
# type: (float) -> Any
from optuna import type_checking
if type_checking.TYPE_CHECKING:
from typing import Any # NOQA
from typing import Dict # NOQA
from typing import List # NOQA
ALIAS_GROUP_LIST = [
{
'param_name': 'bagging_fraction',
'alias_names': ['sub_row', 'subsample', 'bagging'],
'default_value': None,
},
{
'param_name': 'learning_rate',
'alias_names': ['shrinkage_rate', 'eta'],
'default_value': 0.1, # Start from large `learning_rate` value.
},
import optuna
from optuna import type_checking
if type_checking.TYPE_CHECKING:
from optuna.trial import Trial # NOQA
try:
from ignite.engine import Engine # NOQA
_available = True
except ImportError as e:
_import_error = e
# PyTorchIgnitePruningHandler is disabled because pytorch-ignite is not available.
_available = False
class PyTorchIgnitePruningHandler(object):
"""PyTorch Ignite handler to prune unpromising trials.
Example:
import copy
from datetime import datetime
import threading
from optuna import distributions # NOQA
from optuna.storages import base
from optuna.storages.base import DEFAULT_STUDY_NAME_PREFIX
from optuna import structs
from optuna import type_checking
if type_checking.TYPE_CHECKING:
from typing import Any # NOQA
from typing import Dict # NOQA
from typing import List # NOQA
from typing import Optional # NOQA
IN_MEMORY_STORAGE_STUDY_ID = 0
IN_MEMORY_STORAGE_STUDY_UUID = '00000000-0000-0000-0000-000000000000'
class InMemoryStorage(base.BaseStorage):
"""Storage class that stores data in memory of the Python process.
This class is not supposed to be directly accessed by library users.
"""
def __init__(self):
from optuna.integration.lightgbm_tuner.sklearn import LGBMClassifier, LGBMModel, LGBMRegressor # NOQA
from optuna.integration.lightgbm_tuner.optimize import LightGBMTuner
from optuna import type_checking
if type_checking.TYPE_CHECKING:
from type_checking import Any # NOQA
from type_checking import Dict # NOQA
from type_checking import List # NOQA
from type_checking import Optional # NOQA
def train(*args, **kwargs):
# type: (List[Any], Optional[Dict[Any, Any]]) -> Any
"""Wrapper function of LightGBM API: train()
Arguments and keyword arguments for `lightgbm.train()` can be passed.
"""
auto_booster = LightGBMTuner(*args, **kwargs)
booster = auto_booster.run()
return booster