Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_effective_n_jobs():
backend = SparkDistributedBackend()
max_num_concurrent_tasks = 8
backend._get_max_num_concurrent_tasks = MagicMock(return_value=max_num_concurrent_tasks)
assert backend.effective_n_jobs(n_jobs=None) == 1
assert backend.effective_n_jobs(n_jobs=-1) == 8
assert backend.effective_n_jobs(n_jobs=4) == 4
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
assert backend.effective_n_jobs(n_jobs=16) == 16
assert len(w) == 1
def __init__(self, **backend_args):
super(SparkDistributedBackend, self).__init__(**backend_args)
self._pool = None
self._n_jobs = None
self._spark = SparkSession \
.builder \
.appName("JoblibSparkBackend") \
.getOrCreate()
self._job_group = "joblib-spark-job-group-" + str(uuid.uuid4())
def register():
"""
Register joblib spark backend.
"""
try:
import sklearn # pylint: disable=C0415
if LooseVersion(sklearn.__version__) < LooseVersion('0.21'):
warnings.warn("Your sklearn version is < 0.21, but joblib-spark only support "
"sklearn >=0.21 . You can upgrade sklearn to version >= 0.21 to "
"make sklearn use spark backend.")
except ImportError:
pass
register_parallel_backend('spark', SparkDistributedBackend)