Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_all(session):
# type: (Session) -> None
study = StudyModel(study_id=1, study_name='test-study', direction=StudyDirection.MINIMIZE)
trial = TrialModel(trial_id=1, study_id=study.study_id, state=TrialState.COMPLETE)
session.add(
TrialSystemAttributeModel(trial_id=trial.trial_id, key='sample-key', value_json='1'))
session.commit()
system_attributes = TrialSystemAttributeModel.all(session)
assert 1 == len(system_attributes)
assert 'sample-key' == system_attributes[0].key
assert '1' == system_attributes[0].value_json
def test_cascade_delete_on_study(session):
# type: (Session) -> None
study_id = 1
study = StudyModel(study_id=study_id, study_name='test-study',
direction=StudyDirection.MINIMIZE)
study.trials.append(TrialModel(study_id=study.study_id, state=TrialState.COMPLETE))
study.trials.append(TrialModel(study_id=study.study_id, state=TrialState.RUNNING))
session.add(study)
session.commit()
assert 2 == len(TrialModel.where_study(study, session))
session.delete(study)
session.commit()
assert 0 == len(TrialModel.where_study(study, session))
def test_where_study(session):
# type: (Session) -> None
study = StudyModel(study_id=1, study_name='test-study', direction=StudyDirection.MINIMIZE)
trial = TrialModel(trial_id=1, study_id=study.study_id, state=TrialState.COMPLETE)
session.add(study)
session.add(trial)
session.add(
TrialUserAttributeModel(trial_id=trial.trial_id, key='sample-key', value_json='1'))
session.commit()
user_attributes = TrialUserAttributeModel.where_study(study, session)
assert 1 == len(user_attributes)
assert 'sample-key' == user_attributes[0].key
assert '1' == user_attributes[0].value_json
# type: (Session) -> None
study_1 = StudyModel(study_id=1, study_name='test-study-1')
study_2 = StudyModel(study_id=2, study_name='test-study-2')
trial_1_1 = TrialModel(study_id=study_1.study_id, state=TrialState.COMPLETE)
session.add(trial_1_1)
session.commit()
assert 0 == trial_1_1.count_past_trials(session)
trial_1_2 = TrialModel(study_id=study_1.study_id, state=TrialState.RUNNING)
session.add(trial_1_2)
session.commit()
assert 1 == trial_1_2.count_past_trials(session)
trial_2_1 = TrialModel(study_id=study_2.study_id, state=TrialState.RUNNING)
session.add(trial_2_1)
session.commit()
assert 0 == trial_2_1.count_past_trials(session)
def test_cascade_delete_on_trial(session):
# type: (Session) -> None
trial_id = 1
study = StudyModel(study_id=1, study_name='test-study', direction=StudyDirection.MINIMIZE)
trial = TrialModel(trial_id=trial_id, study_id=study.study_id, state=TrialState.COMPLETE)
trial.user_attributes.append(TrialUserAttributeModel(
trial_id=trial_id, key='sample-key1', value_json='1'))
trial.user_attributes.append(TrialUserAttributeModel(
trial_id=trial_id, key='sample-key2', value_json='2'))
study.trials.append(trial)
session.add(study)
session.commit()
assert 2 == len(TrialUserAttributeModel.where_trial_id(trial_id, session))
session.delete(trial)
session.commit()
assert 0 == len(TrialUserAttributeModel.where_trial_id(trial_id, session))
def test_where_trial(session):
# type: (Session) -> None
study = StudyModel(study_id=1, study_name='test-study', direction=StudyDirection.MINIMIZE)
trial = TrialModel(trial_id=1, study_id=study.study_id, state=TrialState.COMPLETE)
session.add(
TrialUserAttributeModel(trial_id=trial.trial_id, key='sample-key', value_json='1'))
session.commit()
user_attributes = TrialUserAttributeModel.where_trial(trial, session)
assert 1 == len(user_attributes)
assert 'sample-key' == user_attributes[0].key
assert '1' == user_attributes[0].value_json
def _get_all_trial_ids(self, study_id):
# type: (int) -> List[int]
session = self.scoped_session()
study = models.StudyModel.find_or_raise_by_id(study_id, session)
trial_ids = models.TrialModel.get_all_trial_ids_where_study(study, session)
# Terminate transaction explicitly to avoid connection timeout during transaction.
self._commit(session)
return trial_ids
def _get_all_trial_ids(self, study_id):
# type: (int) -> List[int]
session = self.scoped_session()
study = models.StudyModel.find_or_raise_by_id(study_id, session)
trial_ids = models.TrialModel.get_all_trial_ids_where_study(study, session)
session.close()
return trial_ids
def _get_all_trials_without_cache(self, study_id):
# type: (int) -> List[structs.FrozenTrial]
session = self.scoped_session()
study = models.StudyModel.find_or_raise_by_id(study_id, session)
trials = models.TrialModel.where_study(study, session)
params = models.TrialParamModel.where_study(study, session)
values = models.TrialValueModel.where_study(study, session)
user_attributes = models.TrialUserAttributeModel.where_study(study, session)
system_attributes = models.TrialSystemAttributeModel.where_study(study, session)
session.close()
return self._merge_trials_orm(trials, params, values, user_attributes, system_attributes)
def where_study(cls, study, session):
# type: (StudyModel, orm.Session) -> List[TrialParamModel]
trial_params = session.query(cls).join(TrialModel). \
filter(TrialModel.study_id == study.study_id).all()
return trial_params