Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def create_tornasole_hook(out_dir, train_data=None, validation_data=None, frequency=1):
save_config = SaveConfig(save_interval=frequency)
hook = Hook(
out_dir=out_dir,
save_config=save_config,
train_data=train_data,
validation_data=validation_data,
)
return hook
def test_hook(tmpdir):
save_config = SaveConfig(save_steps=[0, 1, 2, 3])
out_dir = os.path.join(tmpdir, str(uuid.uuid4()))
hook = Hook(out_dir=out_dir, save_config=save_config)
assert has_training_ended(out_dir) is False
run_xgboost_model(hook=hook)
def test_hook_save_every_step(tmpdir):
save_config = SaveConfig(save_interval=1)
out_dir = os.path.join(tmpdir, str(uuid.uuid4()))
hook = Hook(out_dir=out_dir, save_config=save_config)
run_xgboost_model(hook=hook)
trial = create_trial(out_dir)
assert trial.steps() == list(range(10))
def test_hook_shap(tmpdir):
np.random.seed(42)
train_data = np.random.rand(10, 10)
train_label = np.random.randint(2, size=10)
dtrain = xgboost.DMatrix(train_data, label=train_label)
out_dir = os.path.join(tmpdir, str(uuid.uuid4()))
hook = Hook(
out_dir=out_dir, include_collections=["average_shap", "full_shap"], train_data=dtrain
)
run_xgboost_model(hook=hook)
trial = create_trial(out_dir)
tensors = trial.tensor_names()
assert len(tensors) > 0
assert "average_shap" in trial.collections()
assert "full_shap" in trial.collections()
assert any(t.startswith("average_shap/") for t in tensors)
assert any(t.startswith("full_shap/") for t in tensors)
assert not any(t.endswith("/bias") for t in tensors)
average_shap_tensors = [t for t in tensors if t.startswith("average_shap/")]
average_shap_tensor_name = average_shap_tensors.pop()
assert trial.tensor(average_shap_tensor_name).value(0).shape == (1,)
full_shap_tensors = [t for t in tensors if t.startswith("full_shap/")]
def helper_xgboost_tests(collection, save_config):
coll_name, coll_regex = collection
run_id = "trial_" + coll_name + "-" + datetime.now().strftime("%Y%m%d-%H%M%S%f")
trial_dir = os.path.join(SMDEBUG_XG_HOOK_TESTS_DIR, run_id)
hook = XG_Hook(
out_dir=trial_dir,
include_collections=[coll_name],
save_config=save_config,
export_tensorboard=True,
)
simple_xg_model(hook)
hook.close()
saved_scalars = ["scalar/xg_num_steps", "scalar/xg_before_train", "scalar/xg_after_train"]
verify_files(trial_dir, save_config, saved_scalars)
def test_hook_validation(tmpdir):
np.random.seed(42)
train_data = np.random.rand(5, 10)
train_label = np.random.randint(2, size=5)
dtrain = xgboost.DMatrix(train_data, label=train_label)
valid_data = np.random.rand(5, 10)
valid_label = np.random.randint(2, size=5)
dvalid = xgboost.DMatrix(valid_data, label=valid_label)
out_dir = os.path.join(tmpdir, str(uuid.uuid4()))
hook = Hook(
out_dir=out_dir,
include_collections=["labels", "predictions"],
train_data=dtrain,
validation_data=dvalid,
)
run_xgboost_model(hook=hook)
trial = create_trial(out_dir)
tensors = trial.tensor_names()
assert len(tensors) > 0
assert "labels" in trial.collections()
assert "predictions" in trial.collections()
assert "labels" in tensors
assert "predictions" in tensors
def create_hook(out_dir, train_data=None, validation_data=None, frequency=1):
save_config = SaveConfig(save_interval=frequency)
hook = Hook(
out_dir=out_dir,
save_config=save_config,
train_data=train_data,
validation_data=validation_data,
)
return hook
def create_hook(out_dir, train_data=None, validation_data=None, frequency=1):
save_config = SaveConfig(save_interval=frequency)
hook = Hook(
out_dir=out_dir,
save_config=save_config,
train_data=train_data,
validation_data=validation_data,
)
return hook