diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index e94a79d3954..dd670c26d40 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -137,7 +137,7 @@ def make_mapping(args, variable): ) -def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): +def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): """Populates a dict with arguments to update trace Parameters @@ -147,7 +147,7 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): trace_spec : NamedTuple which kind of trace to be used (has constructor, marginal etc. attributes) - g : pandas DataFrame + trace_data : pandas DataFrame data mapping_labels : dict to be used for hovertemplate @@ -156,87 +156,92 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): Returns ------- - result : dict + trace_patch : dict dict to be used to update trace fit_results : dict fit information to be used for trendlines """ if "line_close" in args and args["line_close"]: - g = g.append(g.iloc[0]) - result = trace_spec.trace_patch.copy() or {} + trace_data = trace_data.append(trace_data.iloc[0]) + trace_patch = trace_spec.trace_patch.copy() or {} fit_results = None hover_header = "" custom_data_len = 0 - for k in trace_spec.attrs: - v = args[k] - v_label = get_decorated_label(args, v, k) - if k == "dimensions": + for attr_name in trace_spec.attrs: + attr_value = args[attr_name] + attr_label = get_decorated_label(args, attr_value, attr_name) + if attr_name == "dimensions": dims = [ (name, column) - for (name, column) in g.iteritems() - if ((not v) or (name in v)) + for (name, column) in trace_data.iteritems() + if ((not attr_value) or (name in attr_value)) and ( trace_spec.constructor != go.Parcoords or args["data_frame"][name].dtype.kind in "bifc" ) and ( trace_spec.constructor != go.Parcats - or (v is not None and name in v) + or (attr_value is not None and name in attr_value) or len(args["data_frame"][name].unique()) <= args["dimensions_max_cardinality"] ) ] - result["dimensions"] = [ + trace_patch["dimensions"] = [ dict(label=get_label(args, name), values=column.values) for (name, column) in dims ] if trace_spec.constructor == go.Splom: - for d in result["dimensions"]: + for d in trace_patch["dimensions"]: d["axis"] = dict(matches=True) mapping_labels["%{xaxis.title.text}"] = "%{x}" mapping_labels["%{yaxis.title.text}"] = "%{y}" elif ( - v is not None - or (trace_spec.constructor == go.Histogram and k in ["x", "y"]) + attr_value is not None + or (trace_spec.constructor == go.Histogram and attr_name in ["x", "y"]) or ( trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour] - and k == "z" + and attr_name == "z" ) ): - if k == "size": - if "marker" not in result: - result["marker"] = dict() - result["marker"]["size"] = g[v] - result["marker"]["sizemode"] = "area" - result["marker"]["sizeref"] = sizeref - mapping_labels[v_label] = "%{marker.size}" - elif k == "marginal_x": + if attr_name == "size": + if "marker" not in trace_patch: + trace_patch["marker"] = dict() + trace_patch["marker"]["size"] = trace_data[attr_value] + trace_patch["marker"]["sizemode"] = "area" + trace_patch["marker"]["sizeref"] = sizeref + mapping_labels[attr_label] = "%{marker.size}" + elif attr_name == "marginal_x": if trace_spec.constructor == go.Histogram: mapping_labels["count"] = "%{y}" - elif k == "marginal_y": + elif attr_name == "marginal_y": if trace_spec.constructor == go.Histogram: mapping_labels["count"] = "%{x}" - elif k == "trendline": - if v in ["ols", "lowess"] and args["x"] and args["y"] and len(g) > 1: + elif attr_name == "trendline": + if ( + attr_value in ["ols", "lowess"] + and args["x"] + and args["y"] + and len(trace_data) > 1 + ): import statsmodels.api as sm # sorting is bad but trace_specs with "trendline" have no other attrs - g2 = g.sort_values(by=args["x"]) - y = g2[args["y"]] - x = g2[args["x"]] - result["x"] = x + sorted_trace_data = trace_data.sort_values(by=args["x"]) + y = sorted_trace_data[args["y"]] + x = sorted_trace_data[args["x"]] + trace_patch["x"] = x if x.dtype.type == np.datetime64: x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds - if v == "lowess": + if attr_value == "lowess": trendline = sm.nonparametric.lowess(y, x) - result["y"] = trendline[:, 1] + trace_patch["y"] = trendline[:, 1] hover_header = "LOWESS trendline

" - elif v == "ols": + elif attr_value == "ols": fit_results = sm.OLS(y.values, sm.add_constant(x.values)).fit() - result["y"] = fit_results.predict() + trace_patch["y"] = fit_results.predict() hover_header = "OLS trendline
" hover_header += "%s = %g * %s + %g
" % ( args["y"], @@ -250,120 +255,127 @@ def make_trace_kwargs(args, trace_spec, g, mapping_labels, sizeref): mapping_labels[get_label(args, args["x"])] = "%{x}" mapping_labels[get_label(args, args["y"])] = "%{y} (trend)" - elif k.startswith("error"): - error_xy = k[:7] - arr = "arrayminus" if k.endswith("minus") else "array" - if error_xy not in result: - result[error_xy] = {} - result[error_xy][arr] = g[v] - elif k == "custom_data": - result["customdata"] = g[v].values - custom_data_len = len(v) # number of custom data columns - elif k == "hover_name": + elif attr_name.startswith("error"): + error_xy = attr_name[:7] + arr = "arrayminus" if attr_name.endswith("minus") else "array" + if error_xy not in trace_patch: + trace_patch[error_xy] = {} + trace_patch[error_xy][arr] = trace_data[attr_value] + elif attr_name == "custom_data": + trace_patch["customdata"] = trace_data[attr_value].values + custom_data_len = len(attr_value) # number of custom data columns + elif attr_name == "hover_name": if trace_spec.constructor not in [ go.Histogram, go.Histogram2d, go.Histogram2dContour, ]: - result["hovertext"] = g[v] + trace_patch["hovertext"] = trace_data[attr_value] if hover_header == "": hover_header = "%{hovertext}

" - elif k == "hover_data": + elif attr_name == "hover_data": if trace_spec.constructor not in [ go.Histogram, go.Histogram2d, go.Histogram2dContour, ]: - for col in v: + for col in attr_value: try: position = args["custom_data"].index(col) except (ValueError, AttributeError, KeyError): position = custom_data_len custom_data_len += 1 - if "customdata" in result: - result["customdata"] = np.hstack( - (result["customdata"], g[col].values[:, None]) + if "customdata" in trace_patch: + trace_patch["customdata"] = np.hstack( + ( + trace_patch["customdata"], + trace_data[col].values[:, None], + ) ) else: - result["customdata"] = g[col].values[:, None] - v_label_col = get_decorated_label(args, col, None) - mapping_labels[v_label_col] = "%%{customdata[%d]}" % (position) - elif k == "color": + trace_patch["customdata"] = trace_data[col].values[ + :, None + ] + attr_label_col = get_decorated_label(args, col, None) + mapping_labels[attr_label_col] = "%%{customdata[%d]}" % ( + position + ) + elif attr_name == "color": if trace_spec.constructor in [go.Choropleth, go.Choroplethmapbox]: - result["z"] = g[v] - result["coloraxis"] = "coloraxis1" - mapping_labels[v_label] = "%{z}" + trace_patch["z"] = trace_data[attr_value] + trace_patch["coloraxis"] = "coloraxis1" + mapping_labels[attr_label] = "%{z}" elif trace_spec.constructor in [ go.Sunburst, go.Treemap, go.Pie, go.Funnelarea, ]: - if "marker" not in result: - result["marker"] = dict() + if "marker" not in trace_patch: + trace_patch["marker"] = dict() if args.get("color_is_continuous"): - result["marker"]["colors"] = g[v] - result["marker"]["coloraxis"] = "coloraxis1" - mapping_labels[v_label] = "%{color}" + trace_patch["marker"]["colors"] = trace_data[attr_value] + trace_patch["marker"]["coloraxis"] = "coloraxis1" + mapping_labels[attr_label] = "%{color}" else: - result["marker"]["colors"] = [] + trace_patch["marker"]["colors"] = [] mapping = {} - for cat in g[v]: + for cat in trace_data[attr_value]: if mapping.get(cat) is None: mapping[cat] = args["color_discrete_sequence"][ len(mapping) % len(args["color_discrete_sequence"]) ] - result["marker"]["colors"].append(mapping[cat]) + trace_patch["marker"]["colors"].append(mapping[cat]) else: colorable = "marker" if trace_spec.constructor in [go.Parcats, go.Parcoords]: colorable = "line" - if colorable not in result: - result[colorable] = dict() - result[colorable]["color"] = g[v] - result[colorable]["coloraxis"] = "coloraxis1" - mapping_labels[v_label] = "%%{%s.color}" % colorable - elif k == "animation_group": - result["ids"] = g[v] - elif k == "locations": - result[k] = g[v] - mapping_labels[v_label] = "%{location}" - elif k == "values": - result[k] = g[v] - _label = "value" if v_label == "values" else v_label + if colorable not in trace_patch: + trace_patch[colorable] = dict() + trace_patch[colorable]["color"] = trace_data[attr_value] + trace_patch[colorable]["coloraxis"] = "coloraxis1" + mapping_labels[attr_label] = "%%{%s.color}" % colorable + elif attr_name == "animation_group": + trace_patch["ids"] = trace_data[attr_value] + elif attr_name == "locations": + trace_patch[attr_name] = trace_data[attr_value] + mapping_labels[attr_label] = "%{location}" + elif attr_name == "values": + trace_patch[attr_name] = trace_data[attr_value] + _label = "value" if attr_label == "values" else attr_label mapping_labels[_label] = "%{value}" - elif k == "parents": - result[k] = g[v] - _label = "parent" if v_label == "parents" else v_label + elif attr_name == "parents": + trace_patch[attr_name] = trace_data[attr_value] + _label = "parent" if attr_label == "parents" else attr_label mapping_labels[_label] = "%{parent}" - elif k == "ids": - result[k] = g[v] - _label = "id" if v_label == "ids" else v_label + elif attr_name == "ids": + trace_patch[attr_name] = trace_data[attr_value] + _label = "id" if attr_label == "ids" else attr_label mapping_labels[_label] = "%{id}" - elif k == "names": + elif attr_name == "names": if trace_spec.constructor in [ go.Sunburst, go.Treemap, go.Pie, go.Funnelarea, ]: - result["labels"] = g[v] - _label = "label" if v_label == "names" else v_label + trace_patch["labels"] = trace_data[attr_value] + _label = "label" if attr_label == "names" else attr_label mapping_labels[_label] = "%{label}" else: - result[k] = g[v] + trace_patch[attr_name] = trace_data[attr_value] else: - if v: - result[k] = g[v] - mapping_labels[v_label] = "%%{%s}" % k + if attr_value: + trace_patch[attr_name] = trace_data[attr_value] + mapping_labels[attr_label] = "%%{%s}" % attr_name if trace_spec.constructor not in [ go.Parcoords, go.Parcats, ]: hover_lines = [k + "=" + v for k, v in mapping_labels.items()] - result["hovertemplate"] = hover_header + "
".join(hover_lines) - return result, fit_results + trace_patch["hovertemplate"] = hover_header + "
".join(hover_lines) + return trace_patch, fit_results def configure_axes(args, constructor, fig, orders): @@ -1015,8 +1027,8 @@ def _check_dataframe_all_leaves(df): null_indices = np.nonzero(null_mask.any(axis=1).values)[0] for null_row_index in null_indices: row = null_mask.iloc[null_row_index] - indices = np.nonzero(row.values)[0] - if not row[indices[0] :].all(): + i = np.nonzero(row.values)[0][0] + if not row[i:].all(): raise ValueError( "None entries cannot have not-None children", df_sorted.iloc[null_row_index], @@ -1058,6 +1070,7 @@ def process_dataframe_hierarchy(args): path = [new_col_name if x == col_name else x for x in path] df[new_col_name] = series_to_copy # ------------ Define aggregation functions -------------------------------- + def aggfunc_discrete(x): uniques = x.unique() if len(uniques) == 1: