Skip to content

PX val_map now respects category_orders #3247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 77 additions & 69 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,10 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
if hover_is_dict and not attr_value[col]:
continue
if col in [
args.get("x", None),
args.get("y", None),
args.get("z", None),
args.get("base", None),
args.get("x"),
args.get("y"),
args.get("z"),
args.get("base"),
]:
continue
try:
Expand Down Expand Up @@ -552,8 +552,10 @@ def set_cartesian_axis_opts(args, axis, letter, orders):
axis["categoryarray"] = (
orders[args[letter]]
if isinstance(axis, go.layout.XAxis)
else list(reversed(orders[args[letter]]))
else list(reversed(orders[args[letter]])) # top down for Y axis
)
if "range" not in axis:
axis["range"] = [-0.5, len(orders[args[letter]]) - 0.5]


def configure_cartesian_marginal_axes(args, fig, orders):
Expand Down Expand Up @@ -1284,8 +1286,8 @@ def build_dataframe(args, constructor):

# now we handle special cases like wide-mode or x-xor-y specification
# by rearranging args to tee things up for process_args_into_dataframe to work
no_x = args.get("x", None) is None
no_y = args.get("y", None) is None
no_x = args.get("x") is None
no_y = args.get("y") is None
wide_x = False if no_x else _is_col_list(df_input, args["x"])
wide_y = False if no_y else _is_col_list(df_input, args["y"])

Expand All @@ -1312,9 +1314,9 @@ def build_dataframe(args, constructor):
if var_name in [None, "value", "index"] or var_name in df_input:
var_name = "variable"
if constructor == go.Funnel:
wide_orientation = args.get("orientation", None) or "h"
wide_orientation = args.get("orientation") or "h"
else:
wide_orientation = args.get("orientation", None) or "v"
wide_orientation = args.get("orientation") or "v"
args["orientation"] = wide_orientation
args["wide_cross"] = None
elif wide_x != wide_y:
Expand Down Expand Up @@ -1345,7 +1347,7 @@ def build_dataframe(args, constructor):
if constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types:
if not wide_mode and (no_x != no_y):
for ax in ["x", "y"]:
if args.get(ax, None) is None:
if args.get(ax) is None:
args[ax] = df_input.index if df_provided else Range()
if constructor == go.Bar:
missing_bar_dim = ax
Expand All @@ -1369,7 +1371,7 @@ def build_dataframe(args, constructor):
)

no_color = False
if type(args.get("color", None)) == str and args["color"] == NO_COLOR:
if type(args.get("color")) == str and args["color"] == NO_COLOR:
no_color = True
args["color"] = None
# now that things have been prepped, we do the systematic rewriting of `args`
Expand Down Expand Up @@ -1777,25 +1779,25 @@ def infer_config(args, constructor, trace_patch, layout_patch):
else args["geojson"].__geo_interface__
)

# Compute marginal attribute
# Compute marginal attribute: copy to appropriate marginal_*
if "marginal" in args:
position = "marginal_x" if args["orientation"] == "v" else "marginal_y"
other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y"
args[position] = args["marginal"]
args[other_position] = None

# If both marginals and faceting are specified, faceting wins
if args.get("facet_col", None) is not None and args.get("marginal_y", None):
if args.get("facet_col") is not None and args.get("marginal_y") is not None:
args["marginal_y"] = None

if args.get("facet_row", None) is not None and args.get("marginal_x", None):
if args.get("facet_row") is not None and args.get("marginal_x") is not None:
args["marginal_x"] = None

# facet_col_wrap only works if no marginals or row faceting is used
if (
args.get("marginal_x", None) is not None
or args.get("marginal_y", None) is not None
or args.get("facet_row", None) is not None
args.get("marginal_x") is not None
or args.get("marginal_y") is not None
or args.get("facet_row") is not None
):
args["facet_col_wrap"] = 0

Expand All @@ -1814,43 +1816,41 @@ def infer_config(args, constructor, trace_patch, layout_patch):

def get_orderings(args, grouper, grouped):
"""
`orders` is the user-supplied ordering (with the remaining data-frame-supplied
ordering appended if the column is used for grouping). It includes anything the user
gave, for any variable, including values not present in the dataset. It is used
downstream to set e.g. `categoryarray` for cartesian axes

`group_names` is the set of groups, ordered by the order above

`group_values` is a subset of `orders` in both keys and values. It contains a key
for every grouped mapping and its values are the sorted *data* values for these
mappings.
`orders` is the user-supplied ordering with the remaining data-frame-supplied
ordering appended if the column is used for grouping. It includes anything the user
gave, for any variable, including values not present in the dataset. It's a dict
where the keys are e.g. "x" or "color"

`sorted_group_names` is the set of groups, ordered by the order above. It's a list
of tuples like [("value1", ""), ("value2", "")] where each tuple contains the name
of a single dimension-group
"""

orders = {} if "category_orders" not in args else args["category_orders"].copy()
group_names = []
group_values = {}
for col in grouper:
if col != one_group:
uniques = args["data_frame"][col].unique()
if col not in orders:
orders[col] = list(uniques)
else:
orders[col] = list(orders[col])
for val in uniques:
if val not in orders[col]:
orders[col].append(val)

sorted_group_names = []
for group_name in grouped.groups:
if len(grouper) == 1:
group_name = (group_name,)
group_names.append(group_name)
for col in grouper:
if col != one_group:
uniques = args["data_frame"][col].unique()
if col not in orders:
orders[col] = list(uniques)
else:
for val in uniques:
if val not in orders[col]:
orders[col].append(val)
group_values[col] = sorted(uniques, key=orders[col].index)
sorted_group_names.append(group_name)

for i, col in reversed(list(enumerate(grouper))):
if col != one_group:
group_names = sorted(
group_names,
sorted_group_names = sorted(
sorted_group_names,
key=lambda g: orders[col].index(g[i]) if g[i] in orders[col] else -1,
)

return orders, group_names, group_values
return orders, sorted_group_names


def make_figure(args, constructor, trace_patch=None, layout_patch=None):
Expand All @@ -1871,32 +1871,35 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
grouped = args["data_frame"].groupby(grouper, sort=False)

orders, sorted_group_names, sorted_group_values = get_orderings(
args, grouper, grouped
)
orders, sorted_group_names = get_orderings(args, grouper, grouped)

col_labels = []
row_labels = []

nrows = ncols = 1
for m in grouped_mappings:
if m.grouper:
if m.grouper not in orders:
m.val_map[""] = m.sequence[0]
else:
sorted_values = orders[m.grouper]
if m.facet == "col":
prefix = get_label(args, args["facet_col"]) + "="
col_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
col_labels = [prefix + str(s) for s in sorted_values]
ncols = len(col_labels)
if m.facet == "row":
prefix = get_label(args, args["facet_row"]) + "="
row_labels = [prefix + str(s) for s in sorted_group_values[m.grouper]]
for val in sorted_group_values[m.grouper]:
if val not in m.val_map:
row_labels = [prefix + str(s) for s in sorted_values]
nrows = len(row_labels)
for val in sorted_values:
if val not in m.val_map: # always False if it's an IdentityMap
m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]

subplot_type = _subplot_type_for_trace_type(constructor().type)

trace_names_by_frame = {}
frames = OrderedDict()
trendline_rows = []
nrows = ncols = 1
trace_name_labels = None
facet_col_wrap = args.get("facet_col_wrap", 0)
for group_name in sorted_group_names:
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
mapping_labels = OrderedDict()
Expand Down Expand Up @@ -1943,8 +1946,6 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):

for i, m in enumerate(grouped_mappings):
val = group_name[i]
if val not in m.val_map:
m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]
try:
m.updater(trace, m.val_map[val]) # covers most cases
except ValueError:
Expand Down Expand Up @@ -1979,14 +1980,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
row = m.val_map[val]
else:
if (
bool(args.get("marginal_x", False))
and trace_spec.marginal != "x"
args.get("marginal_x") is not None # there is a marginal
and trace_spec.marginal != "x" # and we're not it
):
row = 2
else:
row = 1

facet_col_wrap = args.get("facet_col_wrap", 0)
# Find col for trace, handling facet_col and marginal_y
if m.facet == "col":
col = m.val_map[val]
Expand All @@ -1999,11 +1999,9 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
else:
col = 1

nrows = max(nrows, row)
if row > 1:
trace._subplot_row = row

ncols = max(ncols, col)
if col > 1:
trace._subplot_col = col
if (
Expand Down Expand Up @@ -2062,6 +2060,16 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
):
layout_patch["legend"]["itemsizing"] = "constant"

if facet_col_wrap:
nrows = math.ceil(ncols / facet_col_wrap)
ncols = min(ncols, facet_col_wrap)

if args.get("marginal_x") is not None:
nrows += 1

if args.get("marginal_y") is not None:
ncols += 1

fig = init_figure(
args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
)
Expand Down Expand Up @@ -2106,7 +2114,7 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la

# Build column_widths/row_heights
if subplot_type == "xy":
if bool(args.get("marginal_x", False)):
if args.get("marginal_x") is not None:
if args["marginal_x"] == "histogram" or ("color" in args and args["color"]):
main_size = 0.74
else:
Expand All @@ -2115,11 +2123,11 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
row_heights = [main_size] * (nrows - 1) + [1 - main_size]
vertical_spacing = 0.01
elif facet_col_wrap:
vertical_spacing = args.get("facet_row_spacing", None) or 0.07
vertical_spacing = args.get("facet_row_spacing") or 0.07
else:
vertical_spacing = args.get("facet_row_spacing", None) or 0.03
vertical_spacing = args.get("facet_row_spacing") or 0.03

if bool(args.get("marginal_y", False)):
if args.get("marginal_y") is not None:
if args["marginal_y"] == "histogram" or ("color" in args and args["color"]):
main_size = 0.74
else:
Expand All @@ -2128,18 +2136,18 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
column_widths = [main_size] * (ncols - 1) + [1 - main_size]
horizontal_spacing = 0.005
else:
horizontal_spacing = args.get("facet_col_spacing", None) or 0.02
horizontal_spacing = args.get("facet_col_spacing") or 0.02
else:
# Other subplot types:
# 'scene', 'geo', 'polar', 'ternary', 'mapbox', 'domain', None
#
# We can customize subplot spacing per type once we enable faceting
# for all plot types
if facet_col_wrap:
vertical_spacing = args.get("facet_row_spacing", None) or 0.07
vertical_spacing = args.get("facet_row_spacing") or 0.07
else:
vertical_spacing = args.get("facet_row_spacing", None) or 0.03
horizontal_spacing = args.get("facet_col_spacing", None) or 0.02
vertical_spacing = args.get("facet_row_spacing") or 0.03
horizontal_spacing = args.get("facet_col_spacing") or 0.02

if facet_col_wrap:
subplot_labels = [None] * nrows * ncols
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ def test_px_defaults():


def assert_orderings(days_order, days_check, times_order, times_check):
symbol_sequence = ["circle", "diamond", "square", "cross"]
color_sequence = ["red", "blue"]
symbol_sequence = ["circle", "diamond", "square", "cross", "circle", "diamond"]
color_sequence = ["red", "blue", "red", "blue", "red", "blue", "red", "blue"]
fig = px.scatter(
px.data.tips(),
x="total_bill",
Expand All @@ -229,7 +229,7 @@ def assert_orderings(days_order, days_check, times_order, times_check):
assert days_check[col] in trace.hovertemplate

for row in range(len(times_check)):
for trace in fig.select_traces(row=2 - row):
for trace in fig.select_traces(row=len(times_check) - row):
assert times_check[row] in trace.hovertemplate

for trace in fig.data:
Expand All @@ -241,13 +241,10 @@ def assert_orderings(days_order, days_check, times_order, times_check):
assert trace.marker.color == color_sequence[i]


def test_noisy_orthogonal_orderings():
assert_orderings(
["x", "Sun", "Sat", "y", "Fri", "z"], # add extra noise, missing Thur
["Sun", "Sat", "Fri", "Thur"], # Thur is at the back
["a", "Lunch", "b"], # add extra noise, missing Dinner
["Lunch", "Dinner"], # Dinner is at the back
)
@pytest.mark.parametrize("days", permutations(["Sun", "Sat", "Fri", "x"]))
@pytest.mark.parametrize("times", permutations(["Lunch", "x"]))
def test_orthogonal_and_missing_orderings(days, times):
assert_orderings(days, list(days) + ["Thur"], times, list(times) + ["Dinner"])


@pytest.mark.parametrize("days", permutations(["Sun", "Sat", "Fri", "Thur"]))
Expand Down