Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_prevent_override(self, versioned_feather_data_set, dummy_dataframe):
"""Check the error when attempt to override the same data set
version."""
versioned_feather_data_set.save(dummy_dataframe)
pattern = (
r"Save path \`.+\` for FeatherLocalDataSet\(.+\) must not "
r"exist if versioning is enabled"
)
with pytest.raises(DataSetError, match=pattern):
versioned_feather_data_set.save(dummy_dataframe)
def test_non_existent_bucket(self):
"""Test non-existent bucket"""
pattern = r"Failed while loading data from data set CSVGCSDataSet\(.+\)"
with pytest.raises(DataSetError, match=pattern):
CSVGCSDataSet(
filepath=FILENAME,
bucket_name="not-existing-bucket",
project=GCP_PROJECT,
credentials=None,
).load()
def test_node_returning_none(self, saving_none_pipeline):
pattern = "Saving `None` to a `DataSet` is not allowed"
with pytest.raises(DataSetError, match=pattern):
SequentialRunner().run(saving_none_pipeline, DataCatalog())
def test_load_missing(self, hdf_data_set):
"""Check the error when trying to load missing hdf file."""
pattern = r"Failed while loading data from data set HDFLocalDataSet\(.+\)"
with pytest.raises(DataSetError, match=pattern):
hdf_data_set.load()
def test_load_not_callable(self):
pattern = (
r"`load` function for LambdaDataSet must be a Callable\. "
r"Object of type `str` provided instead\."
)
with pytest.raises(DataSetError, match=pattern):
LambdaDataSet("load", None)
def test_save_overwrite_fail(self, tmp_path, sample_spark_df):
# Writes a data frame twice and expects it to fail.
filepath = str(tmp_path / "test_data")
spark_data_set = SparkDataSet(filepath=filepath)
spark_data_set.save(sample_spark_df)
with pytest.raises(DataSetError):
spark_data_set.save(sample_spark_df)
def test_empty_credentials_load(self, bad_credentials):
parquet_data_set = ParquetS3DataSet(
filepath=FILENAME, bucket_name=BUCKET_NAME, credentials=bad_credentials
)
pattern = r"Failed while loading data from data set ParquetS3DataSet\(.+\)"
with pytest.raises(DataSetError, match=pattern):
parquet_data_set.load()
def test_prevent_overwrite(self, versioned_json_data_set, json_data):
"""Check the error when attempting to override the data set if the
corresponding hdf file for a given save version already exists."""
versioned_json_data_set.save(json_data)
pattern = (
r"Save path \`.+\` for JSONLocalDataSet\(.+\) must "
r"not exist if versioning is enabled\."
)
with pytest.raises(DataSetError, match=pattern):
versioned_json_data_set.save(json_data)
def test_load_missing_file(self, yaml_data_set):
"""Check the error when trying to load missing file."""
pattern = r"Failed while loading data from data set " r"YAMLLocalDataSet\(.*\)"
with pytest.raises(DataSetError, match=pattern):
yaml_data_set.load()
def save(self, data: Any):
"""Calls save method of a shared MemoryDataSet in SyncManager.
"""
try:
self.shared_memory_dataset.save(data)
except Exception as exc: # pylint: disable=broad-except
# Checks if the error is due to serialisation or not
try:
pickle.dumps(data)
except Exception:
raise DataSetError(
"{} cannot be serialized. ParallelRunner implicit memory datasets "
"can only be used with serializable data".format(
str(data.__class__)
)
)
else:
raise exc