Skip to content

Commit 5111b0d

Browse files
committed
Correctly sort Categorical facets.
There's a bug where you get an error if you facet a geom_bar with a single facet variable which is Categorical. This is an old bug, but it's more visible now that there's a reason to use Categorical facet variables. I think it's due to a bug in pandas: pandas-dev/pandas#14011
1 parent 0ff7ac0 commit 5111b0d

File tree

4 files changed

+23
-10
lines changed

4 files changed

+23
-10
lines changed

ggplot/facets.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import pprint as pp
55
from collections import OrderedDict
66

7+
from .utils import sorted_unique
8+
79
class Facet(object):
810
def __init__(self, data, is_wrap, rowvar=None, colvar=None, nrow=None, ncol=None, scales=None):
911
self.rowvar = rowvar
@@ -54,14 +56,14 @@ def __init__(self, data, is_wrap, rowvar=None, colvar=None, nrow=None, ncol=None
5456

5557
def generate_subplot_index(self, data, rowvar, colvar):
5658
if rowvar and colvar:
57-
for row_idx, row in enumerate(sorted(data[rowvar].unique())):
58-
for col_idx, col in enumerate(sorted(data[colvar].unique())):
59+
for row in sorted_unique(data[rowvar]):
60+
for col in sorted_unique(data[colvar]):
5961
yield (row, col)
6062
elif rowvar:
61-
for row_idx, row in enumerate(sorted(data[rowvar].unique())):
63+
for row in sorted_unique(data[rowvar]):
6264
yield row
6365
elif colvar:
64-
for col_idx, col in enumerate(sorted(data[colvar].unique())):
66+
for col in sorted_unique(data[colvar]):
6567
yield col
6668

6769
def calculate_ndimensions(self, data, rowvar, colvar):

ggplot/ggplot.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .themes import theme_gray
1616
from .themes import element_text
1717
from . import discretemappers
18-
from .utils import format_ticks
18+
from .utils import format_ticks, sorted_unique
1919
import urllib
2020
import base64
2121
import os
@@ -148,10 +148,8 @@ def add_labels(self):
148148
if self.facets.is_wrap:
149149
return
150150
if self.facets.rowvar:
151-
for row, name in enumerate(sorted(self.data[self.facets.rowvar].unique())):
152-
if self.facets.is_wrap==True:
153-
continue
154-
elif self.facets.colvar:
151+
for row, name in enumerate(sorted_unique(self.data[self.facets.rowvar])):
152+
if self.facets.colvar:
155153
ax = self.subplots[row][-1]
156154
else:
157155
ax = self.subplots[row]
@@ -160,7 +158,7 @@ def add_labels(self):
160158
ax.set_ylabel(name, fontsize=10, rotation=-90)
161159

162160
if self.facets.colvar:
163-
for col, name in enumerate(sorted(self.data[self.facets.colvar].unique())):
161+
for col, name in enumerate(sorted_unique(self.data[self.facets.colvar])):
164162
if len(self.subplots.shape) > 1:
165163
col = col % self.facets.ncol
166164
ax = self.subplots[0][col]

ggplot/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,9 @@ def calc_n_bins(series):
9696
h = (2 * iqr) / (len(series)**(1/3.))
9797
k = (series.max() - series.min()) / h
9898
return k
99+
100+
def sorted_unique(series):
101+
"""Return the unique values of *series*, correctly sorted."""
102+
# This handles Categorical data types, which sorted(series.unique()) fails
103+
# on. series.drop_duplicates() is slower than Series(series.unique()).
104+
return list(pd.Series(series.unique()).sort_values())

tests/test_facets.py

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
import pandas as pd
12
from ggplot import *
23

4+
diamonds['cut'] = pd.Categorical(diamonds['cut'], ordered=True,
5+
categories=['Fair', 'Good', 'Very Good', 'Premium', 'Ideal'])
6+
diamonds['clarity'] = pd.Categorical(diamonds['clarity'], ordered=True,
7+
categories='I1 SI2 SI1 VS2 VS1 VVS2 VVS1 IF'.split())
38

49
diaa = diamonds[['cut','color','table']]
510
diab = diaa.groupby(['cut','color']).quantile([x/100.0 for x in range(0,100,5)])
@@ -9,3 +14,5 @@
914

1015
print ggplot(diamonds, aes(x='clarity', weight='price')) + geom_bar() + facet_grid('color', 'cut')
1116
print ggplot(diamonds, aes(x='clarity', weight='price', fill='color')) + geom_bar() + facet_grid('color', 'cut')
17+
18+
print ggplot(diamonds, aes(x='clarity', y='price')) + geom_boxplot() + facet_wrap('cut')

0 commit comments

Comments
 (0)