Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_state_continue_h5():
inter_file = os.path.join(state_dir, "test_1_inst.h5")
if not os.path.isfile(inter_file):
reviewer = get_reviewer(
data_fp, mode="simulate", model="nb", embedding_fp=embedding_fp,
prior_idx=[1, 2, 3, 4], state_file=inter_file,
n_instances=1, n_queries=1)
reviewer.review()
copyfile(inter_file, h5_state_file)
check_model(mode="simulate", model="nb", state_file=h5_state_file,
continue_from_state=True, n_instances=1, n_queries=2)
def test_log_continue_h5(monkeypatch):
inter_file = os.path.join(log_dir, "test_1_inst.h5")
if not os.path.isfile(inter_file):
reviewer = get_reviewer(
data_fp, mode="simulate", model="nb", embedding_fp=embedding_fp,
prior_included=[1, 3], prior_excluded=[2, 4], log_file=inter_file,
n_instances=1, n_queries=1)
reviewer.review()
copyfile(inter_file, h5_log_file)
check_model(monkeypatch, model="nb", log_file=h5_log_file,
continue_from_log=True, n_instances=1, n_queries=2)
def check_model(monkeypatch=None, use_granular=False, log_file=h5_log_file,
continue_from_log=False, mode="oracle", **kwargs):
if not continue_from_log:
try:
if log_file is not None:
os.unlink(log_file)
except OSError:
pass
if monkeypatch is not None:
monkeypatch.setattr('builtins.input', lambda _: "0")
# start the review process.
reviewer = get_reviewer(data_fp, mode=mode, embedding_fp=embedding_fp,
prior_included=[1, 3], prior_excluded=[2, 4],
log_file=log_file,
**kwargs)
if use_granular:
with Logger.from_file(log_file) as logger:
# Two loops of training and classification.
reviewer.train()
reviewer.log_probabilities(logger)
query_idx = reviewer.query(1)
inclusions = reviewer._get_labels(query_idx)
reviewer.classify(query_idx, inclusions, logger)
reviewer.train()
reviewer.log_probabilities(logger)
query_idx = reviewer.query(1)
inclusions = reviewer._get_labels(query_idx)
def test_state_continue_json():
inter_file = os.path.join(state_dir, "test_1_inst.json")
if not os.path.isfile(inter_file):
reviewer = get_reviewer(
data_fp, mode="simulate", model="nb", embedding_fp=embedding_fp,
prior_idx=[1, 2, 3, 4], state_file=inter_file,
n_instances=1, n_queries=1)
reviewer.review()
copyfile(inter_file, json_state_file)
check_model(mode="simulate", model="nb", state_file=json_state_file,
continue_from_state=True, n_instances=1, n_queries=2)
def test_no_seed():
n_test_max = 100
as_data = ASReviewData.from_file(data_fp)
n_priored = np.zeros(len(as_data), dtype=int)
for _ in range(n_test_max):
reviewer = get_reviewer(
data_fp, mode="simulate", model="nb", state_file=None,
init_seed=None, n_prior_excluded=1, n_prior_included=1)
assert len(reviewer.start_idx) == 2
n_priored[reviewer.start_idx] += 1
if np.all(n_priored > 0):
return
raise ValueError(f"Error getting all priors in {n_test_max} iterations.")
def check_model(monkeypatch=None, use_granular=False, state_file=h5_state_file,
continue_from_state=False, mode="oracle", data_fp=data_fp,
**kwargs):
if not continue_from_state:
try:
if state_file is not None:
os.unlink(state_file)
except OSError:
pass
if monkeypatch is not None:
monkeypatch.setattr('builtins.input', lambda _: "0")
# start the review process.
reviewer = get_reviewer(data_fp, mode=mode, embedding_fp=embedding_fp,
prior_idx=[1, 2, 3, 4],
state_file=state_file,
**kwargs)
if use_granular:
with open_state(state_file) as state:
# Two loops of training and classification.
reviewer.train()
reviewer.log_probabilities(state)
query_idx = reviewer.query(1)
inclusions = reviewer._get_labels(query_idx)
reviewer.classify(query_idx, inclusions, state)
reviewer.train()
reviewer.log_probabilities(state)
query_idx = reviewer.query(1)
inclusions = reviewer._get_labels(query_idx)
# Lock the current state. We want to have a consistent active state.
# This does communicate with the flask backend; it prevents writing and
# reading to the same files at the same time.
with SQLiteLock(lock_file, blocking=True, lock_name="active") as lock:
# Get the all labels since last run. If no new labels, quit.
new_label_history = read_label_history(project_id)
data_fp = str(get_data_file_path(project_id))
as_data = read_data(project_id)
state_file = get_state_path(project_id)
# collect command line arguments and pass them to the reviewer
with open(asr_kwargs_file, "r") as fp:
asr_kwargs = json.load(fp)
asr_kwargs['state_file'] = str(state_file)
reviewer = get_reviewer(dataset=data_fp,
mode="minimal",
**asr_kwargs)
with open_state(state_file) as state:
old_label_history = get_label_train_history(state)
diff_history = get_diff_history(new_label_history, old_label_history)
if len(diff_history) == 0:
logging.info("No new labels since last run.")
return
query_idx = np.array([x[0] for x in diff_history], dtype=int)
inclusions = np.array([x[1] for x in diff_history], dtype=int)
# Classify the new labels, train and store the results.