Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_exists(self, file_format, tmp_path, sample_spark_df):
filepath = str(tmp_path / "test_data")
spark_data_set = SparkDataSet(filepath=filepath, file_format=file_format)
assert not spark_data_set.exists()
spark_data_set.save(sample_spark_df)
assert spark_data_set.exists()
def versioned_dataset_dbfs(tmp_path, version):
return SparkDataSet(filepath="/dbfs" + str(tmp_path / FILENAME), version=version)
def test_load_options_csv(self, tmp_path, sample_pandas_df):
filepath = str(tmp_path / "data")
local_csv_data_set = CSVLocalDataSet(filepath=filepath)
local_csv_data_set.save(sample_pandas_df)
spark_data_set = SparkDataSet(
filepath=filepath, file_format="csv", load_args={"header": True}
)
spark_df = spark_data_set.load()
assert spark_df.filter(col("Name") == "Alex").count() == 1
def test_prevent_overwrite(self, mocker, version):
hdfs_status = mocker.patch(
"kedro.contrib.io.pyspark.spark_data_set.InsecureClient.status"
)
hdfs_status.return_value = True
versioned_hdfs = SparkDataSet(
filepath="hdfs://{}".format(HDFS_PREFIX), version=version
)
mocked_spark_df = mocker.Mock()
pattern = (
r"Save path `.+` for SparkDataSet\(.+\) must not exist "
r"if versioning is enabled"
)
with pytest.raises(DataSetError, match=pattern):
versioned_hdfs.save(mocked_spark_df)
hdfs_status.assert_called_once_with(
"{fn}/{f}/{v}/{f}".format(fn=FOLDER_NAME, v=version.save, f=FILENAME),
strict=False,
)
def test_save_partition(self, tmp_path, sample_spark_df):
# To verify partitioning this test will partition the data by one
# of the columns and then check whether partitioned column is added
# to the save path
filepath = Path(str(tmp_path / "test_data"))
spark_data_set = SparkDataSet(
filepath=str(filepath),
save_args={"mode": "overwrite", "partitionBy": ["name"]},
)
spark_data_set.save(sample_spark_df)
expected_path = filepath / "name=Alex"
assert expected_path.exists()
def test_repr(self, version):
versioned_hdfs = SparkDataSet(
filepath="hdfs://{}".format(HDFS_PREFIX), version=version
)
assert "filepath=hdfs://" in str(versioned_hdfs)
assert "version=Version(load=None, save='{}')".format(version.save) in str(
versioned_hdfs
)
dataset_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX))
assert "filepath=hdfs://" in str(dataset_hdfs)
assert "version=" not in str(dataset_hdfs)
def versioned_dataset_s3(version):
return SparkDataSet(
filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME),
version=version,
credentials=AWS_CREDENTIALS,
)
def test_repr(self, version):
versioned_hdfs = SparkDataSet(
filepath="hdfs://{}".format(HDFS_PREFIX), version=version
)
assert "filepath=hdfs://" in str(versioned_hdfs)
assert "version=Version(load=None, save='{}')".format(version.save) in str(
versioned_hdfs
)
dataset_hdfs = SparkDataSet(filepath="hdfs://{}".format(HDFS_PREFIX))
assert "filepath=hdfs://" in str(dataset_hdfs)
assert "version=" not in str(dataset_hdfs)
def test_save_version_warning(self, mocker):
exact_version = Version("2019-01-01T23.59.59.999Z", "2019-01-02T00.00.00.000Z")
versioned_hdfs = SparkDataSet(
filepath="hdfs://{}".format(HDFS_PREFIX), version=exact_version
)
mocker.patch.object(versioned_hdfs, "_exists_function", return_value=False)
mocked_spark_df = mocker.Mock()
pattern = (
r"Save version `{ev.save}` did not match load version "
r"`{ev.load}` for SparkDataSet\(.+\)".format(ev=exact_version)
)
with pytest.warns(UserWarning, match=pattern):
versioned_hdfs.save(mocked_spark_df)
mocked_spark_df.write.save.assert_called_once_with(
"hdfs://{fn}/{f}/{sv}/{f}".format(
fn=FOLDER_NAME, f=FILENAME, sv=exact_version.save
),
def test_repr(self, versioned_dataset_s3, version):
assert "filepath=s3a://" in str(versioned_dataset_s3)
assert "version=Version(load=None, save='{}')".format(version.save) in str(
versioned_dataset_s3
)
dataset_s3 = SparkDataSet(filepath="s3a://{}/{}".format(BUCKET_NAME, FILENAME))
assert "filepath=s3a://" in str(dataset_s3)
assert "version=" not in str(dataset_s3)