Skip to content

preload val_map from orders #2105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jan 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 29 additions & 13 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,11 +1136,19 @@ def infer_config(args, constructor, trace_patch):
def get_orderings(args, grouper, grouped):
"""
`orders` is the user-supplied ordering (with the remaining data-frame-supplied
ordering appended if the column is used for grouping)
ordering appended if the column is used for grouping). It includes anything the user
gave, for any variable, including values not present in the dataset. It is used
downstream to set e.g. `categoryarray` for cartesian axes

`group_names` is the set of groups, ordered by the order above

`group_values` is a subset of `orders` in both keys and values. It contains a key
for every grouped mapping and its values are the sorted *data* values for these
mappings.
"""
orders = {} if "category_orders" not in args else args["category_orders"].copy()
group_names = []
group_values = {}
for group_name in grouped.groups:
if len(grouper) == 1:
group_name = (group_name,)
Expand All @@ -1154,6 +1162,7 @@ def get_orderings(args, grouper, grouped):
for val in uniques:
if val not in orders[col]:
orders[col].append(val)
group_values[col] = sorted(uniques, key=orders[col].index)

for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
Expand All @@ -1162,7 +1171,7 @@ def get_orderings(args, grouper, grouped):
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
)

return orders, group_names
return orders, group_names, group_values


def make_figure(args, constructor, trace_patch={}, layout_patch={}):
Expand All @@ -1174,16 +1183,31 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped = args["data_frame"].groupby(grouper, sort=False)

orders, sorted_group_names = get_orderings(args, grouper, grouped)
orders, sorted_group_names, sorted_group_values = get_orderings(
args, grouper, grouped
)

col_labels = []
row_labels = []

for m in grouped_mappings:
if m.grouper:
if m.facet == "col":
prefix = get_label(args, args["facet_col"]) + "="
col_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
if m.facet == "row":
prefix = get_label(args, args["facet_row"]) + "="
row_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
for val in sorted_group_values[m.grouper]:
if val not in m.val_map:
m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]

subplot_type = _subplot_type_for_trace_type(constructor().type)

trace_names_by_frame = {}
frames = OrderedDict()
trendline_rows = []
nrows = ncols = 1
col_labels = []
row_labels = []
trace_name_labels = None
for group_name in sorted_group_names:
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
Expand Down Expand Up @@ -1281,10 +1305,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
# Find row for trace, handling facet_row and marginal_x
if m.facet == "row":
row = m.val_map[val]
if args["facet_row"] and len(row_labels) < row:
row_labels.append(
get_label(args, args["facet_row"]) + "=" + str(val)
)
else:
if (
bool(args.get("marginal_x", False))
Expand All @@ -1298,10 +1318,6 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
# Find col for trace, handling facet_col and marginal_y
if m.facet == "col":
col = m.val_map[val]
if args["facet_col"] and len(col_labels) < col:
col_labels.append(
get_label(args, args["facet_col"]) + "=" + str(val)
)
if facet_col_wrap: # assumes no facet_row, no marginals
row = 1 + ((col - 1) // facet_col_wrap)
col = 1 + ((col - 1) % facet_col_wrap)
Expand Down
50 changes: 50 additions & 0 deletions packages/python/plotly/plotly/tests/test_core/test_px/test_px.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,53 @@ def test_px_templates():
assert fig.layout.xaxis3.showgrid is None
assert fig.layout.yaxis2.showgrid
assert fig.layout.yaxis3.showgrid


def test_orthogonal_orderings():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great test!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth testing that adding in category_orders a value not present in the data does not break things?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I tried but it's a harder test to write and much slower :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I think I have a way to spot-check a few options

from itertools import permutations

df = px.data.tips()

symbol_sequence = ["circle", "diamond", "square", "cross"]
color_sequence = ["red", "blue"]

def assert_orderings(days_order, days_check, times_order, times_check):
fig = px.scatter(
df,
x="total_bill",
y="tip",
facet_row="time",
facet_col="day",
color="time",
symbol="day",
symbol_sequence=symbol_sequence,
color_discrete_sequence=color_sequence,
category_orders=dict(day=days_order, time=times_order),
)

for col in range(len(days_check)):
for trace in fig.select_traces(col=col + 1):
assert days_check[col] in trace.hovertemplate

for row in range(len(times_check)):
for trace in fig.select_traces(row=2 - row):
assert times_check[row] in trace.hovertemplate

for trace in fig.data:
for i, day in enumerate(days_check):
if day in trace.name:
assert trace.marker.symbol == symbol_sequence[i]
for i, time in enumerate(times_check):
if time in trace.name:
assert trace.marker.color == color_sequence[i]

assert_orderings(
["x", "Sun", "Sat", "y", "Fri", "z"], # add extra noise, missing Thur
["Sun", "Sat", "Fri", "Thur"], # Thur is at the back
["a", "Lunch", "b"], # add extra noise, missing Dinner
["Lunch", "Dinner"], # Dinner is at the back
)

for days in permutations(df["day"].unique()):
for times in permutations(df["time"].unique()):
assert_orderings(days, days, times, times)