Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def get_trial(self, index):
"""Return a Trial"""
return Trial(**self.trials[index])
def test_register_duplicate_trial(self, storage):
"""Test register trial"""
with OrionState(
experiments=[base_experiment], trials=[base_trial], database=storage) as cfg:
storage = cfg.storage()
with pytest.raises(DuplicateKeyError):
storage.register_trial(Trial(**base_trial))
def dummy_trial():
"""Return a dummy trial object"""
trial = Trial()
trial.params = [
Trial.Param(name='a', type='real', value=0.0),
Trial.Param(name='b', type='integer', value=1),
Trial.Param(name='c', type='categorical', value='Some')]
return trial
def _set_tables(self):
self.trials = []
self.lies = []
for exp in self._experiments:
get_storage().create_experiment(exp)
for trial in self._trials:
nt = get_storage().register_trial(Trial(**trial))
self.trials.append(nt.to_dict())
for lie in self._lies:
nt = get_storage().register_lie(Trial(**lie))
self.lies.append(nt.to_dict())
def load_experience_configuration(self):
"""Load an example database."""
for i, t_dict in enumerate(self._trials):
self._trials[i] = Trial(**t_dict).to_dict()
for i, t_dict in enumerate(self._lies):
self._lies[i] = Trial(**t_dict).to_dict()
self._trials.sort(key=lambda obj: int(obj['_id'], 16), reverse=True)
for i, experiment in enumerate(self._experiments):
if 'user_script' in experiment['metadata']:
path = os.path.join(
os.path.dirname(__file__),
experiment["metadata"]["user_script"])
experiment["metadata"]["user_script"] = path
experiment['_id'] = i
self._set_tables()
def test_init_full(self, exp_config):
"""Initialize with a dictionary with complete specification."""
t = Trial(**exp_config[1][1])
assert t.experiment == exp_config[1][1]['experiment']
assert t.status == exp_config[1][1]['status']
assert t.worker == exp_config[1][1]['worker']
assert t.submit_time == exp_config[1][1]['submit_time']
assert t.start_time == exp_config[1][1]['start_time']
assert t.end_time == exp_config[1][1]['end_time']
assert list(map(lambda x: x.to_dict(), t.results)) == exp_config[1][1]['results']
assert t.results[0].name == exp_config[1][1]['results'][0]['name']
assert t.results[0].type == exp_config[1][1]['results'][0]['type']
assert t.results[0].value == exp_config[1][1]['results'][0]['value']
assert list(map(lambda x: x.to_dict(), t.params)) == exp_config[1][1]['params']
assert t.working_dir is None
def test_push_trial_results(self, storage):
"""Successfully push a completed trial into database."""
with OrionState(experiments=[], trials=[base_trial], database=storage) as cfg:
storage = cfg.storage()
trial = storage.get_trial(Trial(**base_trial))
results = [
Trial.Result(name='loss', type='objective', value=2)
]
trial.results = results
assert storage.push_trial_results(trial), 'should update successfully'
trial2 = storage.get_trial(trial)
assert trial2.results == results
def test_format_commandline_and_config(parser, commandline, json_config, tmpdir, json_converter):
"""Format the commandline and a configuration file."""
cmd_args = json_config
cmd_args.extend(commandline)
parser.parse(cmd_args)
trial = Trial(params=[
{'name': '/lr', 'type': 'real', 'value': -2.4},
{'name': '/prior', 'type': 'categorical', 'value': 'sgd'},
{'name': '/layers/1/width', 'type': 'integer', 'value': 100},
{'name': '/layers/1/type', 'type': 'categorical', 'value': 'relu'},
{'name': '/layers/2/type', 'type': 'categorical', 'value': 'sigmoid'},
{'name': '/training/lr0', 'type': 'real', 'value': 0.032},
{'name': '/training/mbs', 'type': 'integer', 'value': 64},
{'name': '/something-same', 'type': 'categorical', 'value': '3'}])
output_file = str(tmpdir.join("output.json"))
cmd_inst = parser.format(output_file, trial)
assert cmd_inst == ['--config', output_file, "--seed", "555", "--lr", "-2.4",
"--non-prior", "choices({'sgd': 0.2, 'adam': 0.8})", "--prior", "sgd"]
def test_format_with_properties(parser, cmd_with_properties, hacked_exp):
"""Test if format correctly puts the value of `trial` and `exp` when used as properties"""
parser.parse(cmd_with_properties)
trial = Trial(experiment='trial_test', params=[
{'name': '/lr', 'type': 'real', 'value': -2.4},
{'name': '/prior', 'type': 'categorical', 'value': 'sgd'}])
cmd_line = parser.format(None, trial=trial, experiment=hacked_exp)
print(cmd_line)
assert trial.hash_name in cmd_line
assert 'supernaedo2-dendi' in cmd_line
def _fetch_trials(self, query, selection=None):
"""See :func:`~orion.storage.BaseStorageProtocol.fetch_trials`"""
def sort_key(item):
submit_time = item.submit_time
if submit_time is None:
return 0
return submit_time
trials = Trial.build(self._db.read('trials', query=query, selection=selection))
trials.sort(key=sort_key)
return trials