How to use the optuna.pruners.MedianPruner function in optuna

To help you get started, we’ve selected a few optuna examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github optuna / optuna / tests / pruners_tests / test_median.py View on Github external
def test_median_pruner_n_warmup_steps():
    # type: () -> None

    pruner = optuna.pruners.MedianPruner(0, 1)
    study = optuna.study.create_study()

    trial = optuna.trial.Trial(study, study._storage.create_new_trial_id(study.study_id))
    trial.report(1, 1)
    trial.report(1, 2)
    study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)

    trial = optuna.trial.Trial(study, study._storage.create_new_trial_id(study.study_id))
    trial.report(2, 1)
    # A pruner is not activated during warm-up steps.
    assert not pruner.prune(
        study=study, trial=study._storage.get_trial(trial.trial_id), step=1)

    trial.report(2, 2)
    # A pruner is activated after warm-up steps.
    assert pruner.prune(
github optuna / optuna / tests / pruners_tests / test_median.py View on Github external
def test_median_pruner_intermediate_values(direction_value):
    # type: (Tuple[str, float]) -> None

    direction, intermediate_value = direction_value
    pruner = optuna.pruners.MedianPruner(0, 0)
    study = optuna.study.create_study(direction=direction)

    trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
    trial.report(1, 1)
    study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)

    trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
    # A pruner is not activated if a trial has no intermediate values.
    assert not pruner.prune(
        study=study, trial=study._storage.get_trial(trial._trial_id))

    trial.report(intermediate_value, 1)
    # A pruner is activated if a trial has an intermediate value.
    assert pruner.prune(
        study=study, trial=study._storage.get_trial(trial._trial_id))
github optuna / optuna / tests / pruners_tests / test_median.py View on Github external
def test_median_pruner_interval_steps(
        n_warmup_steps, interval_steps, report_steps, expected_prune_steps):
    # type: (int, int, int, List[int]) -> None

    pruner = optuna.pruners.MedianPruner(0, n_warmup_steps, interval_steps)
    study = optuna.study.create_study()

    trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
    n_steps = max(expected_prune_steps)
    base_index = 1
    for i in range(base_index, base_index + n_steps):
        trial.report(base_index, i)
    study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)

    trial = optuna.trial.Trial(study, study._storage.create_new_trial(study._study_id))
    for i in range(base_index, base_index + n_steps):
        if (i - base_index) % report_steps == 0:
            trial.report(2, i)
        assert (pruner.prune(study=study, trial=study._storage.get_trial(trial._trial_id))
                == (i > n_warmup_steps and i in expected_prune_steps))
github optuna / optuna / tests / pruners_tests / test_median.py View on Github external
def test_median_pruner_intermediate_values_nan():
    # type: () -> None

    pruner = optuna.pruners.MedianPruner(0, 0)
    study = optuna.study.create_study()

    trial = optuna.trial.Trial(study, study._storage.create_new_trial_id(study.study_id))
    trial.report(float('nan'), 1)
    # A pruner is not activated if the study does not have any previous trials.
    assert not pruner.prune(
        study=study, trial=study._storage.get_trial(trial.trial_id), step=1)
    study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)

    trial = optuna.trial.Trial(study, study._storage.create_new_trial_id(study.study_id))
    trial.report(float('nan'), 1)
    # A pruner is activated if the best intermediate value of this trial is NaN.
    assert pruner.prune(
        study=study, trial=study._storage.get_trial(trial.trial_id), step=1)
    study._storage.set_trial_state(trial._trial_id, TrialState.COMPLETE)
github abides-sim / abides / calibration / calibrate.py View on Github external
log.info(f'Study : {study_name}')

    n_trials = 100
    n_jobs = psutil.cpu_count()

    log.info(f'Number of Trials : {n_trials}')
    log.info(f'Number of Parallel Jobs : {n_jobs}')

    # sampler = RandomSampler(seed=seed)
    sampler = TPESampler(seed=SEED)  # Make the sampler behave in a deterministic way.

    # study: A study corresponds to an optimization task, i.e., a set of trials.
    study = optuna.create_study(study_name=study_name,
                                direction='maximize',
                                sampler=sampler,
                                pruner=optuna.pruners.MedianPruner(),
                                storage=f'sqlite:///{study_name}.db',
                                load_if_exists=True)
    study.optimize(objective,
                   n_trials=n_trials,
                   n_jobs=n_jobs,
                   show_progress_bar=True)

    log.info(f'Best Parameters: {study.best_params}')
    log.info(f'Best Value: {study.best_value}')

    df = study.trials_dataframe()

    df.to_pickle(f'{study_name}_df.bz2')

    end_time = dt.datetime.now()
    log.info(f'Total time taken for the study: {end_time - start_time}')
github optuna / optuna / examples / pruning / xgboost_integration.py View on Github external
param['sample_type'] = trial.suggest_categorical('sample_type', ['uniform', 'weighted'])
        param['normalize_type'] = trial.suggest_categorical('normalize_type', ['tree', 'forest'])
        param['rate_drop'] = trial.suggest_loguniform('rate_drop', 1e-8, 1.0)
        param['skip_drop'] = trial.suggest_loguniform('skip_drop', 1e-8, 1.0)

    # Add a callback for pruning.
    pruning_callback = optuna.integration.XGBoostPruningCallback(trial, 'validation-auc')
    bst = xgb.train(param, dtrain, evals=[(dtest, 'validation')], callbacks=[pruning_callback])
    preds = bst.predict(dtest)
    pred_labels = np.rint(preds)
    accuracy = sklearn.metrics.accuracy_score(test_y, pred_labels)
    return accuracy


if __name__ == '__main__':
    study = optuna.create_study(pruner=optuna.pruners.MedianPruner(n_warmup_steps=5),
                                direction='maximize')
    study.optimize(objective, n_trials=100)
    print(study.best_trial)
github optuna / optuna / examples / pytorch_ignite_simple.py View on Github external
.format(engine.state.epoch, train_acc, validation_acc)
        )

    trainer.run(train_loader, max_epochs=EPOCHS)

    evaluator.run(val_loader)
    return evaluator.state.metrics['accuracy']


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='PyTorch Ignite example.')
    parser.add_argument('--pruning', '-p', action='store_true',
                        help='Activate the pruning feature. `MedianPruner` stops unpromising '
                             'trials at the early stages of training.')
    args = parser.parse_args()
    pruner = optuna.pruners.MedianPruner() if args.pruning else optuna.pruners.NopPruner()

    study = optuna.create_study(direction='maximize', pruner=pruner)
    study.optimize(objective, n_trials=100, timeout=600)

    print('Number of finished trials: ', len(study.trials))

    print('Best trial:')
    trial = study.best_trial

    print('  Value: ', trial.value)

    print('  Params: ')
    for key, value in trial.params.items():
        print('    {}: {}'.format(key, value))
github optuna / optuna / optuna / study.py View on Github external
def __init__(
            self,
            study_name,  # type: str
            storage,  # type: Union[str, storages.BaseStorage]
            sampler=None,  # type: samplers.BaseSampler
            pruner=None  # type: pruners.BasePruner
    ):
        # type: (...) -> None

        self.study_name = study_name
        storage = storages.get_storage(storage)
        study_id = storage.get_study_id_from_name(study_name)
        super(Study, self).__init__(study_id, storage)

        self.sampler = sampler or samplers.TPESampler()
        self.pruner = pruner or pruners.MedianPruner()

        self.logger = logging.get_logger(__name__)

        self._optimize_lock = threading.Lock()