Skip to content

Commit 278b7c9

Browse files
committed
Improving Spark.read_csv tests
1 parent 65904ce commit 278b7c9

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

awswrangler/redshift.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,11 @@ def _get_redshift_schema(dataframe, dataframe_type, preserve_index=False):
341341
dtype = str(dataframe.index.dtype)
342342
redshift_type = Redshift._type_pandas2redshift(dtype)
343343
schema_built.append((name, redshift_type))
344-
for col, dtype in dataframe.dtypes:
344+
for col in dataframe.columns:
345+
name = str(col)
346+
dtype = str(dataframe[name].dtype)
345347
redshift_type = Redshift._type_pandas2redshift(dtype)
346-
schema_built.append((col, redshift_type))
348+
schema_built.append((name, redshift_type))
347349
elif dataframe_type == "spark":
348350
for name, dtype in dataframe.dtypes:
349351
redshift_type = Redshift._type_spark2redshift(dtype)

testing/test_awswrangler/test_spark.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,33 @@ def bucket(session, cloudformation_outputs):
3939
session.s3.delete_objects(path=f"s3://{bucket}/")
4040

4141

42-
def test_read_csv(session, bucket):
43-
boto3.client("s3").upload_file("data_samples/small.csv", bucket,
44-
"data_samples/small.csv")
45-
path = f"s3://{bucket}/data_samples/small.csv"
46-
dataframe = session.spark.read_csv(path=path)
47-
assert dataframe.count() == 100
42+
@pytest.mark.parametrize(
43+
"sample_name",
44+
["nano", "micro", "small"],
45+
)
46+
def test_read_csv(session, bucket, sample_name):
47+
path = f"data_samples/{sample_name}.csv"
48+
if sample_name == "micro":
49+
schema = "id SMALLINT, name STRING, value FLOAT, date TIMESTAMP"
50+
timestamp_format = "yyyy-MM-dd"
51+
elif sample_name == "small":
52+
schema = "id BIGINT, name STRING, date DATE"
53+
timestamp_format = "dd-MM-yy"
54+
elif sample_name == "nano":
55+
schema = "id INTEGER, name STRING, value DOUBLE, date TIMESTAMP, time TIMESTAMP"
56+
timestamp_format = "yyyy-MM-dd"
57+
dataframe = session.spark.read_csv(path=path,
58+
schema=schema,
59+
timestampFormat=timestamp_format,
60+
dateFormat=timestamp_format,
61+
header=True)
62+
63+
boto3.client("s3").upload_file(path, bucket, path)
64+
path2 = f"s3://{bucket}/{path}"
65+
dataframe2 = session.spark.read_csv(path=path2,
66+
schema=schema,
67+
timestampFormat=timestamp_format,
68+
dateFormat=timestamp_format,
69+
header=True)
70+
assert dataframe.count() == dataframe2.count()
71+
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))

0 commit comments

Comments
 (0)