Skip to content

Commit 5d9c525

Browse files
authored
Merge pull request #115 from awslabs/categorical-partitions
Handling categorical partitions
2 parents 5cc6360 + a81769c commit 5d9c525

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

awswrangler/pandas.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,7 +1004,7 @@ def _data_to_s3_dataset_writer(dataframe: pd.DataFrame,
10041004
objects_paths.append(object_path)
10051005
else:
10061006
dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns)
1007-
for keys, subgroup in dataframe.groupby(partition_cols):
1007+
for keys, subgroup in dataframe.groupby(by=partition_cols, observed=True):
10081008
subgroup = subgroup.drop(partition_cols, axis="columns")
10091009
if not isinstance(keys, tuple):
10101010
keys = (keys, )
@@ -1407,7 +1407,7 @@ def read_parquet(self,
14071407
if len(dfs) == 1:
14081408
df: pd.DataFrame = dfs[0]
14091409
else:
1410-
df = pd.concat(objs=dfs, ignore_index=True)
1410+
df = pd.concat(objs=dfs, ignore_index=True, sort=False)
14111411
return df
14121412

14131413
@staticmethod
@@ -1870,7 +1870,7 @@ def read_csv_list(
18701870
logger.debug(f"Closing proc number: {i}")
18711871
receive_pipes[i].close()
18721872
logger.debug(f"Concatenating all {len(paths)} DataFrames...")
1873-
df = pd.concat(objs=dfs, ignore_index=True)
1873+
df = pd.concat(objs=dfs, ignore_index=True, sort=False)
18741874
return df
18751875

18761876
def _read_csv_list_iterator(

testing/test_awswrangler/test_pandas.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2173,3 +2173,16 @@ def test_aurora_mysql_load_special2(bucket, mysql_parameters):
21732173
assert rows[0][2] is None
21742174
assert rows[1][3] is None
21752175
conn.close()
2176+
2177+
2178+
def test_to_parquet_categorical_partitions(bucket):
2179+
path = f"s3://{bucket}/test_to_parquet_categorical_partitions"
2180+
wr.s3.delete_objects(path=path)
2181+
d = pd.date_range("1990-01-01", freq="D", periods=10000)
2182+
vals = pd.np.random.randn(len(d), 4)
2183+
x = pd.DataFrame(vals, index=d, columns=["A", "B", "C", "D"])
2184+
x['Year'] = x.index.year
2185+
x['Year'] = x['Year'].astype('category')
2186+
wr.pandas.to_parquet(x[x.Year == 1990], path=path, partition_cols=["Year"])
2187+
y = wr.pandas.read_parquet(path=path)
2188+
assert len(x[x.Year == 1990].index) == len(y.index)

0 commit comments

Comments
 (0)