Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
pyarrow_schema = data_types.extract_pyarrow_schema_from_pandas(dataframe=dataframe,
preserve_index=preserve_index,
indexes_position=indexes_position)
schema_built = []
partition_cols_types = {}
for name, dtype in pyarrow_schema:
if (cast_columns is not None) and (name in cast_columns.keys()):
if name in partition_cols:
partition_cols_types[name] = cast_columns[name]
else:
schema_built.append((name, cast_columns[name]))
else:
try:
athena_type = data_types.pyarrow2athena(dtype)
except UndetectedType:
raise UndetectedType(f"We can't infer the data type from an entire null object column ({name}). "
f"Please consider pass the type of this column explicitly using the cast "
f"columns argument")
except UnsupportedType:
raise UnsupportedType(f"Unsupported Pyarrow type for column {name}: {dtype}")
if name in partition_cols:
partition_cols_types[name] = athena_type
else:
schema_built.append((name, athena_type))
partition_cols_schema_built = [(name, partition_cols_types[name]) for name in partition_cols]
logger.debug(f"schema_built:\n{schema_built}")
logger.debug(f"partition_cols_schema_built:\n{partition_cols_schema_built}")
return schema_built, partition_cols_schema_built
:param diststyle: Redshift distribution styles. Must be in ["AUTO", "EVEN", "ALL", "KEY"] (https://docs.aws.amazon.com/redshift/latest/dg/t_Distributing_data.html)
:param distkey: Specifies a column name or positional number for the distribution key
:param sortstyle: Sorting can be "COMPOUND" or "INTERLEAVED" (https://docs.aws.amazon.com/redshift/latest/dg/t_Sorting_data.html)
:param sortkey: List of columns to be sorted
:param primary_keys: Primary keys
:param preserve_index: Should we preserve the Dataframe index?
:param mode: append, overwrite or upsert
:param cast_columns: Dictionary of columns names and Redshift types to be casted. (E.g. {"col name": "SMALLINT", "col2 name": "FLOAT4"})
:return: None
"""
if cast_columns is None:
cast_columns = {}
cast_columns_parquet: Dict = {}
else:
cast_columns_tuples: List[Tuple[str, str]] = [(k, v) for k, v in cast_columns.items()]
cast_columns_parquet = data_types.convert_schema(func=data_types.redshift2athena,
schema=cast_columns_tuples)
if path[-1] != "/":
path += "/"
self._session.s3.delete_objects(path=path)
num_rows: int = len(dataframe.index)
logger.debug(f"Number of rows: {num_rows}")
if num_rows < MIN_NUMBER_OF_ROWS_TO_DISTRIBUTE:
num_partitions: int = 1
else:
num_slices: int = self._session.redshift.get_number_of_slices(redshift_conn=connection)
logger.debug(f"Number of slices on Redshift: {num_slices}")
num_partitions = num_slices
logger.debug(f"Number of partitions calculated: {num_partitions}")
objects_paths: List[str] = self.to_parquet(dataframe=dataframe,
path=path,
preserve_index=preserve_index,
dtypes = copy.deepcopy(dataframe.dtypes.to_dict())
for name, dtype in dtypes.items():
if str(dtype) == "Int64":
dataframe[name] = dataframe[name].astype("float64")
casted_in_pandas.append(name)
cast_columns[name] = "bigint"
logger.debug(f"Casting column {name} Int64 to float64")
# Converting Pandas Dataframe to Pyarrow's Table
table = pa.Table.from_pandas(df=dataframe, preserve_index=preserve_index, safe=False)
# Casting on Pyarrow
if cast_columns:
for col_name, dtype in cast_columns.items():
col_index = table.column_names.index(col_name)
pyarrow_dtype = data_types.athena2pyarrow(dtype)
field = pa.field(name=col_name, type=pyarrow_dtype)
table = table.set_column(col_index, field, table.column(col_name).cast(pyarrow_dtype))
logger.debug(f"Casting column {col_name} ({col_index}) to {dtype} ({pyarrow_dtype})")
# Persisting on S3
Pandas._write_parquet_to_s3_retrying(fs=fs, path=path, table=table, compression=compression)
# Casting back on Pandas if necessary
if isolated_dataframe is False:
for col in casted_in_pandas:
dataframe[col] = dataframe[col].astype("Int64")
cast_columns = {}
schema_built = []
if dataframe_type == "pandas":
pyarrow_schema = data_types.extract_pyarrow_schema_from_pandas(dataframe=dataframe,
preserve_index=preserve_index,
indexes_position="right")
for name, dtype in pyarrow_schema:
if (cast_columns is not None) and (name in cast_columns.keys()):
schema_built.append((name, cast_columns[name]))
else:
redshift_type = data_types.pyarrow2redshift(dtype)
schema_built.append((name, redshift_type))
elif dataframe_type == "spark":
for name, dtype in dataframe.dtypes:
if name in cast_columns.keys():
redshift_type = data_types.athena2redshift(cast_columns[name])
else:
redshift_type = data_types.spark2redshift(dtype)
schema_built.append((name, redshift_type))
else:
raise InvalidDataframeType(dataframe_type)
return schema_built
def get_table_python_types(self, database, table):
"""
Get all columns names and the related python types
:param database: Glue database's name
:param table: Glue table's name
:return: A dictionary as {"col name": "col python type"}
"""
dtypes = self.get_table_athena_types(database=database, table=table)
return {k: data_types.athena2python(v) for k, v in dtypes.items()}
def _build_schema(dataframe, partition_cols, preserve_index, indexes_position, cast_columns=None):
if cast_columns is None:
cast_columns = {}
logger.debug(f"dataframe.dtypes:\n{dataframe.dtypes}")
if not partition_cols:
partition_cols = []
pyarrow_schema = data_types.extract_pyarrow_schema_from_pandas(dataframe=dataframe,
preserve_index=preserve_index,
indexes_position=indexes_position)
schema_built = []
partition_cols_types = {}
for name, dtype in pyarrow_schema:
if (cast_columns is not None) and (name in cast_columns.keys()):
if name in partition_cols:
partition_cols_types[name] = cast_columns[name]
else:
schema_built.append((name, cast_columns[name]))
else:
try:
athena_type = data_types.pyarrow2athena(dtype)
except UndetectedType:
raise UndetectedType(f"We can't infer the data type from an entire null object column ({name}). "
def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False, cast_columns=None):
if cast_columns is None:
cast_columns = {}
schema_built = []
if dataframe_type == "pandas":
pyarrow_schema = data_types.extract_pyarrow_schema_from_pandas(dataframe=dataframe,
preserve_index=preserve_index,
indexes_position="right")
for name, dtype in pyarrow_schema:
if (cast_columns is not None) and (name in cast_columns.keys()):
schema_built.append((name, cast_columns[name]))
else:
redshift_type = data_types.pyarrow2redshift(dtype)
schema_built.append((name, redshift_type))
elif dataframe_type == "spark":
for name, dtype in dataframe.dtypes:
if name in cast_columns.keys():
redshift_type = data_types.athena2redshift(cast_columns[name])
else:
redshift_type = data_types.spark2redshift(dtype)
schema_built.append((name, redshift_type))
else: