Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
bst = lgb.Booster(params, model_file="model.txt")
os.remove("model.txt")
pred_from_model_file = bst.predict(X_test)
# we need to check the consistency of model file here, so test for exact equal
np.testing.assert_array_equal(pred_from_matr, pred_from_model_file)
# check early stopping is working. Make it stop very early, so the scores should be very close to zero
pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 1.5}
pred_early_stopping = bst.predict(X_test, **pred_parameter)
# scores likely to be different, but prediction should still be the same
np.testing.assert_array_equal(np.sign(pred_from_matr), np.sign(pred_early_stopping))
# test that shape is checked during prediction
bad_X_test = X_test[:, 1:]
bad_shape_error_msg = "The number of features in data*"
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, bad_X_test)
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, sparse.csr_matrix(bad_X_test))
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, sparse.csc_matrix(bad_X_test))
with open(tname, "w+b") as f:
dump_svmlight_file(bad_X_test, y_test, f)
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, tname)
with open(tname, "w+b") as f:
dump_svmlight_file(X_test, y_test, f, zero_based=False)
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, tname)
os.remove(tname)
np.testing.assert_array_equal(pred_from_matr, pred_from_model_file)
# check early stopping is working. Make it stop very early, so the scores should be very close to zero
pred_parameter = {"pred_early_stop": True, "pred_early_stop_freq": 5, "pred_early_stop_margin": 1.5}
pred_early_stopping = bst.predict(X_test, **pred_parameter)
# scores likely to be different, but prediction should still be the same
np.testing.assert_array_equal(np.sign(pred_from_matr), np.sign(pred_early_stopping))
# test that shape is checked during prediction
bad_X_test = X_test[:, 1:]
bad_shape_error_msg = "The number of features in data*"
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, bad_X_test)
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, sparse.csr_matrix(bad_X_test))
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, sparse.csc_matrix(bad_X_test))
with open(tname, "w+b") as f:
dump_svmlight_file(bad_X_test, y_test, f)
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, tname)
with open(tname, "w+b") as f:
dump_svmlight_file(X_test, y_test, f, zero_based=False)
np.testing.assert_raises_regex(lgb.basic.LightGBMError, bad_shape_error_msg,
bst.predict, tname)
os.remove(tname)
np.testing.assert_array_equal(hist_idx, hist_name)
np.testing.assert_allclose(bins_idx, bins_name)
# test bins string type
if np.__version__ > '1.11.0':
hist_vals, bin_edges = gbm.get_split_value_histogram(0, bins='auto')
hist = gbm.get_split_value_histogram(0, bins='auto', xgboost_style=True)
if lgb.compat.PANDAS_INSTALLED:
mask = hist_vals > 0
np.testing.assert_array_equal(hist_vals[mask], hist['Count'].values)
np.testing.assert_allclose(bin_edges[1:][mask], hist['SplitValue'].values)
else:
mask = hist_vals > 0
np.testing.assert_array_equal(hist_vals[mask], hist[:, 1])
np.testing.assert_allclose(bin_edges[1:][mask], hist[:, 0])
# test histogram is disabled for categorical features
self.assertRaises(lgb.basic.LightGBMError, gbm.get_split_value_histogram, 2)
self.assertRaises(lgb.basic.LightGBMError, get_cv_result,
params_class_3_verbose)
# no metric with non-default num_class for custom objective
res = get_cv_result(params_class_3_verbose, fobj=dummy_obj)
self.assertEqual(len(res), 0)
for metric_multi_alias in obj_multi_aliases + ['multi_logloss']:
# multiclass metric alias for custom objective
res = get_cv_result(params_class_3_verbose, metrics=metric_multi_alias, fobj=dummy_obj)
self.assertEqual(len(res), 2)
self.assertIn('multi_logloss-mean', res)
# multiclass metric for custom objective
res = get_cv_result(params_class_3_verbose, metrics='multi_error', fobj=dummy_obj)
self.assertEqual(len(res), 2)
self.assertIn('multi_error-mean', res)
# binary metric with non-default num_class for custom objective
self.assertRaises(lgb.basic.LightGBMError, get_cv_result,
params_class_3_verbose, metrics='binary_error', fobj=dummy_obj)