diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py
index e94a79d3954..dd670c26d40 100644
--- a/packages/python/plotly/plotly/express/_core.py
+++ b/packages/python/plotly/plotly/express/_core.py
@@ -137,7 +137,7 @@ def make_mapping(args, variable):
)
-def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
+def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
"""Populates a dict with arguments to update trace
Parameters
@@ -147,7 +147,7 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
trace_spec : NamedTuple
which kind of trace to be used (has constructor, marginal etc.
attributes)
- g : pandas DataFrame
+ trace_data : pandas DataFrame
data
mapping_labels : dict
to be used for hovertemplate
@@ -156,87 +156,92 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
Returns
-------
- result : dict
+ trace_patch : dict
dict to be used to update trace
fit_results : dict
fit information to be used for trendlines
"""
if "line_close" in args and args["line_close"]:
- g = g.append(g.iloc[0])
- result = trace_spec.trace_patch.copy() or {}
+ trace_data = trace_data.append(trace_data.iloc[0])
+ trace_patch = trace_spec.trace_patch.copy() or {}
fit_results = None
hover_header = ""
custom_data_len = 0
- for k in trace_spec.attrs:
- v = args[k]
- v_label = get_decorated_label(args, v, k)
- if k == "dimensions":
+ for attr_name in trace_spec.attrs:
+ attr_value = args[attr_name]
+ attr_label = get_decorated_label(args, attr_value, attr_name)
+ if attr_name == "dimensions":
dims = [
(name, column)
- for (name, column) in g.iteritems()
- if ((not v) or (name in v))
+ for (name, column) in trace_data.iteritems()
+ if ((not attr_value) or (name in attr_value))
and (
trace_spec.constructor != go.Parcoords
or args["data_frame"][name].dtype.kind in "bifc"
)
and (
trace_spec.constructor != go.Parcats
- or (v is not None and name in v)
+ or (attr_value is not None and name in attr_value)
or len(args["data_frame"][name].unique())
<= args["dimensions_max_cardinality"]
)
]
- result["dimensions"] = [
+ trace_patch["dimensions"] = [
dict(label=get_label(args, name), values=column.values)
for (name, column) in dims
]
if trace_spec.constructor == go.Splom:
- for d in result["dimensions"]:
+ for d in trace_patch["dimensions"]:
d["axis"] = dict(matches=True)
mapping_labels["%{xaxis.title.text}"] = "%{x}"
mapping_labels["%{yaxis.title.text}"] = "%{y}"
elif (
- v is not None
- or (trace_spec.constructor == go.Histogram and k in ["x", "y"])
+ attr_value is not None
+ or (trace_spec.constructor == go.Histogram and attr_name in ["x", "y"])
or (
trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour]
- and k == "z"
+ and attr_name == "z"
)
):
- if k == "size":
- if "marker" not in result:
- result["marker"] = dict()
- result["marker"]["size"] = g[v]
- result["marker"]["sizemode"] = "area"
- result["marker"]["sizeref"] = sizeref
- mapping_labels[v_label] = "%{marker.size}"
- elif k == "marginal_x":
+ if attr_name == "size":
+ if "marker" not in trace_patch:
+ trace_patch["marker"] = dict()
+ trace_patch["marker"]["size"] = trace_data[attr_value]
+ trace_patch["marker"]["sizemode"] = "area"
+ trace_patch["marker"]["sizeref"] = sizeref
+ mapping_labels[attr_label] = "%{marker.size}"
+ elif attr_name == "marginal_x":
if trace_spec.constructor == go.Histogram:
mapping_labels["count"] = "%{y}"
- elif k == "marginal_y":
+ elif attr_name == "marginal_y":
if trace_spec.constructor == go.Histogram:
mapping_labels["count"] = "%{x}"
- elif k == "trendline":
- if v in ["ols", "lowess"] and args["x"] and args["y"] and len(g) > 1:
+ elif attr_name == "trendline":
+ if (
+ attr_value in ["ols", "lowess"]
+ and args["x"]
+ and args["y"]
+ and len(trace_data) > 1
+ ):
import statsmodels.api as sm
# sorting is bad but trace_specs with "trendline" have no other attrs
- g2 = g.sort_values(by=args["x"])
- y = g2[args["y"]]
- x = g2[args["x"]]
- result["x"] = x
+ sorted_trace_data = trace_data.sort_values(by=args["x"])
+ y = sorted_trace_data[args["y"]]
+ x = sorted_trace_data[args["x"]]
+ trace_patch["x"] = x
if x.dtype.type == np.datetime64:
x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds
- if v == "lowess":
+ if attr_value == "lowess":
trendline = sm.nonparametric.lowess(y, x)
- result["y"] = trendline[:, 1]
+ trace_patch["y"] = trendline[:, 1]
hover_header = "LOWESS trendline
"
- elif v == "ols":
+ elif attr_value == "ols":
fit_results = sm.OLS(y.values, sm.add_constant(x.values)).fit()
- result["y"] = fit_results.predict()
+ trace_patch["y"] = fit_results.predict()
hover_header = "OLS trendline
"
hover_header += "%s = %g * %s + %g
" % (
args["y"],
@@ -250,120 +255,127 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref):
mapping_labels[get_label(args, args["x"])] = "%{x}"
mapping_labels[get_label(args, args["y"])] = "%{y} (trend)"
- elif k.startswith("error"):
- error_xy = k[:7]
- arr = "arrayminus" if k.endswith("minus") else "array"
- if error_xy not in result:
- result[error_xy] = {}
- result[error_xy][arr] = g[v]
- elif k == "custom_data":
- result["customdata"] = g[v].values
- custom_data_len = len(v) # number of custom data columns
- elif k == "hover_name":
+ elif attr_name.startswith("error"):
+ error_xy = attr_name[:7]
+ arr = "arrayminus" if attr_name.endswith("minus") else "array"
+ if error_xy not in trace_patch:
+ trace_patch[error_xy] = {}
+ trace_patch[error_xy][arr] = trace_data[attr_value]
+ elif attr_name == "custom_data":
+ trace_patch["customdata"] = trace_data[attr_value].values
+ custom_data_len = len(attr_value) # number of custom data columns
+ elif attr_name == "hover_name":
if trace_spec.constructor not in [
go.Histogram,
go.Histogram2d,
go.Histogram2dContour,
]:
- result["hovertext"] = g[v]
+ trace_patch["hovertext"] = trace_data[attr_value]
if hover_header == "":
hover_header = "%{hovertext}
"
- elif k == "hover_data":
+ elif attr_name == "hover_data":
if trace_spec.constructor not in [
go.Histogram,
go.Histogram2d,
go.Histogram2dContour,
]:
- for col in v:
+ for col in attr_value:
try:
position = args["custom_data"].index(col)
except (ValueError, AttributeError, KeyError):
position = custom_data_len
custom_data_len += 1
- if "customdata" in result:
- result["customdata"] = np.hstack(
- (result["customdata"], g[col].values[:, None])
+ if "customdata" in trace_patch:
+ trace_patch["customdata"] = np.hstack(
+ (
+ trace_patch["customdata"],
+ trace_data[col].values[:, None],
+ )
)
else:
- result["customdata"] = g[col].values[:, None]
- v_label_col = get_decorated_label(args, col, None)
- mapping_labels[v_label_col] = "%%{customdata[%d]}" % (position)
- elif k == "color":
+ trace_patch["customdata"] = trace_data[col].values[
+ :, None
+ ]
+ attr_label_col = get_decorated_label(args, col, None)
+ mapping_labels[attr_label_col] = "%%{customdata[%d]}" % (
+ position
+ )
+ elif attr_name == "color":
if trace_spec.constructor in [go.Choropleth, go.Choroplethmapbox]:
- result["z"] = g[v]
- result["coloraxis"] = "coloraxis1"
- mapping_labels[v_label] = "%{z}"
+ trace_patch["z"] = trace_data[attr_value]
+ trace_patch["coloraxis"] = "coloraxis1"
+ mapping_labels[attr_label] = "%{z}"
elif trace_spec.constructor in [
go.Sunburst,
go.Treemap,
go.Pie,
go.Funnelarea,
]:
- if "marker" not in result:
- result["marker"] = dict()
+ if "marker" not in trace_patch:
+ trace_patch["marker"] = dict()
if args.get("color_is_continuous"):
- result["marker"]["colors"] = g[v]
- result["marker"]["coloraxis"] = "coloraxis1"
- mapping_labels[v_label] = "%{color}"
+ trace_patch["marker"]["colors"] = trace_data[attr_value]
+ trace_patch["marker"]["coloraxis"] = "coloraxis1"
+ mapping_labels[attr_label] = "%{color}"
else:
- result["marker"]["colors"] = []
+ trace_patch["marker"]["colors"] = []
mapping = {}
- for cat in g[v]:
+ for cat in trace_data[attr_value]:
if mapping.get(cat) is None:
mapping[cat] = args["color_discrete_sequence"][
len(mapping) % len(args["color_discrete_sequence"])
]
- result["marker"]["colors"].append(mapping[cat])
+ trace_patch["marker"]["colors"].append(mapping[cat])
else:
colorable = "marker"
if trace_spec.constructor in [go.Parcats, go.Parcoords]:
colorable = "line"
- if colorable not in result:
- result[colorable] = dict()
- result[colorable]["color"] = g[v]
- result[colorable]["coloraxis"] = "coloraxis1"
- mapping_labels[v_label] = "%%{%s.color}" % colorable
- elif k == "animation_group":
- result["ids"] = g[v]
- elif k == "locations":
- result[k] = g[v]
- mapping_labels[v_label] = "%{location}"
- elif k == "values":
- result[k] = g[v]
- _label = "value" if v_label == "values" else v_label
+ if colorable not in trace_patch:
+ trace_patch[colorable] = dict()
+ trace_patch[colorable]["color"] = trace_data[attr_value]
+ trace_patch[colorable]["coloraxis"] = "coloraxis1"
+ mapping_labels[attr_label] = "%%{%s.color}" % colorable
+ elif attr_name == "animation_group":
+ trace_patch["ids"] = trace_data[attr_value]
+ elif attr_name == "locations":
+ trace_patch[attr_name] = trace_data[attr_value]
+ mapping_labels[attr_label] = "%{location}"
+ elif attr_name == "values":
+ trace_patch[attr_name] = trace_data[attr_value]
+ _label = "value" if attr_label == "values" else attr_label
mapping_labels[_label] = "%{value}"
- elif k == "parents":
- result[k] = g[v]
- _label = "parent" if v_label == "parents" else v_label
+ elif attr_name == "parents":
+ trace_patch[attr_name] = trace_data[attr_value]
+ _label = "parent" if attr_label == "parents" else attr_label
mapping_labels[_label] = "%{parent}"
- elif k == "ids":
- result[k] = g[v]
- _label = "id" if v_label == "ids" else v_label
+ elif attr_name == "ids":
+ trace_patch[attr_name] = trace_data[attr_value]
+ _label = "id" if attr_label == "ids" else attr_label
mapping_labels[_label] = "%{id}"
- elif k == "names":
+ elif attr_name == "names":
if trace_spec.constructor in [
go.Sunburst,
go.Treemap,
go.Pie,
go.Funnelarea,
]:
- result["labels"] = g[v]
- _label = "label" if v_label == "names" else v_label
+ trace_patch["labels"] = trace_data[attr_value]
+ _label = "label" if attr_label == "names" else attr_label
mapping_labels[_label] = "%{label}"
else:
- result[k] = g[v]
+ trace_patch[attr_name] = trace_data[attr_value]
else:
- if v:
- result[k] = g[v]
- mapping_labels[v_label] = "%%{%s}" % k
+ if attr_value:
+ trace_patch[attr_name] = trace_data[attr_value]
+ mapping_labels[attr_label] = "%%{%s}" % attr_name
if trace_spec.constructor not in [
go.Parcoords,
go.Parcats,
]:
hover_lines = [k + "=" + v for k, v in mapping_labels.items()]
- result["hovertemplate"] = hover_header + "
".join(hover_lines)
- return result, fit_results
+ trace_patch["hovertemplate"] = hover_header + "
".join(hover_lines)
+ return trace_patch, fit_results
def configure_axes(args, constructor, fig, orders):
@@ -1015,8 +1027,8 @@ def _check_dataframe_all_leaves(df):
null_indices = np.nonzero(null_mask.any(axis=1).values)[0]
for null_row_index in null_indices:
row = null_mask.iloc[null_row_index]
- indices = np.nonzero(row.values)[0]
- if not row[indices[0] :].all():
+ i = np.nonzero(row.values)[0][0]
+ if not row[i:].all():
raise ValueError(
"None entries cannot have not-None children",
df_sorted.iloc[null_row_index],
@@ -1058,6 +1070,7 @@ def process_dataframe_hierarchy(args):
path = [new_col_name if x == col_name else x for x in path]
df[new_col_name] = series_to_copy
# ------------ Define aggregation functions --------------------------------
+
def aggfunc_discrete(x):
uniques = x.unique()
if len(uniques) == 1: