From 152f87f60fa3079682440de49af584467ecdf572 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 13 Sep 2019 09:18:12 -0400 Subject: [PATCH 01/69] more flexible type of input arguments for px functions --- .../plotly/plotly/express/_chart_types.py | 50 +++++++++---------- .../python/plotly/plotly/express/_core.py | 43 +++++++++++++++- .../tests/test_core/test_px/test_px_input.py | 29 +++++++++++ 3 files changed, 95 insertions(+), 27 deletions(-) create mode 100644 packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 7a08d2e6ac4..c17c5a657e5 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -4,7 +4,7 @@ def scatter( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -57,7 +57,7 @@ def scatter( def density_contour( - data_frame, + data_frame=None, x=None, y=None, z=None, @@ -113,7 +113,7 @@ def density_contour( def density_heatmap( - data_frame, + data_frame=None, x=None, y=None, z=None, @@ -167,7 +167,7 @@ def density_heatmap( def line( - data_frame, + data_frame=None, x=None, y=None, line_group=None, @@ -212,7 +212,7 @@ def line( def area( - data_frame, + data_frame=None, x=None, y=None, line_group=None, @@ -257,7 +257,7 @@ def area( def bar( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -307,7 +307,7 @@ def bar( def histogram( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -364,7 +364,7 @@ def histogram( def violin( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -415,7 +415,7 @@ def violin( def box( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -461,7 +461,7 @@ def box( def strip( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -511,7 +511,7 @@ def strip( def scatter_3d( - data_frame, + data_frame=None, x=None, y=None, z=None, @@ -562,7 +562,7 @@ def scatter_3d( def line_3d( - data_frame, + data_frame=None, x=None, y=None, z=None, @@ -608,7 +608,7 @@ def line_3d( def scatter_ternary( - data_frame, + data_frame=None, a=None, b=None, c=None, @@ -647,7 +647,7 @@ def scatter_ternary( def line_ternary( - data_frame, + data_frame=None, a=None, b=None, c=None, @@ -682,7 +682,7 @@ def line_ternary( def scatter_polar( - data_frame, + data_frame=None, r=None, theta=None, color=None, @@ -726,7 +726,7 @@ def scatter_polar( def line_polar( - data_frame, + data_frame=None, r=None, theta=None, color=None, @@ -766,7 +766,7 @@ def line_polar( def bar_polar( - data_frame, + data_frame=None, r=None, theta=None, color=None, @@ -804,7 +804,7 @@ def bar_polar( def choropleth( - data_frame, + data_frame=None, lat=None, lon=None, locations=None, @@ -844,7 +844,7 @@ def choropleth( def scatter_geo( - data_frame, + data_frame=None, lat=None, lon=None, locations=None, @@ -888,7 +888,7 @@ def scatter_geo( def line_geo( - data_frame, + data_frame=None, lat=None, lon=None, locations=None, @@ -930,7 +930,7 @@ def line_geo( def scatter_mapbox( - data_frame, + data_frame=None, lat=None, lon=None, color=None, @@ -966,7 +966,7 @@ def scatter_mapbox( def line_mapbox( - data_frame, + data_frame=None, lat=None, lon=None, color=None, @@ -997,7 +997,7 @@ def line_mapbox( def scatter_matrix( - data_frame, + data_frame=None, dimensions=None, color=None, symbol=None, @@ -1035,7 +1035,7 @@ def scatter_matrix( def parallel_coordinates( - data_frame, + data_frame=None, dimensions=None, color=None, labels={}, @@ -1059,7 +1059,7 @@ def parallel_coordinates( def parallel_categories( - data_frame, + data_frame=None, dimensions=None, color=None, labels={}, diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 3d55d3004da..ba28ad71ea8 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -5,7 +5,7 @@ from _plotly_utils.basevalidators import ColorscaleValidator from .colors import qualitative, sequential import math -import pandas +import pandas as pd import numpy as np from plotly.subplots import ( @@ -754,6 +754,42 @@ def apply_default_cascade(args): args["marginal_x"] = None +def build_or_augment_dataframe(args, attrables, array_attrables): + """ + Constructs an implicit dataframe and modifies `args` in-place. + `attrables` is a list of keys into `args`, all of whose corresponding + values are converted into columns of a dataframe. + Used to be support calls to plotting function that elide a dataframe + argument; for example `scatter(x=[1,2], y=[3,4])`. + """ + if args.get("data_frame") is None: + df = pd.DataFrame() + else: + df = args["data_frame"] + df = df.reset_index() + data_frame_columns = {} + for field in attrables: + if field in array_attrables: + continue + argument = args.get(field) + if argument is None: + continue + elif isinstance(argument, str) and argument in df.columns: + continue + else: # args[field] should be an array or df or index now + try: + col_name = argument.name # pandas df + except AttributeError: + labels = args.get("labels") + col_name = labels[field] if labels and labels.get(field) else field + df[col_name] = argument + # This sets the label of an attribute to be + # the name of the attribute. + args[field] = col_name + args["data_frame"] = df + return args + + def infer_config(args, constructor, trace_patch): # Declare all supported attributes, across all plot types attrables = ( @@ -766,6 +802,9 @@ def infer_config(args, constructor, trace_patch): array_attrables = ["dimensions", "custom_data", "hover_data"] group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] + all_attrables = attrables + group_attrables + ["color"] + build_or_augment_dataframe(args, all_attrables, array_attrables) + # Validate that the strings provided as attribute values reference columns # in the provided data_frame df_columns = args["data_frame"].columns @@ -1095,7 +1134,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): fig.layout.update(layout_patch) fig.frames = frame_list if len(frames) > 1 else [] - fig._px_trendlines = pandas.DataFrame(trendline_rows) + fig._px_trendlines = pd.DataFrame(trendline_rows) configure_axes(args, constructor, fig, orders) configure_animation_controls(args, constructor, fig) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py new file mode 100644 index 00000000000..d488af08a2f --- /dev/null +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -0,0 +1,29 @@ +import plotly.express as px +import numpy as np +import pandas as pd + + +def test_numpy(): + fig = px.scatter(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9]) + + +def test_numpy_labels(): + fig = px.scatter( + x=[1, 2, 3], y=[2, 3, 4], labels={"x": "time"} + ) # other labels will be kw arguments + assert fig.data[0]["hovertemplate"] == "time=%{x}
y=%{y}" + + +def test_with_index(): + tips = px.data.tips() + fig = px.scatter(tips, x="index", y="total_bill") + fig = px.scatter(tips, x="index", y=tips.total_bill) + assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" + # I was not expecting this to work but it does... + fig = px.scatter(tips, x="index", y=10 * tips.total_bill) + assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" + + +def test_mixed_case(): + df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25])) + fig = px.scatter(df, x="time", y="temperature", color=[1, 3, 9]) From fe4eda3d43f3c097d9de5a49fc9e2d0a26bb7ffa Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 13 Sep 2019 11:26:25 -0400 Subject: [PATCH 02/69] wrong col name case --- packages/python/plotly/plotly/express/_core.py | 2 +- .../plotly/plotly/tests/test_core/test_px/test_px_input.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index ba28ad71ea8..baf6783c909 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -774,7 +774,7 @@ def build_or_augment_dataframe(args, attrables, array_attrables): argument = args.get(field) if argument is None: continue - elif isinstance(argument, str) and argument in df.columns: + elif isinstance(argument, str): continue else: # args[field] should be an array or df or index now try: diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index d488af08a2f..6c076a54032 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -1,6 +1,7 @@ import plotly.express as px import numpy as np import pandas as pd +import pytest def test_numpy(): @@ -27,3 +28,8 @@ def test_with_index(): def test_mixed_case(): df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25])) fig = px.scatter(df, x="time", y="temperature", color=[1, 3, 9]) + + +def test_wrong_column_name(): + with pytest.raises(ValueError): + fig = px.scatter(px.data.tips(), x="bla", y="wrong") From 86937ea107247d20e2ce6c1ae56abf7200d81da3 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 13 Sep 2019 11:45:11 -0400 Subject: [PATCH 03/69] corner case of functions grabbing all cols --- packages/python/plotly/plotly/express/_core.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index baf6783c909..0a0fa1cd066 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -754,7 +754,7 @@ def apply_default_cascade(args): args["marginal_x"] = None -def build_or_augment_dataframe(args, attrables, array_attrables): +def build_or_augment_dataframe(args, attrables, array_attrables, constructor): """ Constructs an implicit dataframe and modifies `args` in-place. `attrables` is a list of keys into `args`, all of whose corresponding @@ -766,7 +766,9 @@ def build_or_augment_dataframe(args, attrables, array_attrables): df = pd.DataFrame() else: df = args["data_frame"] - df = df.reset_index() + # we don't want to add an index to functions grabbing all cols + if constructor != go.Splom and constructor != go.Parcoords: + df = df.reset_index() data_frame_columns = {} for field in attrables: if field in array_attrables: @@ -803,7 +805,7 @@ def infer_config(args, constructor, trace_patch): group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] all_attrables = attrables + group_attrables + ["color"] - build_or_augment_dataframe(args, all_attrables, array_attrables) + build_or_augment_dataframe(args, all_attrables, array_attrables, constructor) # Validate that the strings provided as attribute values reference columns # in the provided data_frame From 1ad0f5ce846fd6f71163987a1442169fbb16c79b Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 13 Sep 2019 21:06:38 -0400 Subject: [PATCH 04/69] better behavior of index, more tests --- .../python/plotly/plotly/express/_core.py | 20 ++++-- .../tests/test_core/test_px/test_px_input.py | 71 ++++++++++++++++++- 2 files changed, 82 insertions(+), 9 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 0a0fa1cd066..0bbc394a6ce 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -766,11 +766,9 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): df = pd.DataFrame() else: df = args["data_frame"] - # we don't want to add an index to functions grabbing all cols - if constructor != go.Splom and constructor != go.Parcoords: - df = df.reset_index() data_frame_columns = {} for field in attrables: + labels = args.get("labels") if field in array_attrables: continue argument = args.get(field) @@ -778,16 +776,24 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): continue elif isinstance(argument, str): continue + elif isinstance(argument, pd.core.indexes.range.RangeIndex): + col_name = argument.name if argument.name else "index" + print (col_name, labels) + col_name = labels[field] if labels and labels.get(field) else col_name + print (col_name) + try: + df.insert(0, col_name, argument) + except ValueError: + pass else: # args[field] should be an array or df or index now try: col_name = argument.name # pandas df except AttributeError: - labels = args.get("labels") col_name = labels[field] if labels and labels.get(field) else field df[col_name] = argument - # This sets the label of an attribute to be - # the name of the attribute. - args[field] = col_name + # This sets the label of an attribute to be + # the name of the attribute. + args[field] = col_name args["data_frame"] = df return args diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 6c076a54032..5f553f5a808 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -2,6 +2,21 @@ import numpy as np import pandas as pd import pytest +import plotly.graph_objects as go +import plotly +from plotly.express._core import build_or_augment_dataframe + +attrables = ( + ["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"] + + ["custom_data", "hover_name", "hover_data", "text"] + + ["error_x", "error_x_minus"] + + ["error_y", "error_y_minus", "error_z", "error_z_minus"] + + ["lat", "lon", "locations", "animation_group"] +) +array_attrables = ["dimensions", "custom_data", "hover_data"] +group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] + +all_attrables = attrables + group_attrables + ["color"] def test_numpy(): @@ -17,12 +32,16 @@ def test_numpy_labels(): def test_with_index(): tips = px.data.tips() - fig = px.scatter(tips, x="index", y="total_bill") - fig = px.scatter(tips, x="index", y=tips.total_bill) + fig = px.scatter(tips, x=tips.index, y="total_bill") + fig = px.scatter(tips, x=tips.index, y=tips.total_bill) + tips = px.data.tips() + fig = px.scatter(tips, x=tips.index, y=tips.total_bill) assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" # I was not expecting this to work but it does... fig = px.scatter(tips, x="index", y=10 * tips.total_bill) assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" + fig = px.scatter(tips, x=tips.index, y=tips.total_bill, labels={"index": "number"}) + assert fig.data[0]["hovertemplate"] == "number=%{x}
total_bill=%{y}" def test_mixed_case(): @@ -33,3 +52,51 @@ def test_mixed_case(): def test_wrong_column_name(): with pytest.raises(ValueError): fig = px.scatter(px.data.tips(), x="bla", y="wrong") + + +def test_build_df_from_lists(): + # Just lists + args = dict(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9]) + output = {key: key for key in args} + df = pd.DataFrame(args) + out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) + assert df.equals(out["data_frame"]) + out.pop("data_frame") + assert out == output + + # Arrays + args = dict(x=np.array([1, 2, 3]), y=np.array([2, 3, 4]), color=[1, 3, 9]) + output = {key: key for key in args} + df = pd.DataFrame(args) + out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) + assert df.equals(out["data_frame"]) + out.pop("data_frame") + assert out == output + + # Lists, changing one label + labels = {"x": "time"} + args = dict(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9], labels=labels) + output = {key: key for key in args} + output.update(labels) + args_wo_labels = args.copy() + _ = args_wo_labels.pop("labels") + df = pd.DataFrame(args_wo_labels).rename(columns=labels) + out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) + assert df.equals(out["data_frame"]) + + +def test_build_df_with_index(): + tips = px.data.tips() + args = dict(data_frame=tips, x=tips.index, y="total_bill") + changed_output = dict(x="index") + out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) + assert out["data_frame"].equals(tips) + out.pop("data_frame") + assert out == args + + tips = px.data.tips() + args = dict(data_frame=tips, x="index", y=tips.total_bill) + out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) + assert out["data_frame"].equals(tips) + out.pop("data_frame") + assert out == args From 7d0e985a5e655c18023d53245759370ee440968d Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 13 Sep 2019 21:25:06 -0400 Subject: [PATCH 05/69] comment code --- packages/python/plotly/plotly/express/_core.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 0bbc394a6ce..a79bb8ffd93 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -762,30 +762,30 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): Used to be support calls to plotting function that elide a dataframe argument; for example `scatter(x=[1,2], y=[3,4])`. """ + # This will be changed so that we start from an empty dataframe if args.get("data_frame") is None: df = pd.DataFrame() else: df = args["data_frame"] - data_frame_columns = {} for field in attrables: - labels = args.get("labels") + labels = args.get("labels") # labels or None + # hack, needs to be changed when we start from empty df if field in array_attrables: continue argument = args.get(field) if argument is None: continue - elif isinstance(argument, str): + elif isinstance(argument, str): # needs to change continue + # Case of index elif isinstance(argument, pd.core.indexes.range.RangeIndex): col_name = argument.name if argument.name else "index" - print (col_name, labels) col_name = labels[field] if labels and labels.get(field) else col_name - print (col_name) try: df.insert(0, col_name, argument) - except ValueError: - pass - else: # args[field] should be an array or df or index now + except ValueError: # if col named index already exists, replace + df['col_name'] = argument + else: # args[field] should be an array or df column try: col_name = argument.name # pandas df except AttributeError: From 915a5a1210d9d8ae1203d21c91382e9aefa7880c Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Sat, 14 Sep 2019 21:45:11 -0400 Subject: [PATCH 06/69] black --- packages/python/plotly/plotly/express/_core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index a79bb8ffd93..09fd473d447 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -768,14 +768,14 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): else: df = args["data_frame"] for field in attrables: - labels = args.get("labels") # labels or None + labels = args.get("labels") # labels or None # hack, needs to be changed when we start from empty df if field in array_attrables: continue argument = args.get(field) if argument is None: continue - elif isinstance(argument, str): # needs to change + elif isinstance(argument, str): # needs to change continue # Case of index elif isinstance(argument, pd.core.indexes.range.RangeIndex): @@ -783,8 +783,8 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): col_name = labels[field] if labels and labels.get(field) else col_name try: df.insert(0, col_name, argument) - except ValueError: # if col named index already exists, replace - df['col_name'] = argument + except ValueError: # if col named index already exists, replace + df["col_name"] = argument else: # args[field] should be an array or df column try: col_name = argument.name # pandas df From 5d5ab81f7c9476d4780bd1b290488c37bce506c6 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Sat, 14 Sep 2019 21:50:52 -0400 Subject: [PATCH 07/69] debugging --- .../plotly/plotly/tests/test_core/test_px/test_px_input.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 5f553f5a808..d3ece62b5f9 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -60,6 +60,8 @@ def test_build_df_from_lists(): output = {key: key for key in args} df = pd.DataFrame(args) out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) + print(df) + print(out["data_frame"]) assert df.equals(out["data_frame"]) out.pop("data_frame") assert out == output From ea0fa6a1777aa3c71e4fdd8260b03e6bb72e5b47 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Sat, 14 Sep 2019 22:02:30 -0400 Subject: [PATCH 08/69] relax column ordering in tests --- .../plotly/plotly/tests/test_core/test_px/test_px_input.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index d3ece62b5f9..3f870f0a7f4 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -60,9 +60,7 @@ def test_build_df_from_lists(): output = {key: key for key in args} df = pd.DataFrame(args) out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) - print(df) - print(out["data_frame"]) - assert df.equals(out["data_frame"]) + assert df.sort_index(axis=1) == out["data_frame"].sort_index(axis=1) out.pop("data_frame") assert out == output @@ -71,7 +69,7 @@ def test_build_df_from_lists(): output = {key: key for key in args} df = pd.DataFrame(args) out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) - assert df.equals(out["data_frame"]) + assert df.sort_index(axis=1) == out["data_frame"].sort_index(axis=1) out.pop("data_frame") assert out == output From e5f69536dc5090667a88d3dbe938808715249f66 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Sat, 14 Sep 2019 22:10:46 -0400 Subject: [PATCH 09/69] tests --- .../plotly/tests/test_core/test_px/test_px_input.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 3f870f0a7f4..63d0183bd67 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -5,6 +5,7 @@ import plotly.graph_objects as go import plotly from plotly.express._core import build_or_augment_dataframe +from pandas.util.testing import assert_frame_equal attrables = ( ["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"] @@ -60,7 +61,7 @@ def test_build_df_from_lists(): output = {key: key for key in args} df = pd.DataFrame(args) out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) - assert df.sort_index(axis=1) == out["data_frame"].sort_index(axis=1) + assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) out.pop("data_frame") assert out == output @@ -69,7 +70,7 @@ def test_build_df_from_lists(): output = {key: key for key in args} df = pd.DataFrame(args) out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) - assert df.sort_index(axis=1) == out["data_frame"].sort_index(axis=1) + assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) out.pop("data_frame") assert out == output @@ -82,7 +83,7 @@ def test_build_df_from_lists(): _ = args_wo_labels.pop("labels") df = pd.DataFrame(args_wo_labels).rename(columns=labels) out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) - assert df.equals(out["data_frame"]) + assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) def test_build_df_with_index(): @@ -90,13 +91,13 @@ def test_build_df_with_index(): args = dict(data_frame=tips, x=tips.index, y="total_bill") changed_output = dict(x="index") out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) - assert out["data_frame"].equals(tips) + assert_frame_equal(tips.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) out.pop("data_frame") assert out == args tips = px.data.tips() args = dict(data_frame=tips, x="index", y=tips.total_bill) out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) - assert out["data_frame"].equals(tips) + assert_frame_equal(tips.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) out.pop("data_frame") assert out == args From 29d7e18408b7a0b85b8a0359c3ab41245ea1e882 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 16 Sep 2019 10:55:10 -0400 Subject: [PATCH 10/69] array arguments --- .../python/plotly/plotly/express/_core.py | 63 +++++++++++-------- .../tests/test_core/test_px/test_px_input.py | 36 +++++++++++ 2 files changed, 73 insertions(+), 26 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 09fd473d447..f721052af67 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -767,33 +767,44 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): df = pd.DataFrame() else: df = args["data_frame"] - for field in attrables: - labels = args.get("labels") # labels or None - # hack, needs to be changed when we start from empty df - if field in array_attrables: - continue - argument = args.get(field) - if argument is None: - continue - elif isinstance(argument, str): # needs to change + labels = args.get("labels") # labels or None + for field_name in attrables: + argument_list = ( + [args.get(field_name)] + if field_name not in array_attrables + else args.get(field_name) + ) + if argument_list is None: continue - # Case of index - elif isinstance(argument, pd.core.indexes.range.RangeIndex): - col_name = argument.name if argument.name else "index" - col_name = labels[field] if labels and labels.get(field) else col_name - try: - df.insert(0, col_name, argument) - except ValueError: # if col named index already exists, replace - df["col_name"] = argument - else: # args[field] should be an array or df column - try: - col_name = argument.name # pandas df - except AttributeError: - col_name = labels[field] if labels and labels.get(field) else field - df[col_name] = argument - # This sets the label of an attribute to be - # the name of the attribute. - args[field] = col_name + field_list = ( + [field_name] + if field_name not in array_attrables + else [field_name + "_" + str(i) for i in range(len(argument_list))] + ) + for i, (argument, field) in enumerate(zip(argument_list, field_list)): + if argument is None: + continue + elif isinstance(argument, str): # needs to change + continue + # Case of index + elif isinstance(argument, pd.core.indexes.range.RangeIndex): + col_name = argument.name if argument.name else "index" + col_name = labels[field] if labels and labels.get(field) else col_name + try: + df.insert(0, col_name, argument) + except ValueError: # if col named index already exists, replace + df["col_name"] = argument + else: # args[field] should be an array or df column + try: + col_name = argument.name # pandas df + except AttributeError: + col_name = labels[field] if labels and labels.get(field) else field + df[col_name] = argument + # This sets the label of an attribute to be + if field_name not in array_attrables: + args[field_name] = col_name + else: + args[field_name][i] = col_name args["data_frame"] = df return args diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 63d0183bd67..ace3e61f4bf 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -50,11 +50,47 @@ def test_mixed_case(): fig = px.scatter(df, x="time", y="temperature", color=[1, 3, 9]) +def test_arrayattrable_numpy(): + tips = px.data.tips() + fig = px.scatter( + tips, x="total_bill", y="tip", hover_data=[np.random.random(tips.shape[0])] + ) + assert ( + fig.data[0]["hovertemplate"] + == "total_bill=%{x}
tip=%{y}
hover_data_0=%{customdata[0]}" + ) + tips = px.data.tips() + fig = px.scatter( + tips, + x="total_bill", + y="tip", + hover_data=[np.random.random(tips.shape[0])], + labels={"hover_data_0": "suppl"}, + ) + assert ( + fig.data[0]["hovertemplate"] + == "total_bill=%{x}
tip=%{y}
suppl=%{customdata[0]}" + ) + + def test_wrong_column_name(): with pytest.raises(ValueError): fig = px.scatter(px.data.tips(), x="bla", y="wrong") +def test_wrong_dimensions_of_array(): + with pytest.raises(ValueError) as err_msg: + fig = px.scatter(x=[1, 2, 3], y=[2, 3, 4, 5]) + assert "Length of values does not match length of index" in str(err_msg.value) + + +def test_wrong_dimensions_mixed_cqse(): + with pytest.raises(ValueError) as err_msg: + df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25])) + fig = px.scatter(df, x="time", y="temperature", color=[1, 3, 9, 5]) + assert "Length of values does not match length of index" in str(err_msg.value) + + def test_build_df_from_lists(): # Just lists args = dict(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9]) From 4a028a2365cb430bb15204908ba8e70a9d8b5598 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 16 Sep 2019 14:09:44 -0400 Subject: [PATCH 11/69] move column checks --- .../python/plotly/plotly/express/_core.py | 21 +++++++++++++------ .../tests/test_core/test_px/test_px_input.py | 19 ++++++++--------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index f721052af67..4ff5fe4c071 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -762,12 +762,13 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): Used to be support calls to plotting function that elide a dataframe argument; for example `scatter(x=[1,2], y=[3,4])`. """ - # This will be changed so that we start from an empty dataframe - if args.get("data_frame") is None: - df = pd.DataFrame() - else: - df = args["data_frame"] + df = pd.DataFrame() labels = args.get("labels") # labels or None + df_columns = ( + args["data_frame"].columns if args.get("data_frame") is not None else None + ) + if "symbol" in args: + attrables += ["symbol"] for field_name in attrables: argument_list = ( [args.get(field_name)] @@ -785,7 +786,15 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): if argument is None: continue elif isinstance(argument, str): # needs to change - continue + try: + df[argument] = args["data_frame"][argument] + continue + except KeyError: + raise ValueError( + "Value of '%s' is not the name of a column in 'data_frame'. " + "Expected one of %s but received: %s" + % (field, str(list(df_columns)), argument) + ) # Case of index elif isinstance(argument, pd.core.indexes.range.RangeIndex): col_name = argument.name if argument.name else "index" diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index ace3e61f4bf..f945d29dbb8 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -35,14 +35,20 @@ def test_with_index(): tips = px.data.tips() fig = px.scatter(tips, x=tips.index, y="total_bill") fig = px.scatter(tips, x=tips.index, y=tips.total_bill) - tips = px.data.tips() fig = px.scatter(tips, x=tips.index, y=tips.total_bill) assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" # I was not expecting this to work but it does... - fig = px.scatter(tips, x="index", y=10 * tips.total_bill) + fig = px.scatter(tips, x=tips.index, y=10 * tips.total_bill) assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" fig = px.scatter(tips, x=tips.index, y=tips.total_bill, labels={"index": "number"}) assert fig.data[0]["hovertemplate"] == "number=%{x}
total_bill=%{y}" + # We do not allow "x=index" + with pytest.raises(ValueError) as err_msg: + fig = px.scatter(tips, x="index", y="total_bill") + assert ( + "ValueError: Value of 'x' is not the name of a column in 'data_frame'" + in str(err_msg.value) + ) def test_mixed_case(): @@ -127,13 +133,6 @@ def test_build_df_with_index(): args = dict(data_frame=tips, x=tips.index, y="total_bill") changed_output = dict(x="index") out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) - assert_frame_equal(tips.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) - out.pop("data_frame") - assert out == args - - tips = px.data.tips() - args = dict(data_frame=tips, x="index", y=tips.total_bill) - out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) - assert_frame_equal(tips.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) + assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"]) out.pop("data_frame") assert out == args From ab35b4294ad603e8558ab19b87f10e2ffa6cebff Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 16 Sep 2019 14:54:32 -0400 Subject: [PATCH 12/69] case of dimensions --- packages/python/plotly/plotly/express/_core.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 4ff5fe4c071..64a4fdc0740 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -762,13 +762,18 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): Used to be support calls to plotting function that elide a dataframe argument; for example `scatter(x=[1,2], y=[3,4])`. """ - df = pd.DataFrame() + if constructor in [go.Splom, go.Parcats, go.Parcoords]: # we take all dimensions + df = args["data_frame"] + else: + df = pd.DataFrame() labels = args.get("labels") # labels or None df_columns = ( args["data_frame"].columns if args.get("data_frame") is not None else None ) - if "symbol" in args: - attrables += ["symbol"] + group_attrs = ["symbol", "line_dash"] + for group_attr in group_attrs: + if group_attr in args: + attrables += [group_attr] for field_name in attrables: argument_list = ( [args.get(field_name)] From e4b8835ad22c7c0dab97ea9cc323ba7948c219a3 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 16 Sep 2019 15:22:29 -0400 Subject: [PATCH 13/69] clean code + black --- .../python/plotly/plotly/express/_core.py | 47 ++++++++----------- .../tests/test_core/test_px/test_px_input.py | 6 +++ 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 64a4fdc0740..cb9406748ac 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -762,11 +762,19 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): Used to be support calls to plotting function that elide a dataframe argument; for example `scatter(x=[1,2], y=[3,4])`. """ + + # We start from an empty DataFrame except for the case of functions which + # implicitely need all dimensions: Splom, Parcats, Parcoords + # This could be refined when dimensions is given if constructor in [go.Splom, go.Parcats, go.Parcoords]: # we take all dimensions df = args["data_frame"] else: df = pd.DataFrame() + + # Retrieve labels (to change column names) labels = args.get("labels") # labels or None + + # Valid column names df_columns = ( args["data_frame"].columns if args.get("data_frame") is not None else None ) @@ -774,23 +782,29 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): for group_attr in group_attrs: if group_attr in args: attrables += [group_attr] + + # Loop over possible arguments for field_name in attrables: argument_list = ( [args.get(field_name)] if field_name not in array_attrables else args.get(field_name) ) - if argument_list is None: + if argument_list is None: # argument not specified, continue continue + # Argument name: field_name if the argument is a list + # Else we give names like ["hover_data_0, hover_data_1"] etc. field_list = ( [field_name] if field_name not in array_attrables else [field_name + "_" + str(i) for i in range(len(argument_list))] ) + # argument_list and field_list ready, iterate over them for i, (argument, field) in enumerate(zip(argument_list, field_list)): if argument is None: continue - elif isinstance(argument, str): # needs to change + elif isinstance(argument, str): # just a column name + # Check validity of column name try: df[argument] = args["data_frame"][argument] continue @@ -807,14 +821,15 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): try: df.insert(0, col_name, argument) except ValueError: # if col named index already exists, replace - df["col_name"] = argument - else: # args[field] should be an array or df column + df[col_name] = argument + # Case of numpy array or df column + else: try: col_name = argument.name # pandas df except AttributeError: col_name = labels[field] if labels and labels.get(field) else field df[col_name] = argument - # This sets the label of an attribute to be + # Update argument with column name now that column exists if field_name not in array_attrables: args[field_name] = col_name else: @@ -838,28 +853,6 @@ def infer_config(args, constructor, trace_patch): all_attrables = attrables + group_attrables + ["color"] build_or_augment_dataframe(args, all_attrables, array_attrables, constructor) - # Validate that the strings provided as attribute values reference columns - # in the provided data_frame - df_columns = args["data_frame"].columns - - for attr in attrables + group_attrables + ["color"]: - if attr in args and args[attr] is not None: - maybe_col_list = [args[attr]] if attr not in array_attrables else args[attr] - for maybe_col in maybe_col_list: - try: - in_cols = maybe_col in df_columns - except TypeError: - in_cols = False - if not in_cols: - value_str = ( - "Element of value" if attr in array_attrables else "Value" - ) - raise ValueError( - "%s of '%s' is not the name of a column in 'data_frame'. " - "Expected one of %s but received: %s" - % (value_str, attr, str(list(df_columns)), str(maybe_col)) - ) - attrs = [k for k in attrables if k in args] grouped_attrs = [] diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index f945d29dbb8..0107a1f3614 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -136,3 +136,9 @@ def test_build_df_with_index(): assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"]) out.pop("data_frame") assert out == args + + +def test_splom_case(): + iris = px.data.iris() + fig = px.scatter_matrix(iris) + assert len(fig.data[0].dimensions) == len(iris.columns) From 3b0d21abc402389cce26b1fc53d91c0e76187d89 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 16 Sep 2019 17:03:23 -0400 Subject: [PATCH 14/69] deduplicated labels logics --- packages/python/plotly/plotly/express/_core.py | 12 +++++++++--- .../plotly/tests/test_core/test_px/test_px_input.py | 11 ----------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index cb9406748ac..c273c6a48e9 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -757,6 +757,13 @@ def apply_default_cascade(args): def build_or_augment_dataframe(args, attrables, array_attrables, constructor): """ Constructs an implicit dataframe and modifies `args` in-place. + + Parameters + ---------- + args : OrderedDict + + constructor : go Trace object + trace function. It is used to fine-tune some options. `attrables` is a list of keys into `args`, all of whose corresponding values are converted into columns of a dataframe. Used to be support calls to plotting function that elide a dataframe @@ -772,7 +779,7 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): df = pd.DataFrame() # Retrieve labels (to change column names) - labels = args.get("labels") # labels or None + # labels = args.get("labels") # labels or None # Valid column names df_columns = ( @@ -817,7 +824,6 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): # Case of index elif isinstance(argument, pd.core.indexes.range.RangeIndex): col_name = argument.name if argument.name else "index" - col_name = labels[field] if labels and labels.get(field) else col_name try: df.insert(0, col_name, argument) except ValueError: # if col named index already exists, replace @@ -827,7 +833,7 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): try: col_name = argument.name # pandas df except AttributeError: - col_name = labels[field] if labels and labels.get(field) else field + col_name = field df[col_name] = argument # Update argument with column name now that column exists if field_name not in array_attrables: diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 0107a1f3614..c254e675d65 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -116,17 +116,6 @@ def test_build_df_from_lists(): out.pop("data_frame") assert out == output - # Lists, changing one label - labels = {"x": "time"} - args = dict(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9], labels=labels) - output = {key: key for key in args} - output.update(labels) - args_wo_labels = args.copy() - _ = args_wo_labels.pop("labels") - df = pd.DataFrame(args_wo_labels).rename(columns=labels) - out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) - assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) - def test_build_df_with_index(): tips = px.data.tips() From 2a6ff71b4c8dd915280d2f863b20d5d32ba64184 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 08:43:03 -0400 Subject: [PATCH 15/69] better handling of dimensions --- .../python/plotly/plotly/express/_core.py | 20 +++++++++---------- .../tests/test_core/test_px/test_px_input.py | 6 +++--- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index c273c6a48e9..4555d3652b8 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -754,16 +754,13 @@ def apply_default_cascade(args): args["marginal_x"] = None -def build_or_augment_dataframe(args, attrables, array_attrables, constructor): +def build_or_augment_dataframe(args, attrables, array_attrables): """ Constructs an implicit dataframe and modifies `args` in-place. Parameters ---------- args : OrderedDict - - constructor : go Trace object - trace function. It is used to fine-tune some options. `attrables` is a list of keys into `args`, all of whose corresponding values are converted into columns of a dataframe. Used to be support calls to plotting function that elide a dataframe @@ -773,13 +770,14 @@ def build_or_augment_dataframe(args, attrables, array_attrables, constructor): # We start from an empty DataFrame except for the case of functions which # implicitely need all dimensions: Splom, Parcats, Parcoords # This could be refined when dimensions is given - if constructor in [go.Splom, go.Parcats, go.Parcoords]: # we take all dimensions - df = args["data_frame"] - else: - df = pd.DataFrame() + df = pd.DataFrame() - # Retrieve labels (to change column names) - # labels = args.get("labels") # labels or None + if "dimensions" in args and args["dimensions"] is None: + if args.get("data_frame") is None or args["data_frame"] is None: + raise ValueError("No data were provided") + else: + df_args = args["data_frame"] + df[df_args.columns] = df_args[df_args.columns] # Valid column names df_columns = ( @@ -857,7 +855,7 @@ def infer_config(args, constructor, trace_patch): group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] all_attrables = attrables + group_attrables + ["color"] - build_or_augment_dataframe(args, all_attrables, array_attrables, constructor) + build_or_augment_dataframe(args, all_attrables, array_attrables) attrs = [k for k in attrables if k in args] grouped_attrs = [] diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index c254e675d65..21b3daf12a1 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -102,7 +102,7 @@ def test_build_df_from_lists(): args = dict(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9]) output = {key: key for key in args} df = pd.DataFrame(args) - out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) + out = build_or_augment_dataframe(args, all_attrables, array_attrables) assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) out.pop("data_frame") assert out == output @@ -111,7 +111,7 @@ def test_build_df_from_lists(): args = dict(x=np.array([1, 2, 3]), y=np.array([2, 3, 4]), color=[1, 3, 9]) output = {key: key for key in args} df = pd.DataFrame(args) - out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) + out = build_or_augment_dataframe(args, all_attrables, array_attrables) assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) out.pop("data_frame") assert out == output @@ -121,7 +121,7 @@ def test_build_df_with_index(): tips = px.data.tips() args = dict(data_frame=tips, x=tips.index, y="total_bill") changed_output = dict(x="index") - out = build_or_augment_dataframe(args, all_attrables, array_attrables, go.Scatter) + out = build_or_augment_dataframe(args, all_attrables, array_attrables) assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"]) out.pop("data_frame") assert out == args From 72b5d1c9d0d53d82165eb5024de478419a0d6a4c Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 09:53:35 -0400 Subject: [PATCH 16/69] corner case when column was modified --- packages/python/plotly/plotly/express/_core.py | 17 +++++++++++++++-- .../tests/test_core/test_px/test_px_input.py | 4 ++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 4555d3652b8..5e3d58b68be 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -774,7 +774,9 @@ def build_or_augment_dataframe(args, attrables, array_attrables): if "dimensions" in args and args["dimensions"] is None: if args.get("data_frame") is None or args["data_frame"] is None: - raise ValueError("No data were provided") + raise ValueError( + "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument." + ) else: df_args = args["data_frame"] df[df_args.columns] = df_args[df_args.columns] @@ -830,7 +832,18 @@ def build_or_augment_dataframe(args, attrables, array_attrables): else: try: col_name = argument.name # pandas df - except AttributeError: + if ( + args.get("data_frame") is not None + and col_name in args["data_frame"] + ): + # If the name exists but the values have changed + # we do not want to keep the name, revert to field + col_name = ( + col_name + if args["data_frame"][col_name].equals(argument) + else field + ) + except AttributeError: # numpy array, list... col_name = field df[col_name] = argument # Update argument with column name now that column exists diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 21b3daf12a1..e7ee3da903c 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -37,9 +37,9 @@ def test_with_index(): fig = px.scatter(tips, x=tips.index, y=tips.total_bill) fig = px.scatter(tips, x=tips.index, y=tips.total_bill) assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" - # I was not expecting this to work but it does... + # If we tinker with the column then the name is the one of the kw argument fig = px.scatter(tips, x=tips.index, y=10 * tips.total_bill) - assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" + assert fig.data[0]["hovertemplate"] == "index=%{x}
y=%{y}" fig = px.scatter(tips, x=tips.index, y=tips.total_bill, labels={"index": "number"}) assert fig.data[0]["hovertemplate"] == "number=%{x}
total_bill=%{y}" # We do not allow "x=index" From 19431dde33a8537014c6068402f9864ae5528a23 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 10:13:12 -0400 Subject: [PATCH 17/69] modified docs --- packages/python/plotly/plotly/express/_doc.py | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index fbefe4e3860..a914aa2ae9c 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -1,7 +1,7 @@ import inspect -colref = "(string: name of column in `data_frame`)" -colref_list = "(list of string: names of columns in `data_frame`)" +colref = "(string: name of column in `data_frame`, or array_like object)" +colref_list = "(list of string: names of columns in `data_frame`, or array_like objects)" # TODO contents of columns # TODO explain categorical @@ -15,50 +15,50 @@ data_frame=["A 'tidy' `pandas.DataFrame`"], x=[ colref, - "Values from this column are used to position marks along the x axis in cartesian coordinates.", + "Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.", "For horizontal `histogram`s, these values are used as inputs to `histfunc`.", ], y=[ colref, - "Values from this column are used to position marks along the y axis in cartesian coordinates.", + "Values from this column or array_like are used to position marks along the y axis in cartesian coordinates.", "For vertical `histogram`s, these values are used as inputs to `histfunc`.", ], z=[ colref, - "Values from this column are used to position marks along the z axis in cartesian coordinates.", + "Values from this column or array_like are used to position marks along the z axis in cartesian coordinates.", "For `density_heatmap` and `density_contour` these values are used as the inputs to `histfunc`.", ], a=[ colref, - "Values from this column are used to position marks along the a axis in ternary coordinates.", + "Values from this column or array_like are used to position marks along the a axis in ternary coordinates.", ], b=[ colref, - "Values from this column are used to position marks along the b axis in ternary coordinates.", + "Values from this column or array_like are used to position marks along the b axis in ternary coordinates.", ], c=[ colref, - "Values from this column are used to position marks along the c axis in ternary coordinates.", + "Values from this column or array_like are used to position marks along the c axis in ternary coordinates.", ], r=[ colref, - "Values from this column are used to position marks along the radial axis in polar coordinates.", + "Values from this column or array_like are used to position marks along the radial axis in polar coordinates.", ], theta=[ colref, - "Values from this column are used to position marks along the angular axis in polar coordinates.", + "Values from this column or array_like are used to position marks along the angular axis in polar coordinates.", ], lat=[ colref, - "Values from this column are used to position marks according to latitude on a map.", + "Values from this column or array_like are used to position marks according to latitude on a map.", ], lon=[ colref, - "Values from this column are used to position marks according to longitude on a map.", + "Values from this column or array_like are used to position marks according to longitude on a map.", ], locations=[ colref, - "Values from this column are be interpreted according to `locationmode` and mapped to longitude/latitude.", + "Values from this column or array_like are be interpreted according to `locationmode` and mapped to longitude/latitude.", ], dimensions=[ "(list of strings, names of columns in `data_frame`)", @@ -66,47 +66,47 @@ ], error_x=[ colref, - "Values from this column are used to size x-axis error bars.", + "Values from this column or array_like are used to size x-axis error bars.", "If `error_x_minus` is `None`, error bars will be symmetrical, otherwise `error_x` is used for the positive direction only.", ], error_x_minus=[ colref, - "Values from this column are used to size x-axis error bars in the negative direction.", + "Values from this column or array_like are used to size x-axis error bars in the negative direction.", "Ignored if `error_x` is `None`.", ], error_y=[ colref, - "Values from this column are used to size y-axis error bars.", + "Values from this column or array_like are used to size y-axis error bars.", "If `error_y_minus` is `None`, error bars will be symmetrical, otherwise `error_y` is used for the positive direction only.", ], error_y_minus=[ colref, - "Values from this column are used to size y-axis error bars in the negative direction.", + "Values from this column or array_like are used to size y-axis error bars in the negative direction.", "Ignored if `error_y` is `None`.", ], error_z=[ colref, - "Values from this column are used to size z-axis error bars.", + "Values from this column or array_like are used to size z-axis error bars.", "If `error_z_minus` is `None`, error bars will be symmetrical, otherwise `error_z` is used for the positive direction only.", ], error_z_minus=[ colref, - "Values from this column are used to size z-axis error bars in the negative direction.", + "Values from this column or array_like are used to size z-axis error bars in the negative direction.", "Ignored if `error_z` is `None`.", ], - color=[colref, "Values from this column are used to assign color to marks."], + color=[colref, "Values from this column or array_like are used to assign color to marks."], opacity=["(number, between 0 and 1) Sets the opacity for markers."], line_dash=[ colref, - "Values from this column are used to assign dash-patterns to lines.", + "Values from this column or array_like are used to assign dash-patterns to lines.", ], line_group=[ colref, - "Values from this column are used to group rows of `data_frame` into lines.", + "Values from this column or array_like are used to group rows of `data_frame` into lines.", ], - symbol=[colref, "Values from this column are used to assign symbols to marks."], - size=[colref, "Values from this column are used to assign mark sizes."], - hover_name=[colref, "Values from this column appear in bold in the hover tooltip."], + symbol=[colref, "Values from this column or array_like are used to assign symbols to marks."], + size=[colref, "Values from this column or array_like are used to assign mark sizes."], + hover_name=[colref, "Values from this column or array_like appear in bold in the hover tooltip."], hover_data=[ colref_list, "Values from these columns appear as extra data in the hover tooltip.", @@ -115,26 +115,26 @@ colref_list, "Values from these columns are extra data, to be used in widgets or Dash callbacks for example. This data is not user-visible but is included in events emitted by the figure (lasso selection etc.)", ], - text=[colref, "Values from this column appear in the figure as text labels."], + text=[colref, "Values from this column or array_like appear in the figure as text labels."], locationmode=[ "(string, one of 'ISO-3', 'USA-states', 'country names')", "Determines the set of locations used to match entries in `locations` to regions on the map.", ], facet_row=[ colref, - "Values from this column are used to assign marks to facetted subplots in the vertical direction.", + "Values from this column or array_like are used to assign marks to facetted subplots in the vertical direction.", ], facet_col=[ colref, - "Values from this column are used to assign marks to facetted subplots in the horizontal direction.", + "Values from this column or array_like are used to assign marks to facetted subplots in the horizontal direction.", ], animation_frame=[ colref, - "Values from this column are used to assign marks to animation frames.", + "Values from this column or array_like are used to assign marks to animation frames.", ], animation_group=[ colref, - "Values from this column are used to provide object-constancy across animation frames: rows with matching `animation_group`s will be treated as if they describe the same object in each frame.", + "Values from this column or array_like are used to provide object-constancy across animation frames: rows with matching `animation_group`s will be treated as if they describe the same object in each frame.", ], symbol_sequence=[ "(list of strings defining plotly.js symbols)", From 7297a7f81fe91ff07021e61f1a14a483237e3109 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 10:15:15 -0400 Subject: [PATCH 18/69] black --- packages/python/plotly/plotly/express/_doc.py | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index a914aa2ae9c..44abcfe6478 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -1,7 +1,9 @@ import inspect colref = "(string: name of column in `data_frame`, or array_like object)" -colref_list = "(list of string: names of columns in `data_frame`, or array_like objects)" +colref_list = ( + "(list of string: names of columns in `data_frame`, or array_like objects)" +) # TODO contents of columns # TODO explain categorical @@ -94,7 +96,10 @@ "Values from this column or array_like are used to size z-axis error bars in the negative direction.", "Ignored if `error_z` is `None`.", ], - color=[colref, "Values from this column or array_like are used to assign color to marks."], + color=[ + colref, + "Values from this column or array_like are used to assign color to marks.", + ], opacity=["(number, between 0 and 1) Sets the opacity for markers."], line_dash=[ colref, @@ -104,9 +109,18 @@ colref, "Values from this column or array_like are used to group rows of `data_frame` into lines.", ], - symbol=[colref, "Values from this column or array_like are used to assign symbols to marks."], - size=[colref, "Values from this column or array_like are used to assign mark sizes."], - hover_name=[colref, "Values from this column or array_like appear in bold in the hover tooltip."], + symbol=[ + colref, + "Values from this column or array_like are used to assign symbols to marks.", + ], + size=[ + colref, + "Values from this column or array_like are used to assign mark sizes.", + ], + hover_name=[ + colref, + "Values from this column or array_like appear in bold in the hover tooltip.", + ], hover_data=[ colref_list, "Values from these columns appear as extra data in the hover tooltip.", @@ -115,7 +129,10 @@ colref_list, "Values from these columns are extra data, to be used in widgets or Dash callbacks for example. This data is not user-visible but is included in events emitted by the figure (lasso selection etc.)", ], - text=[colref, "Values from this column or array_like appear in the figure as text labels."], + text=[ + colref, + "Values from this column or array_like appear in the figure as text labels.", + ], locationmode=[ "(string, one of 'ISO-3', 'USA-states', 'country names')", "Determines the set of locations used to match entries in `locations` to regions on the map.", From dadd64555f8764241eb76f4cc0a67e412efb8db8 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 11:13:17 -0400 Subject: [PATCH 19/69] case when no df provided and str arguments --- .../python/plotly/plotly/express/_core.py | 21 ++++++++++++++----- .../tests/test_core/test_px/test_px_input.py | 8 +++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 5e3d58b68be..c6463d3417b 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -756,15 +756,21 @@ def apply_default_cascade(args): def build_or_augment_dataframe(args, attrables, array_attrables): """ - Constructs an implicit dataframe and modifies `args` in-place. + Constructs a dataframe and modifies `args` in-place. + + The argument values in `args` can be either strings corresponding to + existing columns of a dataframe, or data arrays (lists, numpy arrays, + pandas columns, series). Parameters ---------- args : OrderedDict - `attrables` is a list of keys into `args`, all of whose corresponding - values are converted into columns of a dataframe. - Used to be support calls to plotting function that elide a dataframe - argument; for example `scatter(x=[1,2], y=[3,4])`. + arguments passed to the px function and subsequently modified + attrables : list + list of keys into `args`, all of whose corresponding values are + converted into columns of a dataframe. + array_attrables : list + argument names corresponding to iterables, such as `hover_data`, ... """ # We start from an empty DataFrame except for the case of functions which @@ -821,6 +827,11 @@ def build_or_augment_dataframe(args, attrables, array_attrables): "Expected one of %s but received: %s" % (field, str(list(df_columns)), argument) ) + except TypeError: + raise ValueError( + "String arguments are only possible when a DataFrame" + "is provided in the `data_frame` argument." + ) # Case of index elif isinstance(argument, pd.core.indexes.range.RangeIndex): col_name = argument.name if argument.name else "index" diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index e7ee3da903c..376621cde8e 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -84,6 +84,14 @@ def test_wrong_column_name(): fig = px.scatter(px.data.tips(), x="bla", y="wrong") +def test_missing_data_frame(): + with pytest.raises(ValueError) as err_msg: + fig = px.scatter(x="arg1", y="arg2") + assert "String arguments are only possible when a DataFrame" in str( + err_msg.value + ) + + def test_wrong_dimensions_of_array(): with pytest.raises(ValueError) as err_msg: fig = px.scatter(x=[1, 2, 3], y=[2, 3, 4, 5]) From 2cb9dba08f39563bb57136e7a6ca28fb01c801d2 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 11:15:10 -0400 Subject: [PATCH 20/69] argument=None --- packages/python/plotly/plotly/express/_core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index c6463d3417b..19905f6c47d 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -803,7 +803,8 @@ def build_or_augment_dataframe(args, attrables, array_attrables): if field_name not in array_attrables else args.get(field_name) ) - if argument_list is None: # argument not specified, continue + # argument not specified, continue + if argument_list is None or argument_list is [None]: continue # Argument name: field_name if the argument is a list # Else we give names like ["hover_data_0, hover_data_1"] etc. From b09dc58df012043ad999d7d28389562786293dde Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 12:11:26 -0400 Subject: [PATCH 21/69] better handling of problematic case --- packages/python/plotly/plotly/express/_core.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 19905f6c47d..0d76c9f0daf 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -818,6 +818,11 @@ def build_or_augment_dataframe(args, attrables, array_attrables): if argument is None: continue elif isinstance(argument, str): # just a column name + if not isinstance(args.get("data_frame"), pd.DataFrame): + raise ValueError( + "String arguments are only possible when a DataFrame" + "is provided in the `data_frame` argument." + ) # Check validity of column name try: df[argument] = args["data_frame"][argument] @@ -828,11 +833,6 @@ def build_or_augment_dataframe(args, attrables, array_attrables): "Expected one of %s but received: %s" % (field, str(list(df_columns)), argument) ) - except TypeError: - raise ValueError( - "String arguments are only possible when a DataFrame" - "is provided in the `data_frame` argument." - ) # Case of index elif isinstance(argument, pd.core.indexes.range.RangeIndex): col_name = argument.name if argument.name else "index" From 91e6cda3404d2948f56cae3d6e18c9b433005622 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 13:01:00 -0400 Subject: [PATCH 22/69] simplify code --- packages/python/plotly/plotly/express/_core.py | 6 ++---- .../plotly/plotly/tests/test_core/test_px/test_px_input.py | 2 ++ 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 0d76c9f0daf..04e6d4e32b4 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -779,7 +779,7 @@ def build_or_augment_dataframe(args, attrables, array_attrables): df = pd.DataFrame() if "dimensions" in args and args["dimensions"] is None: - if args.get("data_frame") is None or args["data_frame"] is None: + if args["data_frame"] is None: raise ValueError( "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument." ) @@ -788,9 +788,7 @@ def build_or_augment_dataframe(args, attrables, array_attrables): df[df_args.columns] = df_args[df_args.columns] # Valid column names - df_columns = ( - args["data_frame"].columns if args.get("data_frame") is not None else None - ) + df_columns = args["data_frame"].columns if args["data_frame"] is not None else None group_attrs = ["symbol", "line_dash"] for group_attr in group_attrs: if group_attr in args: diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 376621cde8e..7ba5e88e8e6 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -110,6 +110,7 @@ def test_build_df_from_lists(): args = dict(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9]) output = {key: key for key in args} df = pd.DataFrame(args) + args["data_frame"] = None out = build_or_augment_dataframe(args, all_attrables, array_attrables) assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) out.pop("data_frame") @@ -119,6 +120,7 @@ def test_build_df_from_lists(): args = dict(x=np.array([1, 2, 3]), y=np.array([2, 3, 4]), color=[1, 3, 9]) output = {key: key for key in args} df = pd.DataFrame(args) + args["data_frame"] = None out = build_or_augment_dataframe(args, all_attrables, array_attrables) assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) out.pop("data_frame") From 33db488e61cf8af8caea69d25c527fde406ba418 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 13:27:34 -0400 Subject: [PATCH 23/69] symbol and line_dash --- packages/python/plotly/plotly/express/_core.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 04e6d4e32b4..de3daa698f9 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -764,7 +764,7 @@ def build_or_augment_dataframe(args, attrables, array_attrables): Parameters ---------- - args : OrderedDict +placebo args : OrderedDict arguments passed to the px function and subsequently modified attrables : list list of keys into `args`, all of whose corresponding values are @@ -789,11 +789,6 @@ def build_or_augment_dataframe(args, attrables, array_attrables): # Valid column names df_columns = args["data_frame"].columns if args["data_frame"] is not None else None - group_attrs = ["symbol", "line_dash"] - for group_attr in group_attrs: - if group_attr in args: - attrables += [group_attr] - # Loop over possible arguments for field_name in attrables: argument_list = ( @@ -876,8 +871,12 @@ def infer_config(args, constructor, trace_patch): ) array_attrables = ["dimensions", "custom_data", "hover_data"] group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] - all_attrables = attrables + group_attrables + ["color"] + group_attrs = ["symbol", "line_dash"] + for group_attr in group_attrs: + if group_attr in args: + all_attrables += [group_attr] + build_or_augment_dataframe(args, all_attrables, array_attrables) attrs = [k for k in attrables if k in args] From e58c82e71716c64586b1e2480df5d6c4e3143eca Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 13:47:32 -0400 Subject: [PATCH 24/69] stricter check on column --- packages/python/plotly/plotly/express/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index de3daa698f9..eed0cb5cd0e 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -764,7 +764,7 @@ def build_or_augment_dataframe(args, attrables, array_attrables): Parameters ---------- -placebo args : OrderedDict + args : OrderedDict arguments passed to the px function and subsequently modified attrables : list list of keys into `args`, all of whose corresponding values are @@ -845,7 +845,7 @@ def build_or_augment_dataframe(args, attrables, array_attrables): # we do not want to keep the name, revert to field col_name = ( col_name - if args["data_frame"][col_name].equals(argument) + if argument is args["data_frame"][col_name] else field ) except AttributeError: # numpy array, list... From 9d65a07dc56d483bc3d81915dab36bffa404a120 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 14:40:36 -0400 Subject: [PATCH 25/69] more readable error message --- packages/python/plotly/plotly/express/_core.py | 8 ++++---- .../plotly/tests/test_core/test_px/test_px_input.py | 5 ++++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index eed0cb5cd0e..4553acd73db 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -817,15 +817,15 @@ def build_or_augment_dataframe(args, attrables, array_attrables): "is provided in the `data_frame` argument." ) # Check validity of column name - try: - df[argument] = args["data_frame"][argument] - continue - except KeyError: + if argument not in df_columns: raise ValueError( "Value of '%s' is not the name of a column in 'data_frame'. " "Expected one of %s but received: %s" % (field, str(list(df_columns)), argument) ) + + df[argument] = args["data_frame"][argument] + continue # Case of index elif isinstance(argument, pd.core.indexes.range.RangeIndex): col_name = argument.name if argument.name else "index" diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 7ba5e88e8e6..51b9340ad40 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -80,8 +80,11 @@ def test_arrayattrable_numpy(): def test_wrong_column_name(): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as err_msg: fig = px.scatter(px.data.tips(), x="bla", y="wrong") + assert "Value of 'x' is not the name of a column in 'data_frame'" in str( + err_msg.value + ) def test_missing_data_frame(): From 64010dedf4ce006a980f5d55c6c186786394e205 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 15:43:25 -0400 Subject: [PATCH 26/69] check length --- .../python/plotly/plotly/express/_core.py | 6 +++++- .../tests/test_core/test_px/test_px_input.py | 19 ++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 4553acd73db..7e09319c08d 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -808,6 +808,7 @@ def build_or_augment_dataframe(args, attrables, array_attrables): ) # argument_list and field_list ready, iterate over them for i, (argument, field) in enumerate(zip(argument_list, field_list)): + length = len(df) if argument is None: continue elif isinstance(argument, str): # just a column name @@ -823,7 +824,8 @@ def build_or_augment_dataframe(args, attrables, array_attrables): "Expected one of %s but received: %s" % (field, str(list(df_columns)), argument) ) - + if length and len(args["data_frame"][argument]) != length: + raise ValueError("All arguments should have the same length.") df[argument] = args["data_frame"][argument] continue # Case of index @@ -850,6 +852,8 @@ def build_or_augment_dataframe(args, attrables, array_attrables): ) except AttributeError: # numpy array, list... col_name = field + if length and len(argument) != length: + raise ValueError("All arguments should have the same length.") df[col_name] = argument # Update argument with column name now that column exists if field_name not in array_attrables: diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 51b9340ad40..39b1ae01049 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -98,14 +98,27 @@ def test_missing_data_frame(): def test_wrong_dimensions_of_array(): with pytest.raises(ValueError) as err_msg: fig = px.scatter(x=[1, 2, 3], y=[2, 3, 4, 5]) - assert "Length of values does not match length of index" in str(err_msg.value) + assert "All arguments should have the same length." in str(err_msg.value) -def test_wrong_dimensions_mixed_cqse(): +def test_wrong_dimensions_mixed_case(): with pytest.raises(ValueError) as err_msg: df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25])) fig = px.scatter(df, x="time", y="temperature", color=[1, 3, 9, 5]) - assert "Length of values does not match length of index" in str(err_msg.value) + assert "All arguments should have the same length." in str(err_msg.value) + + +def test_wrong_dimensions(): + with pytest.raises(ValueError) as err_msg: + fig = px.scatter(px.data.tips(), x="tip", y=[1, 2, 3]) + assert "All arguments should have the same length." in str(err_msg.value) + # the order matters + with pytest.raises(ValueError) as err_msg: + fig = px.scatter(px.data.tips(), x=[1, 2, 3], y="tip") + assert "All arguments should have the same length." in str(err_msg.value) + with pytest.raises(ValueError): + fig = px.scatter(px.data.tips(), x=px.data.iris().index, y="tip") + # assert "All arguments should have the same length." in str(err_msg.value) def test_build_df_from_lists(): From 20bd812b88417d9da880db4d77b44413824dd877 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 17 Sep 2019 16:42:57 -0400 Subject: [PATCH 27/69] error when multiindex --- packages/python/plotly/plotly/express/_core.py | 2 ++ .../tests/test_core/test_px/test_px_input.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 7e09319c08d..a9cce392105 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -829,6 +829,8 @@ def build_or_augment_dataframe(args, attrables, array_attrables): df[argument] = args["data_frame"][argument] continue # Case of index + elif isinstance(argument, pd.core.indexes.multi.MultiIndex): + raise TypeError("pandas MultiIndex is not supported by plotly express") elif isinstance(argument, pd.core.indexes.range.RangeIndex): col_name = argument.name if argument.name else "index" try: diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 39b1ae01049..e9cab964a52 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -121,6 +121,20 @@ def test_wrong_dimensions(): # assert "All arguments should have the same length." in str(err_msg.value) +def test_multiindex_raise_error(): + index = pd.MultiIndex.from_product( + [[1, 2, 3], ["a", "b"]], names=["first", "second"] + ) + df = pd.DataFrame(np.random.random((6, 3)), index=index, columns=["A", "B", "C"]) + # This is ok + fig = px.scatter(df, x="A", y="B") + with pytest.raises(TypeError) as err_msg: + fig = px.scatter(df, x=df.index, y="B") + assert "pandas MultiIndex is not supported by plotly express" in str( + err_msg.value + ) + + def test_build_df_from_lists(): # Just lists args = dict(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9]) From a7241a5e70e28c3524f4e463e06a29657ef496cb Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 18 Sep 2019 09:40:20 -0400 Subject: [PATCH 28/69] accept int as column names --- .../python/plotly/plotly/express/_core.py | 23 +++++++++++++++---- .../tests/test_core/test_px/test_px_input.py | 11 +++++++++ 2 files changed, 29 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index a9cce392105..26e21b88fed 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -788,7 +788,11 @@ def build_or_augment_dataframe(args, attrables, array_attrables): df[df_args.columns] = df_args[df_args.columns] # Valid column names - df_columns = args["data_frame"].columns if args["data_frame"] is not None else None + df_columns = ( + args["data_frame"].columns + if isinstance(args["data_frame"], pd.DataFrame) + else None + ) # Loop over possible arguments for field_name in attrables: argument_list = ( @@ -811,14 +815,18 @@ def build_or_augment_dataframe(args, attrables, array_attrables): length = len(df) if argument is None: continue - elif isinstance(argument, str): # just a column name - if not isinstance(args.get("data_frame"), pd.DataFrame): + elif isinstance(argument, str) or isinstance( + argument, int + ): # just a column name + if not isinstance( + args.get("data_frame"), pd.DataFrame + ) and not isinstance(args.get("data_frame"), np.ndarray): raise ValueError( "String arguments are only possible when a DataFrame" "is provided in the `data_frame` argument." ) # Check validity of column name - if argument not in df_columns: + if df_columns is not None and argument not in df_columns: raise ValueError( "Value of '%s' is not the name of a column in 'data_frame'. " "Expected one of %s but received: %s" @@ -826,7 +834,12 @@ def build_or_augment_dataframe(args, attrables, array_attrables): ) if length and len(args["data_frame"][argument]) != length: raise ValueError("All arguments should have the same length.") - df[argument] = args["data_frame"][argument] + df[str(argument)] = args["data_frame"][argument] + if isinstance(argument, int): + if field_name not in array_attrables: + args[field_name] = str(argument) + else: + args[field_name][i] = str(argument) continue # Case of index elif isinstance(argument, pd.core.indexes.multi.MultiIndex): diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index e9cab964a52..f0e5def4f09 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -171,3 +171,14 @@ def test_splom_case(): iris = px.data.iris() fig = px.scatter_matrix(iris) assert len(fig.data[0].dimensions) == len(iris.columns) + + +def test_int_col_names(): + # DataFrame with int column names + lengths = pd.DataFrame(np.random.random(100)) + fig = px.histogram(lengths, x=0) + assert np.all(np.array(lengths).flatten() == fig.data[0].x) + # Numpy array + ar = np.arange(100).reshape((10, 10)) + fig = px.scatter(ar, x=2, y=8) + assert np.all(fig.data[0].x == ar[2]) From dc1fc0a4390270795d1c36002a44eac1ad7f7bf1 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 18 Sep 2019 10:18:48 -0400 Subject: [PATCH 29/69] array and dict as dataframe --- packages/python/plotly/plotly/express/_core.py | 11 +++++++++-- .../plotly/tests/test_core/test_px/test_px_input.py | 12 ++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 26e21b88fed..9eba41505b3 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -787,6 +787,12 @@ def build_or_augment_dataframe(args, attrables, array_attrables): df_args = args["data_frame"] df[df_args.columns] = df_args[df_args.columns] + # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) + if args["data_frame"] is not None and not isinstance( + args["data_frame"], pd.DataFrame + ): + args["data_frame"] = pd.DataFrame(args["data_frame"]) + # Valid column names df_columns = ( args["data_frame"].columns @@ -822,8 +828,9 @@ def build_or_augment_dataframe(args, attrables, array_attrables): args.get("data_frame"), pd.DataFrame ) and not isinstance(args.get("data_frame"), np.ndarray): raise ValueError( - "String arguments are only possible when a DataFrame" - "is provided in the `data_frame` argument." + "String or int arguments are only possible when a" + "DataFrame or an array is provided in the `data_frame`" + "argument." ) # Check validity of column name if df_columns is not None and argument not in df_columns: diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index f0e5def4f09..f7c456f63f4 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -90,9 +90,7 @@ def test_wrong_column_name(): def test_missing_data_frame(): with pytest.raises(ValueError) as err_msg: fig = px.scatter(x="arg1", y="arg2") - assert "String arguments are only possible when a DataFrame" in str( - err_msg.value - ) + assert "String or int arguments are only possible" in str(err_msg.value) def test_wrong_dimensions_of_array(): @@ -181,4 +179,10 @@ def test_int_col_names(): # Numpy array ar = np.arange(100).reshape((10, 10)) fig = px.scatter(ar, x=2, y=8) - assert np.all(fig.data[0].x == ar[2]) + assert np.all(fig.data[0].x == ar[:, 2]) + + +def test_data_frame_from_dict(): + fig = px.scatter({"time": [0, 1], "money": [1, 2]}, x="time", y="money") + assert fig.data[0].hovertemplate == "time=%{x}
money=%{y}" + assert np.all(fig.data[0].x == [0, 1]) From 61678c23f2ecb5ef977c56b72b61baa9b06749a2 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 18 Sep 2019 15:45:10 -0400 Subject: [PATCH 30/69] splom case --- packages/python/plotly/plotly/express/_core.py | 12 ++++++------ .../plotly/tests/test_core/test_px/test_px_input.py | 6 ++++++ 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 9eba41505b3..e4d3bc85640 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -772,6 +772,12 @@ def build_or_augment_dataframe(args, attrables, array_attrables): array_attrables : list argument names corresponding to iterables, such as `hover_data`, ... """ + # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) + if args["data_frame"] is not None and not isinstance( + args["data_frame"], pd.DataFrame + ): + args["data_frame"] = pd.DataFrame(args["data_frame"]) + # We start from an empty DataFrame except for the case of functions which # implicitely need all dimensions: Splom, Parcats, Parcoords @@ -787,12 +793,6 @@ def build_or_augment_dataframe(args, attrables, array_attrables): df_args = args["data_frame"] df[df_args.columns] = df_args[df_args.columns] - # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) - if args["data_frame"] is not None and not isinstance( - args["data_frame"], pd.DataFrame - ): - args["data_frame"] = pd.DataFrame(args["data_frame"]) - # Valid column names df_columns = ( args["data_frame"].columns diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index f7c456f63f4..969ceb3c68b 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -169,6 +169,12 @@ def test_splom_case(): iris = px.data.iris() fig = px.scatter_matrix(iris) assert len(fig.data[0].dimensions) == len(iris.columns) + dic = {'a':[1, 2, 3], 'b':[4, 5, 6], 'c':[7, 8, 9]} + fig = px.scatter_matrix(dic) + assert np.all(fig.data[0].dimensions[0].values == np.array(dic['a'])) + ar = np.arange(9).reshape((3, 3)) + fig = px.scatter_matrix(ar) + assert np.all(fig.data[0].dimensions[0].values == ar[:, 0]) def test_int_col_names(): From 71756a3808e76df916fb113467058ea5eae74143 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 18 Sep 2019 15:54:00 -0400 Subject: [PATCH 31/69] simplified index case --- packages/python/plotly/plotly/express/_core.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index e4d3bc85640..0d9dfeb35c9 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -778,7 +778,6 @@ def build_or_augment_dataframe(args, attrables, array_attrables): ): args["data_frame"] = pd.DataFrame(args["data_frame"]) - # We start from an empty DataFrame except for the case of functions which # implicitely need all dimensions: Splom, Parcats, Parcoords # This could be refined when dimensions is given @@ -851,16 +850,14 @@ def build_or_augment_dataframe(args, attrables, array_attrables): # Case of index elif isinstance(argument, pd.core.indexes.multi.MultiIndex): raise TypeError("pandas MultiIndex is not supported by plotly express") - elif isinstance(argument, pd.core.indexes.range.RangeIndex): - col_name = argument.name if argument.name else "index" - try: - df.insert(0, col_name, argument) - except ValueError: # if col named index already exists, replace - df[col_name] = argument # Case of numpy array or df column else: try: col_name = argument.name # pandas df + if col_name is None and isinstance( + argument, pd.core.indexes.range.RangeIndex + ): + col_name = "index" if ( args.get("data_frame") is not None and col_name in args["data_frame"] From 40ba5b91b59e05da2258eb9a3e8ac5a5a5a1bc9d Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 18 Sep 2019 15:56:51 -0400 Subject: [PATCH 32/69] black --- .../plotly/plotly/tests/test_core/test_px/test_px_input.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 969ceb3c68b..e00de7d9848 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -169,9 +169,9 @@ def test_splom_case(): iris = px.data.iris() fig = px.scatter_matrix(iris) assert len(fig.data[0].dimensions) == len(iris.columns) - dic = {'a':[1, 2, 3], 'b':[4, 5, 6], 'c':[7, 8, 9]} + dic = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} fig = px.scatter_matrix(dic) - assert np.all(fig.data[0].dimensions[0].values == np.array(dic['a'])) + assert np.all(fig.data[0].dimensions[0].values == np.array(dic["a"])) ar = np.arange(9).reshape((3, 3)) fig = px.scatter_matrix(ar) assert np.all(fig.data[0].dimensions[0].values == ar[:, 0]) From 749db6ce03c6dec90113d62e716bbd392137f596 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 19 Sep 2019 09:21:57 -0400 Subject: [PATCH 33/69] better error messages --- .../python/plotly/plotly/express/_core.py | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 0d9dfeb35c9..9dacab35d13 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -823,13 +823,12 @@ def build_or_augment_dataframe(args, attrables, array_attrables): elif isinstance(argument, str) or isinstance( argument, int ): # just a column name - if not isinstance( - args.get("data_frame"), pd.DataFrame - ) and not isinstance(args.get("data_frame"), np.ndarray): + if not isinstance(args.get("data_frame"), pd.DataFrame): raise ValueError( "String or int arguments are only possible when a" "DataFrame or an array is provided in the `data_frame`" - "argument." + "argument. No DataFrame was provided, but argument '%s'" + "is of type str or int." % field ) # Check validity of column name if df_columns is not None and argument not in df_columns: @@ -839,7 +838,12 @@ def build_or_augment_dataframe(args, attrables, array_attrables): % (field, str(list(df_columns)), argument) ) if length and len(args["data_frame"][argument]) != length: - raise ValueError("All arguments should have the same length.") + raise ValueError( + "All arguments should have the same length." + "The length of column argument `df[%s]` is %d, whereas the" + "length of previous arguments is %d" + % (field, len(args["data_frame"][argument]), length) + ) df[str(argument)] = args["data_frame"][argument] if isinstance(argument, int): if field_name not in array_attrables: @@ -849,10 +853,13 @@ def build_or_augment_dataframe(args, attrables, array_attrables): continue # Case of index elif isinstance(argument, pd.core.indexes.multi.MultiIndex): - raise TypeError("pandas MultiIndex is not supported by plotly express") + raise TypeError( + "Argument '%s' is a pandas MultiIndex." + "pandas MultiIndex is not supported by plotly express" % field + ) # Case of numpy array or df column else: - try: + if hasattr(argument, "name"): col_name = argument.name # pandas df if col_name is None and isinstance( argument, pd.core.indexes.range.RangeIndex @@ -869,10 +876,15 @@ def build_or_augment_dataframe(args, attrables, array_attrables): if argument is args["data_frame"][col_name] else field ) - except AttributeError: # numpy array, list... + else: # numpy array, list... col_name = field if length and len(argument) != length: - raise ValueError("All arguments should have the same length.") + raise ValueError( + "All arguments should have the same length." + "The length of argument `%s` is %d, whereas the" + "length of previous arguments is %d" + % (field, len(argument), length) + ) df[col_name] = argument # Update argument with column name now that column exists if field_name not in array_attrables: From cd8574a1ee81961bba6922fbfbaa424c7cd60949 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 19 Sep 2019 16:37:04 -0400 Subject: [PATCH 34/69] better handling of Series (name is None) --- packages/python/plotly/plotly/express/_core.py | 10 ++++++---- .../plotly/tests/test_core/test_px/test_px_input.py | 7 +++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 9dacab35d13..d7f98d0c30e 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -859,11 +859,13 @@ def build_or_augment_dataframe(args, attrables, array_attrables): ) # Case of numpy array or df column else: - if hasattr(argument, "name"): + is_index = isinstance(argument, pd.core.indexes.range.RangeIndex) + # pandas series have a name but it's None + if ( + hasattr(argument, "name") and argument.name is not None + ) or is_index: col_name = argument.name # pandas df - if col_name is None and isinstance( - argument, pd.core.indexes.range.RangeIndex - ): + if col_name is None and is_index: col_name = "index" if ( args.get("data_frame") is not None diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index e00de7d9848..d5f685d0fd5 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -51,6 +51,13 @@ def test_with_index(): ) +def test_pandas_series(): + tips = px.data.tips() + before_tip = tips.total_bill - tips.tip + fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"}) + assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}" + + def test_mixed_case(): df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25])) fig = px.scatter(df, x="time", y="temperature", color=[1, 3, 9]) From 299fddeeb588baeffc96707db6a1a14b93183bf2 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 19 Sep 2019 18:11:43 -0400 Subject: [PATCH 35/69] handle case of name conflicts --- .../python/plotly/plotly/express/_core.py | 32 +++++++++++++++---- .../tests/test_core/test_px/test_px_input.py | 19 ++++++++--- 2 files changed, 41 insertions(+), 10 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d7f98d0c30e..0aa8ca9379d 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -754,7 +754,23 @@ def apply_default_cascade(args): args["marginal_x"] = None -def build_or_augment_dataframe(args, attrables, array_attrables): +def _name_heuristic(argument, field_name, df): + print (argument, field_name, df.columns) + if isinstance(argument, int): + argument = str(argument) + if argument not in df.columns: + print ("no pb", argument) + return argument + elif field_name not in df.columns: + print ("fallback field", field_name) + return field_name + elif field_name + argument not in df.columns: + return field_name + "_" + argument + else: + raise NameError("A name conflict was encountered.") + + +def build_dataframe(args, attrables, array_attrables): """ Constructs a dataframe and modifies `args` in-place. @@ -844,13 +860,14 @@ def build_or_augment_dataframe(args, attrables, array_attrables): "length of previous arguments is %d" % (field, len(args["data_frame"][argument]), length) ) - df[str(argument)] = args["data_frame"][argument] + col_name = _name_heuristic(argument, field_name, df) + df[col_name] = args["data_frame"][argument] if isinstance(argument, int): if field_name not in array_attrables: - args[field_name] = str(argument) + args[field_name] = col_name else: - args[field_name][i] = str(argument) - continue + args[field_name][i] = col_name + # continue # Case of index elif isinstance(argument, pd.core.indexes.multi.MultiIndex): raise TypeError( @@ -878,8 +895,10 @@ def build_or_augment_dataframe(args, attrables, array_attrables): if argument is args["data_frame"][col_name] else field ) + col_name = _name_heuristic(col_name, field_name, df) else: # numpy array, list... col_name = field + # col_name = _name_heuristic(field, field, df) if length and len(argument) != length: raise ValueError( "All arguments should have the same length." @@ -894,6 +913,7 @@ def build_or_augment_dataframe(args, attrables, array_attrables): else: args[field_name][i] = col_name args["data_frame"] = df + print (df) return args @@ -914,7 +934,7 @@ def infer_config(args, constructor, trace_patch): if group_attr in args: all_attrables += [group_attr] - build_or_augment_dataframe(args, all_attrables, array_attrables) + build_dataframe(args, all_attrables, array_attrables) attrs = [k for k in attrables if k in args] grouped_attrs = [] diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index d5f685d0fd5..11f9bbb2a0b 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -4,7 +4,7 @@ import pytest import plotly.graph_objects as go import plotly -from plotly.express._core import build_or_augment_dataframe +from plotly.express._core import build_dataframe from pandas.util.testing import assert_frame_equal attrables = ( @@ -58,6 +58,17 @@ def test_pandas_series(): assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}" +def test_name_conflict(): + df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) + fig = px.scatter(df, x=[10, 1], y="y", color="x") + assert np.all(fig.data[0].x == np.array([10, 1])) + fig = px.scatter(df, x=[10, 1], y="y", color=df.x) + assert np.all(fig.data[0].x == np.array([10, 1])) + df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], color=[1, 2])) + fig = px.scatter(df, x=[10, 1], y="y", size="color", color=df.x) + assert np.all(fig.data[0].x == np.array([10, 1])) + + def test_mixed_case(): df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25])) fig = px.scatter(df, x="time", y="temperature", color=[1, 3, 9]) @@ -146,7 +157,7 @@ def test_build_df_from_lists(): output = {key: key for key in args} df = pd.DataFrame(args) args["data_frame"] = None - out = build_or_augment_dataframe(args, all_attrables, array_attrables) + out = build_dataframe(args, all_attrables, array_attrables) assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) out.pop("data_frame") assert out == output @@ -156,7 +167,7 @@ def test_build_df_from_lists(): output = {key: key for key in args} df = pd.DataFrame(args) args["data_frame"] = None - out = build_or_augment_dataframe(args, all_attrables, array_attrables) + out = build_dataframe(args, all_attrables, array_attrables) assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) out.pop("data_frame") assert out == output @@ -166,7 +177,7 @@ def test_build_df_with_index(): tips = px.data.tips() args = dict(data_frame=tips, x=tips.index, y="total_bill") changed_output = dict(x="index") - out = build_or_augment_dataframe(args, all_attrables, array_attrables) + out = build_dataframe(args, all_attrables, array_attrables) assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"]) out.pop("data_frame") assert out == args From 8753797232bc00f14820bbcff3b2b3d7852c6bac Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 19 Sep 2019 22:34:07 -0400 Subject: [PATCH 36/69] removed print statements --- packages/python/plotly/plotly/express/_core.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 0aa8ca9379d..dd5db380b69 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -755,19 +755,20 @@ def apply_default_cascade(args): def _name_heuristic(argument, field_name, df): - print (argument, field_name, df.columns) if isinstance(argument, int): argument = str(argument) if argument not in df.columns: - print ("no pb", argument) return argument elif field_name not in df.columns: - print ("fallback field", field_name) return field_name elif field_name + argument not in df.columns: return field_name + "_" + argument else: - raise NameError("A name conflict was encountered.") + raise NameError( + "A name conflict was encountered for argument %s." + "Columns with names %s, %s and %s are already used" + % (field_name, argument, field_name, field_name + '_' + argument) + ) def build_dataframe(args, attrables, array_attrables): @@ -913,7 +914,6 @@ def build_dataframe(args, attrables, array_attrables): else: args[field_name][i] = col_name args["data_frame"] = df - print (df) return args From ed24d8bb2deba477cf9a4fee4f3898694a94a93d Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 19 Sep 2019 22:36:37 -0400 Subject: [PATCH 37/69] black --- packages/python/plotly/plotly/express/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index dd5db380b69..bba84fcefc8 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -767,8 +767,8 @@ def _name_heuristic(argument, field_name, df): raise NameError( "A name conflict was encountered for argument %s." "Columns with names %s, %s and %s are already used" - % (field_name, argument, field_name, field_name + '_' + argument) - ) + % (field_name, argument, field_name, field_name + "_" + argument) + ) def build_dataframe(args, attrables, array_attrables): From 2da2b21de2c0370f6ee89ae50f2869a6df4e6860 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 20 Sep 2019 10:36:20 -0400 Subject: [PATCH 38/69] wip --- packages/python/plotly/plotly/express/_core.py | 7 +++++++ .../tests/test_core/test_px/test_px_input.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index bba84fcefc8..77a1add67c5 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -771,6 +771,11 @@ def _name_heuristic(argument, field_name, df): ) +def _get_reserved_names(args, attrables, array_attrables): + df = args['data_frame'] + for field in args: + + def build_dataframe(args, attrables, array_attrables): """ Constructs a dataframe and modifies `args` in-place. @@ -795,6 +800,8 @@ def build_dataframe(args, attrables, array_attrables): ): args["data_frame"] = pd.DataFrame(args["data_frame"]) + if args["data_frame"] is not None: + reserved_names = _get_reserved_names(args, attrables, array_attrables) # We start from an empty DataFrame except for the case of functions which # implicitely need all dimensions: Splom, Parcats, Parcoords # This could be refined when dimensions is given diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 11f9bbb2a0b..b13b98fa1db 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -62,12 +62,30 @@ def test_name_conflict(): df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) fig = px.scatter(df, x=[10, 1], y="y", color="x") assert np.all(fig.data[0].x == np.array([10, 1])) + fig = px.scatter(df, x=[10, 1], y="y", color=df.x) assert np.all(fig.data[0].x == np.array([10, 1])) + df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], color=[1, 2])) fig = px.scatter(df, x=[10, 1], y="y", size="color", color=df.x) assert np.all(fig.data[0].x == np.array([10, 1])) + df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) + df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) + fig = px.scatter(x=df.y, y=df2.y) + + +def test_repeated_name(): + iris = px.data.iris() + fig = px.scatter( + iris, + x="sepal_width", + y="sepal_length", + hover_data=["petal_length", "petal_width", "species_id"], + custom_data=["species_id", "species"], + ) + assert fig.data[0].customdata.shape[1] == 4 + def test_mixed_case(): df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25])) From 102b89fddd4a72142a32df320635a04f38c7a94f Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 20 Sep 2019 16:14:57 -0400 Subject: [PATCH 39/69] name heuristic --- .../python/plotly/plotly/express/_core.py | 88 ++++++++++++++----- .../tests/test_core/test_px/test_px_input.py | 23 +++-- 2 files changed, 81 insertions(+), 30 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 77a1add67c5..85f84374ea8 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -671,7 +671,7 @@ def one_group(x): def apply_default_cascade(args): - # first we apply px.defaults to unspecified args + # https://github.com/plotly/dash-table/issues/597first we apply px.defaults to unspecified args for param in ( ["color_discrete_sequence", "color_continuous_scale"] + ["symbol_sequence", "line_dash_sequence", "template"] @@ -754,14 +754,12 @@ def apply_default_cascade(args): args["marginal_x"] = None -def _name_heuristic(argument, field_name, df): +def _name_heuristic(argument, field_name, reserved_names): if isinstance(argument, int): argument = str(argument) - if argument not in df.columns: - return argument - elif field_name not in df.columns: + elif field_name not in reserved_names: return field_name - elif field_name + argument not in df.columns: + elif field_name + argument not in reserved_names: return field_name + "_" + argument else: raise NameError( @@ -772,8 +770,29 @@ def _name_heuristic(argument, field_name, df): def _get_reserved_names(args, attrables, array_attrables): - df = args['data_frame'] + df = args["data_frame"] + reserved_names = [] for field in args: + if field not in attrables: + continue + names = args[field] if field in array_attrables else [args[field]] + if names is None: + continue + for arg in names: + if arg is None: + continue + if isinstance(arg, str) and arg not in reserved_names: + reserved_names.append(arg) + if isinstance(arg, int) and str(arg) not in reserved_names: + reserved_names.append(str(arg)) + if isinstance(arg, pd.DataFrame) or isinstance(arg, pd.core.series.Series): + arg_name = arg.name + if arg_name: + in_df = arg is df[arg_name] + if arg_name not in reserved_names: + reserved_names.append(arg_name) + + return reserved_names def build_dataframe(args, attrables, array_attrables): @@ -802,6 +821,10 @@ def build_dataframe(args, attrables, array_attrables): if args["data_frame"] is not None: reserved_names = _get_reserved_names(args, attrables, array_attrables) + else: + reserved_names = [] + canbechanged_names = {} + forbidden_names = reserved_names.copy() # We start from an empty DataFrame except for the case of functions which # implicitely need all dimensions: Splom, Parcats, Parcoords # This could be refined when dimensions is given @@ -868,14 +891,13 @@ def build_dataframe(args, attrables, array_attrables): "length of previous arguments is %d" % (field, len(args["data_frame"][argument]), length) ) - col_name = _name_heuristic(argument, field_name, df) - df[col_name] = args["data_frame"][argument] + df[str(argument)] = args["data_frame"][argument] if isinstance(argument, int): if field_name not in array_attrables: - args[field_name] = col_name + args[field_name] = str(argument) else: - args[field_name][i] = col_name - # continue + args[field_name][i] = str(argument) + continue # Case of index elif isinstance(argument, pd.core.indexes.multi.MultiIndex): raise TypeError( @@ -892,21 +914,35 @@ def build_dataframe(args, attrables, array_attrables): col_name = argument.name # pandas df if col_name is None and is_index: col_name = "index" - if ( - args.get("data_frame") is not None - and col_name in args["data_frame"] - ): + # revert previous argument + if col_name in canbechanged_names: + if argument is not df[col_name]: + old_field, old_i = canbechanged_names[col_name] + # old_field_name = old_field + str(i) if else old_field + df.rename(columns={col_name: old_field}, inplace=True) + args[old_field] = old_field + del canbechanged_names[col_name] + reserved_names.remove(col_name) + if col_name in forbidden_names: # If the name exists but the values have changed # we do not want to keep the name, revert to field + name_in_dataframe = ( + args["data_frame"] is not None + and col_name in args["data_frame"].columns + ) + keep_name = ( + (argument is args["data_frame"][col_name]) + if name_in_dataframe + else (col_name in df and argument is df[col_name]) + ) col_name = ( col_name - if argument is args["data_frame"][col_name] - else field + if keep_name + else _name_heuristic(col_name, field, reserved_names) ) - col_name = _name_heuristic(col_name, field_name, df) else: # numpy array, list... - col_name = field - # col_name = _name_heuristic(field, field, df) + # col_name = field + col_name = _name_heuristic(field, field, reserved_names) if length and len(argument) != length: raise ValueError( "All arguments should have the same length." @@ -914,12 +950,16 @@ def build_dataframe(args, attrables, array_attrables): "length of previous arguments is %d" % (field, len(argument), length) ) - df[col_name] = argument + df[str(col_name)] = argument + if col_name not in reserved_names: + reserved_names.append(str(col_name)) + forbidden_names.append(str(col_name)) + canbechanged_names[str(col_name)] = (field_name, i) # Update argument with column name now that column exists if field_name not in array_attrables: - args[field_name] = col_name + args[field_name] = str(col_name) else: - args[field_name][i] = col_name + args[field_name][i] = str(col_name) args["data_frame"] = df return args diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index b13b98fa1db..743df14f9a8 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -73,17 +73,28 @@ def test_name_conflict(): df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) fig = px.scatter(x=df.y, y=df2.y) + assert np.all(fig.data[0].x == np.array([3, 4])) + assert np.all(fig.data[0].y == np.array([23, 24])) + assert fig.data[0].hovertemplate == "x=%{x}
y=%{y}" + + df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) + df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) + df3 = pd.DataFrame(dict(y=[0.1, 0.2])) + fig = px.scatter(x=df.y, y=df2.y, size=df3.y) + assert np.all(fig.data[0].x == np.array([3, 4])) + assert np.all(fig.data[0].y == np.array([23, 24])) + assert fig.data[0].hovertemplate == "x=%{x}
y=%{y}
size=%{marker.size}" def test_repeated_name(): iris = px.data.iris() fig = px.scatter( - iris, - x="sepal_width", - y="sepal_length", - hover_data=["petal_length", "petal_width", "species_id"], - custom_data=["species_id", "species"], - ) + iris, + x="sepal_width", + y="sepal_length", + hover_data=["petal_length", "petal_width", "species_id"], + custom_data=["species_id", "species"], + ) assert fig.data[0].customdata.shape[1] == 4 From 0a50f6a1977a85c31091f3a9ab2696facad7c3b5 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 20 Sep 2019 16:27:06 -0400 Subject: [PATCH 40/69] Py2 fix --- packages/python/plotly/plotly/express/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 85f84374ea8..7847df0facf 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -824,7 +824,7 @@ def build_dataframe(args, attrables, array_attrables): else: reserved_names = [] canbechanged_names = {} - forbidden_names = reserved_names.copy() + forbidden_names = list(reserved_names) # copy method compatible with Py2 # We start from an empty DataFrame except for the case of functions which # implicitely need all dimensions: Splom, Parcats, Parcoords # This could be refined when dimensions is given From 58c9ebaa8ab29cdd44d16c7bf74ea836e9714d61 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 20 Sep 2019 16:40:02 -0400 Subject: [PATCH 41/69] more complex test --- packages/python/plotly/plotly/express/_core.py | 5 +++++ .../plotly/tests/test_core/test_px/test_px_input.py | 10 ++++++++++ 2 files changed, 15 insertions(+) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 7847df0facf..97463bcab9d 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -770,6 +770,11 @@ def _name_heuristic(argument, field_name, reserved_names): def _get_reserved_names(args, attrables, array_attrables): + """ + This function builds a list of columns of the data_frame argument used + as arguments, either as str/int arguments or given as columns + (pandas series type). + """ df = args["data_frame"] reserved_names = [] for field in args: diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 743df14f9a8..dcb101f4c8d 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -85,6 +85,16 @@ def test_name_conflict(): assert np.all(fig.data[0].y == np.array([23, 24])) assert fig.data[0].hovertemplate == "x=%{x}
y=%{y}
size=%{marker.size}" + df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) + df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) + df3 = pd.DataFrame(dict(y=[0.1, 0.2])) + fig = px.scatter(x=df.y, y=df2.y, hover_data=[df3.y]) + assert np.all(fig.data[0].x == np.array([3, 4])) + assert np.all(fig.data[0].y == np.array([23, 24])) + assert ( + fig.data[0].hovertemplate == "x=%{x}
y=%{y}
hover_data_0=%{customdata[0]}" + ) + def test_repeated_name(): iris = px.data.iris() From 0a925374df5dab9c236570e78c09557c0c42dfe8 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 20 Sep 2019 16:52:53 -0400 Subject: [PATCH 42/69] use sets instead of lists --- .../python/plotly/plotly/express/_core.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 97463bcab9d..2759e0b07f4 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -776,7 +776,7 @@ def _get_reserved_names(args, attrables, array_attrables): (pandas series type). """ df = args["data_frame"] - reserved_names = [] + reserved_names = set() for field in args: if field not in attrables: continue @@ -786,16 +786,15 @@ def _get_reserved_names(args, attrables, array_attrables): for arg in names: if arg is None: continue - if isinstance(arg, str) and arg not in reserved_names: - reserved_names.append(arg) - if isinstance(arg, int) and str(arg) not in reserved_names: - reserved_names.append(str(arg)) + if isinstance(arg, str): + reserved_names.add(arg) + if isinstance(arg, int): + reserved_names.add(str(arg)) if isinstance(arg, pd.DataFrame) or isinstance(arg, pd.core.series.Series): arg_name = arg.name if arg_name: in_df = arg is df[arg_name] - if arg_name not in reserved_names: - reserved_names.append(arg_name) + reserved_names.add(arg_name) return reserved_names @@ -827,9 +826,9 @@ def build_dataframe(args, attrables, array_attrables): if args["data_frame"] is not None: reserved_names = _get_reserved_names(args, attrables, array_attrables) else: - reserved_names = [] + reserved_names = set() canbechanged_names = {} - forbidden_names = list(reserved_names) # copy method compatible with Py2 + forbidden_names = set(reserved_names) # copy method compatible with Py2 # We start from an empty DataFrame except for the case of functions which # implicitely need all dimensions: Splom, Parcats, Parcoords # This could be refined when dimensions is given @@ -956,10 +955,9 @@ def build_dataframe(args, attrables, array_attrables): % (field, len(argument), length) ) df[str(col_name)] = argument - if col_name not in reserved_names: - reserved_names.append(str(col_name)) - forbidden_names.append(str(col_name)) - canbechanged_names[str(col_name)] = (field_name, i) + reserved_names.add(str(col_name)) + forbidden_names.add(str(col_name)) + canbechanged_names[str(col_name)] = (field_name, i) # Update argument with column name now that column exists if field_name not in array_attrables: args[field_name] = str(col_name) From 372bd8f4ca4fa0d219e0cac978616c87289dd132 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 20 Sep 2019 17:36:28 -0400 Subject: [PATCH 43/69] bug correction and new test --- .../python/plotly/plotly/express/_core.py | 68 ++++++++++--------- .../tests/test_core/test_px/test_px_input.py | 6 ++ 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 2759e0b07f4..d14f5716f14 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -754,12 +754,12 @@ def apply_default_cascade(args): args["marginal_x"] = None -def _name_heuristic(argument, field_name, reserved_names): +def _name_heuristic(argument, field_name, used_col_names): if isinstance(argument, int): argument = str(argument) - elif field_name not in reserved_names: + elif field_name not in used_col_names: return field_name - elif field_name + argument not in reserved_names: + elif field_name + argument not in used_col_names: return field_name + "_" + argument else: raise NameError( @@ -769,14 +769,14 @@ def _name_heuristic(argument, field_name, reserved_names): ) -def _get_reserved_names(args, attrables, array_attrables): +def _initialize_argument_col_names(args, attrables, array_attrables): """ This function builds a list of columns of the data_frame argument used as arguments, either as str/int arguments or given as columns (pandas series type). """ df = args["data_frame"] - reserved_names = set() + used_col_names = set() for field in args: if field not in attrables: continue @@ -787,16 +787,16 @@ def _get_reserved_names(args, attrables, array_attrables): if arg is None: continue if isinstance(arg, str): - reserved_names.add(arg) + used_col_names.add(arg) if isinstance(arg, int): - reserved_names.add(str(arg)) + used_col_names.add(str(arg)) if isinstance(arg, pd.DataFrame) or isinstance(arg, pd.core.series.Series): arg_name = arg.name if arg_name: in_df = arg is df[arg_name] - reserved_names.add(arg_name) + used_col_names.add(arg_name) - return reserved_names + return used_col_names def build_dataframe(args, attrables, array_attrables): @@ -818,24 +818,27 @@ def build_dataframe(args, attrables, array_attrables): argument names corresponding to iterables, such as `hover_data`, ... """ # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) - if args["data_frame"] is not None and not isinstance( - args["data_frame"], pd.DataFrame - ): + df_provided = args["data_frame"] is not None + if df_provided and not isinstance(args["data_frame"], pd.DataFrame): args["data_frame"] = pd.DataFrame(args["data_frame"]) - if args["data_frame"] is not None: - reserved_names = _get_reserved_names(args, attrables, array_attrables) - else: - reserved_names = set() - canbechanged_names = {} - forbidden_names = set(reserved_names) # copy method compatible with Py2 # We start from an empty DataFrame except for the case of functions which # implicitely need all dimensions: Splom, Parcats, Parcoords # This could be refined when dimensions is given df = pd.DataFrame() + # Initialize sets of column names + if df_provided: + used_col_names = _initialize_argument_col_names( + args, attrables, array_attrables + ) + else: + used_col_names = set() + canbechanged_names = {} + forbidden_names = set(used_col_names) # copy method compatible with Py2 + if "dimensions" in args and args["dimensions"] is None: - if args["data_frame"] is None: + if not df_provided: raise ValueError( "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument." ) @@ -849,6 +852,7 @@ def build_dataframe(args, attrables, array_attrables): if isinstance(args["data_frame"], pd.DataFrame) else None ) + # Loop over possible arguments for field_name in attrables: argument_list = ( @@ -871,10 +875,11 @@ def build_dataframe(args, attrables, array_attrables): length = len(df) if argument is None: continue + ## ----------------- argument is a col name ---------------------- elif isinstance(argument, str) or isinstance( argument, int - ): # just a column name - if not isinstance(args.get("data_frame"), pd.DataFrame): + ): # just a column name given as str or int + if not df_provided: raise ValueError( "String or int arguments are only possible when a" "DataFrame or an array is provided in the `data_frame`" @@ -908,7 +913,7 @@ def build_dataframe(args, attrables, array_attrables): "Argument '%s' is a pandas MultiIndex." "pandas MultiIndex is not supported by plotly express" % field ) - # Case of numpy array or df column + # ----------------- argument is a column / array / list.... ------- else: is_index = isinstance(argument, pd.core.indexes.range.RangeIndex) # pandas series have a name but it's None @@ -920,16 +925,14 @@ def build_dataframe(args, attrables, array_attrables): col_name = "index" # revert previous argument if col_name in canbechanged_names: - if argument is not df[col_name]: + if not argument.equals(df[col_name]): + print ("will revert", col_name) old_field, old_i = canbechanged_names[col_name] - # old_field_name = old_field + str(i) if else old_field df.rename(columns={col_name: old_field}, inplace=True) args[old_field] = old_field del canbechanged_names[col_name] - reserved_names.remove(col_name) + used_col_names.remove(col_name) if col_name in forbidden_names: - # If the name exists but the values have changed - # we do not want to keep the name, revert to field name_in_dataframe = ( args["data_frame"] is not None and col_name in args["data_frame"].columns @@ -942,11 +945,10 @@ def build_dataframe(args, attrables, array_attrables): col_name = ( col_name if keep_name - else _name_heuristic(col_name, field, reserved_names) + else _name_heuristic(col_name, field, used_col_names) ) else: # numpy array, list... - # col_name = field - col_name = _name_heuristic(field, field, reserved_names) + col_name = _name_heuristic(field, field, used_col_names) if length and len(argument) != length: raise ValueError( "All arguments should have the same length." @@ -955,14 +957,16 @@ def build_dataframe(args, attrables, array_attrables): % (field, len(argument), length) ) df[str(col_name)] = argument - reserved_names.add(str(col_name)) + used_col_names.add(str(col_name)) forbidden_names.add(str(col_name)) canbechanged_names[str(col_name)] = (field_name, i) - # Update argument with column name now that column exists + + # Finally, update argument with column name now that column exists if field_name not in array_attrables: args[field_name] = str(col_name) else: args[field_name][i] = str(col_name) + args["data_frame"] = df return args diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index dcb101f4c8d..b658d5df197 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -95,6 +95,12 @@ def test_name_conflict(): fig.data[0].hovertemplate == "x=%{x}
y=%{y}
hover_data_0=%{customdata[0]}" ) + df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], z=[0.1, 0.2])) + fig = px.scatter(x=df.y, y=df.x, size=df.y) + assert np.all(fig.data[0].x == np.array([3, 4])) + assert np.all(fig.data[0].y == np.array([0, 1])) + assert fig.data[0].hovertemplate == "y=%{x}
x=%{y}
size=%{marker.size}" + def test_repeated_name(): iris = px.data.iris() From 3d0aff0676ce2ce01c6a32bebef68affbb7c6d2a Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Fri, 20 Sep 2019 18:03:08 -0400 Subject: [PATCH 44/69] comments --- packages/python/plotly/plotly/express/_core.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d14f5716f14..4cb4e73e0b4 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -835,7 +835,8 @@ def build_dataframe(args, attrables, array_attrables): else: used_col_names = set() canbechanged_names = {} - forbidden_names = set(used_col_names) # copy method compatible with Py2 + # Names which are already taken + reserved_names = set(used_col_names) # copy method compatible with Py2 if "dimensions" in args and args["dimensions"] is None: if not df_provided: @@ -855,6 +856,7 @@ def build_dataframe(args, attrables, array_attrables): # Loop over possible arguments for field_name in attrables: + # Massaging variables argument_list = ( [args.get(field_name)] if field_name not in array_attrables @@ -871,6 +873,7 @@ def build_dataframe(args, attrables, array_attrables): else [field_name + "_" + str(i) for i in range(len(argument_list))] ) # argument_list and field_list ready, iterate over them + # Core of the loop starts here for i, (argument, field) in enumerate(zip(argument_list, field_list)): length = len(df) if argument is None: @@ -926,19 +929,18 @@ def build_dataframe(args, attrables, array_attrables): # revert previous argument if col_name in canbechanged_names: if not argument.equals(df[col_name]): - print ("will revert", col_name) old_field, old_i = canbechanged_names[col_name] df.rename(columns={col_name: old_field}, inplace=True) args[old_field] = old_field del canbechanged_names[col_name] used_col_names.remove(col_name) - if col_name in forbidden_names: + if col_name in reserved_names: name_in_dataframe = ( args["data_frame"] is not None and col_name in args["data_frame"].columns ) keep_name = ( - (argument is args["data_frame"][col_name]) + argument is args["data_frame"][col_name] if name_in_dataframe else (col_name in df and argument is df[col_name]) ) @@ -958,7 +960,7 @@ def build_dataframe(args, attrables, array_attrables): ) df[str(col_name)] = argument used_col_names.add(str(col_name)) - forbidden_names.add(str(col_name)) + reserved_names.add(str(col_name)) canbechanged_names[str(col_name)] = (field_name, i) # Finally, update argument with column name now that column exists From 5a329ad3a3b460ef85696e9dca9c5a588128484a Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 23 Sep 2019 11:20:37 -0400 Subject: [PATCH 45/69] simpler way to use names --- .../python/plotly/plotly/express/_core.py | 35 ++++--------------- .../tests/test_core/test_px/test_px_input.py | 5 ++- 2 files changed, 8 insertions(+), 32 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 4cb4e73e0b4..55330a0ad06 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -834,9 +834,6 @@ def build_dataframe(args, attrables, array_attrables): ) else: used_col_names = set() - canbechanged_names = {} - # Names which are already taken - reserved_names = set(used_col_names) # copy method compatible with Py2 if "dimensions" in args and args["dimensions"] is None: if not df_provided: @@ -919,6 +916,7 @@ def build_dataframe(args, attrables, array_attrables): # ----------------- argument is a column / array / list.... ------- else: is_index = isinstance(argument, pd.core.indexes.range.RangeIndex) + # First pandas # pandas series have a name but it's None if ( hasattr(argument, "name") and argument.name is not None @@ -926,29 +924,11 @@ def build_dataframe(args, attrables, array_attrables): col_name = argument.name # pandas df if col_name is None and is_index: col_name = "index" - # revert previous argument - if col_name in canbechanged_names: - if not argument.equals(df[col_name]): - old_field, old_i = canbechanged_names[col_name] - df.rename(columns={col_name: old_field}, inplace=True) - args[old_field] = old_field - del canbechanged_names[col_name] - used_col_names.remove(col_name) - if col_name in reserved_names: - name_in_dataframe = ( - args["data_frame"] is not None - and col_name in args["data_frame"].columns - ) - keep_name = ( - argument is args["data_frame"][col_name] - if name_in_dataframe - else (col_name in df and argument is df[col_name]) - ) - col_name = ( - col_name - if keep_name - else _name_heuristic(col_name, field, used_col_names) - ) + if not df_provided: + col_name = field + else: + keep_name = argument is getattr(args["data_frame"], col_name) + col_name = col_name if keep_name else _name_heuristic(col_name, field, used_col_names) else: # numpy array, list... col_name = _name_heuristic(field, field, used_col_names) if length and len(argument) != length: @@ -959,9 +939,6 @@ def build_dataframe(args, attrables, array_attrables): % (field, len(argument), length) ) df[str(col_name)] = argument - used_col_names.add(str(col_name)) - reserved_names.add(str(col_name)) - canbechanged_names[str(col_name)] = (field_name, i) # Finally, update argument with column name now that column exists if field_name not in array_attrables: diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index b658d5df197..fd1a736e3b5 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -57,7 +57,6 @@ def test_pandas_series(): fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"}) assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}" - def test_name_conflict(): df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) fig = px.scatter(df, x=[10, 1], y="y", color="x") @@ -96,10 +95,10 @@ def test_name_conflict(): ) df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], z=[0.1, 0.2])) - fig = px.scatter(x=df.y, y=df.x, size=df.y) + fig = px.scatter(df, x=df.y, y=df.x, size=df.y) assert np.all(fig.data[0].x == np.array([3, 4])) assert np.all(fig.data[0].y == np.array([0, 1])) - assert fig.data[0].hovertemplate == "y=%{x}
x=%{y}
size=%{marker.size}" + assert fig.data[0].hovertemplate == "y=%{x}
x=%{y}" def test_repeated_name(): From e0e0dd574549afd3605089969ee19fd1fa0f4101 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 23 Sep 2019 13:38:18 -0400 Subject: [PATCH 46/69] comments --- .../python/plotly/plotly/express/_core.py | 23 +++++++------------ .../tests/test_core/test_px/test_px_input.py | 11 ++++----- 2 files changed, 13 insertions(+), 21 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 55330a0ad06..e568efeada1 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -822,12 +822,11 @@ def build_dataframe(args, attrables, array_attrables): if df_provided and not isinstance(args["data_frame"], pd.DataFrame): args["data_frame"] = pd.DataFrame(args["data_frame"]) - # We start from an empty DataFrame except for the case of functions which - # implicitely need all dimensions: Splom, Parcats, Parcoords - # This could be refined when dimensions is given + # We start from an empty DataFrame df = pd.DataFrame() - # Initialize sets of column names + # Initialize set of column names + # These are reserved names if df_provided: used_col_names = _initialize_argument_col_names( args, attrables, array_attrables @@ -835,6 +834,7 @@ def build_dataframe(args, attrables, array_attrables): else: used_col_names = set() + # Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords if "dimensions" in args and args["dimensions"] is None: if not df_provided: raise ValueError( @@ -844,13 +844,6 @@ def build_dataframe(args, attrables, array_attrables): df_args = args["data_frame"] df[df_args.columns] = df_args[df_args.columns] - # Valid column names - df_columns = ( - args["data_frame"].columns - if isinstance(args["data_frame"], pd.DataFrame) - else None - ) - # Loop over possible arguments for field_name in attrables: # Massaging variables @@ -862,7 +855,7 @@ def build_dataframe(args, attrables, array_attrables): # argument not specified, continue if argument_list is None or argument_list is [None]: continue - # Argument name: field_name if the argument is a list + # Argument name: field_name if the argument is not a list # Else we give names like ["hover_data_0, hover_data_1"] etc. field_list = ( [field_name] @@ -887,11 +880,11 @@ def build_dataframe(args, attrables, array_attrables): "is of type str or int." % field ) # Check validity of column name - if df_columns is not None and argument not in df_columns: + if argument not in args['data_frame'].columns: raise ValueError( "Value of '%s' is not the name of a column in 'data_frame'. " "Expected one of %s but received: %s" - % (field, str(list(df_columns)), argument) + % (field, str(list(args['data_frame'].columns)), argument) ) if length and len(args["data_frame"][argument]) != length: raise ValueError( @@ -907,7 +900,7 @@ def build_dataframe(args, attrables, array_attrables): else: args[field_name][i] = str(argument) continue - # Case of index + # Case of multiindex elif isinstance(argument, pd.core.indexes.multi.MultiIndex): raise TypeError( "Argument '%s' is a pandas MultiIndex." diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index fd1a736e3b5..7fab8a886bc 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -34,7 +34,7 @@ def test_numpy_labels(): def test_with_index(): tips = px.data.tips() fig = px.scatter(tips, x=tips.index, y="total_bill") - fig = px.scatter(tips, x=tips.index, y=tips.total_bill) + assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" fig = px.scatter(tips, x=tips.index, y=tips.total_bill) assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" # If we tinker with the column then the name is the one of the kw argument @@ -61,13 +61,16 @@ def test_name_conflict(): df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) fig = px.scatter(df, x=[10, 1], y="y", color="x") assert np.all(fig.data[0].x == np.array([10, 1])) + assert fig.data[0].hovertemplate == "x_x=%{x}
y=%{y}
x=%{marker.color}" fig = px.scatter(df, x=[10, 1], y="y", color=df.x) assert np.all(fig.data[0].x == np.array([10, 1])) + assert fig.data[0].hovertemplate == "x_x=%{x}
y=%{y}
x=%{marker.color}" df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], color=[1, 2])) fig = px.scatter(df, x=[10, 1], y="y", size="color", color=df.x) assert np.all(fig.data[0].x == np.array([10, 1])) + assert fig.data[0].hovertemplate == "x_x=%{x}
y=%{y}
color=%{marker.size}
x=%{marker.color}" df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) @@ -98,7 +101,7 @@ def test_name_conflict(): fig = px.scatter(df, x=df.y, y=df.x, size=df.y) assert np.all(fig.data[0].x == np.array([3, 4])) assert np.all(fig.data[0].y == np.array([0, 1])) - assert fig.data[0].hovertemplate == "y=%{x}
x=%{y}" + assert fig.data[0].hovertemplate == "y=%{marker.size}
x=%{y}" def test_repeated_name(): @@ -113,10 +116,6 @@ def test_repeated_name(): assert fig.data[0].customdata.shape[1] == 4 -def test_mixed_case(): - df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25])) - fig = px.scatter(df, x="time", y="temperature", color=[1, 3, 9]) - def test_arrayattrable_numpy(): tips = px.data.tips() From 6bced739e681b9956f200c4804b69207b73b2fd7 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 23 Sep 2019 13:44:35 -0400 Subject: [PATCH 47/69] docstring --- packages/python/plotly/plotly/express/_core.py | 10 +++++++--- packages/python/plotly/plotly/express/_doc.py | 6 +++++- .../plotly/tests/test_core/test_px/test_px_input.py | 7 +++++-- 3 files changed, 17 insertions(+), 6 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index e568efeada1..675e3fd5e6b 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -880,11 +880,11 @@ def build_dataframe(args, attrables, array_attrables): "is of type str or int." % field ) # Check validity of column name - if argument not in args['data_frame'].columns: + if argument not in args["data_frame"].columns: raise ValueError( "Value of '%s' is not the name of a column in 'data_frame'. " "Expected one of %s but received: %s" - % (field, str(list(args['data_frame'].columns)), argument) + % (field, str(list(args["data_frame"].columns)), argument) ) if length and len(args["data_frame"][argument]) != length: raise ValueError( @@ -921,7 +921,11 @@ def build_dataframe(args, attrables, array_attrables): col_name = field else: keep_name = argument is getattr(args["data_frame"], col_name) - col_name = col_name if keep_name else _name_heuristic(col_name, field, used_col_names) + col_name = ( + col_name + if keep_name + else _name_heuristic(col_name, field, used_col_names) + ) else: # numpy array, list... col_name = _name_heuristic(field, field, used_col_names) if length and len(argument) != length: diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 44abcfe6478..1ecd1d4a3c7 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -14,7 +14,11 @@ # TODO standardize positioning and casing of 'default' docs = dict( - data_frame=["A 'tidy' `pandas.DataFrame`"], + data_frame=[ + "A `pandas.DataFrame`, or a `NumPy` array or a dictionary", + "which are tranformed internally to `pandas.DataFrame`. This argument needs" + "to be passed for column names (and not keyword names) to be used.", + ], x=[ colref, "Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.", diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 7fab8a886bc..c106f32c6f1 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -57,6 +57,7 @@ def test_pandas_series(): fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"}) assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}" + def test_name_conflict(): df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) fig = px.scatter(df, x=[10, 1], y="y", color="x") @@ -70,7 +71,10 @@ def test_name_conflict(): df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], color=[1, 2])) fig = px.scatter(df, x=[10, 1], y="y", size="color", color=df.x) assert np.all(fig.data[0].x == np.array([10, 1])) - assert fig.data[0].hovertemplate == "x_x=%{x}
y=%{y}
color=%{marker.size}
x=%{marker.color}" + assert ( + fig.data[0].hovertemplate + == "x_x=%{x}
y=%{y}
color=%{marker.size}
x=%{marker.color}" + ) df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) @@ -116,7 +120,6 @@ def test_repeated_name(): assert fig.data[0].customdata.shape[1] == 4 - def test_arrayattrable_numpy(): tips = px.data.tips() fig = px.scatter( From bcc41a24d9e47f07b8729587f237fc6bb198ace0 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 23 Sep 2019 14:08:39 -0400 Subject: [PATCH 48/69] more tests --- packages/python/plotly/plotly/express/_core.py | 17 ++++++++++------- .../tests/test_core/test_px/test_px_input.py | 9 +++++++++ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 675e3fd5e6b..efa732bde84 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -792,9 +792,10 @@ def _initialize_argument_col_names(args, attrables, array_attrables): used_col_names.add(str(arg)) if isinstance(arg, pd.DataFrame) or isinstance(arg, pd.core.series.Series): arg_name = arg.name - if arg_name: + if arg_name and hasattr(df, arg_name): in_df = arg is df[arg_name] - used_col_names.add(arg_name) + if in_df: + used_col_names.add(arg_name) return used_col_names @@ -828,11 +829,11 @@ def build_dataframe(args, attrables, array_attrables): # Initialize set of column names # These are reserved names if df_provided: - used_col_names = _initialize_argument_col_names( + reserved_names = _initialize_argument_col_names( args, attrables, array_attrables ) else: - used_col_names = set() + reserved_names = set() # Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords if "dimensions" in args and args["dimensions"] is None: @@ -920,14 +921,16 @@ def build_dataframe(args, attrables, array_attrables): if not df_provided: col_name = field else: - keep_name = argument is getattr(args["data_frame"], col_name) + keep_name = hasattr( + args["data_frame"], col_name + ) and argument is getattr(args["data_frame"], col_name) col_name = ( col_name if keep_name - else _name_heuristic(col_name, field, used_col_names) + else _name_heuristic(col_name, field, reserved_names) ) else: # numpy array, list... - col_name = _name_heuristic(field, field, used_col_names) + col_name = _name_heuristic(field, field, reserved_names) if length and len(argument) != length: raise ValueError( "All arguments should have the same length." diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index c106f32c6f1..b4abd580ec4 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -58,6 +58,15 @@ def test_pandas_series(): assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}" +def test_several_dataframes(): + df = pd.DataFrame(dict(x=[0, 1], y=[1, 10], z=[0.1, 0.8])) + df2 = pd.DataFrame(dict(time=[23, 26], money=[100, 200])) + fig = px.scatter(df, x="z", y=df2.money, size="y") + assert fig.data[0].hovertemplate == "z=%{x}
y_money=%{y}
y=%{marker.size}" + fig = px.scatter(df2, x=df.z, y=df2.money, size=df.y) + assert fig.data[0].hovertemplate == "x=%{x}
money=%{y}
size=%{marker.size}" + + def test_name_conflict(): df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) fig = px.scatter(df, x=[10, 1], y="y", color="x") From 382e768da26caf905758c4eda3fe7af62f6eb9d6 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 23 Sep 2019 17:06:36 -0400 Subject: [PATCH 49/69] integer column names --- packages/python/plotly/plotly/express/_core.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index efa732bde84..d17977272ef 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -757,7 +757,7 @@ def apply_default_cascade(args): def _name_heuristic(argument, field_name, used_col_names): if isinstance(argument, int): argument = str(argument) - elif field_name not in used_col_names: + if field_name not in used_col_names: return field_name elif field_name + argument not in used_col_names: return field_name + "_" + argument @@ -788,8 +788,6 @@ def _initialize_argument_col_names(args, attrables, array_attrables): continue if isinstance(arg, str): used_col_names.add(arg) - if isinstance(arg, int): - used_col_names.add(str(arg)) if isinstance(arg, pd.DataFrame) or isinstance(arg, pd.core.series.Series): arg_name = arg.name if arg_name and hasattr(df, arg_name): @@ -894,12 +892,14 @@ def build_dataframe(args, attrables, array_attrables): "length of previous arguments is %d" % (field, len(args["data_frame"][argument]), length) ) - df[str(argument)] = args["data_frame"][argument] + col_name = argument if isinstance(argument, int): + col_name = _name_heuristic(argument, field, reserved_names) if field_name not in array_attrables: - args[field_name] = str(argument) + args[field_name] = col_name else: - args[field_name][i] = str(argument) + args[field_name][i] = col_name + df[col_name] = args["data_frame"][argument] continue # Case of multiindex elif isinstance(argument, pd.core.indexes.multi.MultiIndex): From 82c6592802cf3106ad9d1a145eb809330ff551dd Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 09:05:25 -0400 Subject: [PATCH 50/69] typo --- packages/python/plotly/plotly/express/_core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d17977272ef..6dcc0b172f4 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -671,7 +671,7 @@ def one_group(x): def apply_default_cascade(args): - # https://github.com/plotly/dash-table/issues/597first we apply px.defaults to unspecified args + # first we apply px.defaults to unspecified args for param in ( ["color_discrete_sequence", "color_continuous_scale"] + ["symbol_sequence", "line_dash_sequence", "template"] From 9caeee1f21575a08523cfad2906a05e8defd03ef Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 09:28:21 -0400 Subject: [PATCH 51/69] better error message --- packages/python/plotly/plotly/express/_core.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 6dcc0b172f4..d7c982d15a0 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -887,10 +887,15 @@ def build_dataframe(args, attrables, array_attrables): ) if length and len(args["data_frame"][argument]) != length: raise ValueError( - "All arguments should have the same length." - "The length of column argument `df[%s]` is %d, whereas the" - "length of previous arguments is %d" - % (field, len(args["data_frame"][argument]), length) + "All arguments should have the same length. " + "The length of column argument `df[%s]` is %d, whereas the " + "length of previous arguments %s is %d" + % ( + field, + len(args["data_frame"][argument]), + str(list(df.columns)), + length, + ) ) col_name = argument if isinstance(argument, int): @@ -935,8 +940,8 @@ def build_dataframe(args, attrables, array_attrables): raise ValueError( "All arguments should have the same length." "The length of argument `%s` is %d, whereas the" - "length of previous arguments is %d" - % (field, len(argument), length) + "length of previous arguments %s is %d" + % (field, len(argument), str(list(df.columns)), length) ) df[str(col_name)] = argument From adc68d3c168edfc8161252f56c6ee00b9840a45a Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 09:55:52 -0400 Subject: [PATCH 52/69] better test --- packages/python/plotly/plotly/express/_core.py | 3 ++- .../plotly/plotly/tests/test_core/test_px/test_px_input.py | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d7c982d15a0..d2339b2daa0 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -910,7 +910,8 @@ def build_dataframe(args, attrables, array_attrables): elif isinstance(argument, pd.core.indexes.multi.MultiIndex): raise TypeError( "Argument '%s' is a pandas MultiIndex." - "pandas MultiIndex is not supported by plotly express" % field + "pandas MultiIndex is not supported by plotly express " + "at the moment." % field ) # ----------------- argument is a column / array / list.... ------- else: diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index b4abd580ec4..ea3a0798898 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -22,6 +22,9 @@ def test_numpy(): fig = px.scatter(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9]) + assert np.all(fig.data[0].x == np.array([1, 2, 3])) + assert np.all(fig.data[0].y == np.array([2, 3, 4])) + assert np.all(fig.data[0].marker.color == np.array([1, 3, 9])) def test_numpy_labels(): From 5dc0ab1492065e9b92a5cfa4c2fe291b0f565184 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 10:06:55 -0400 Subject: [PATCH 53/69] order of tests --- .../plotly/plotly/tests/test_core/test_px/test_px_input.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index ea3a0798898..15bd4b79c48 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -40,9 +40,6 @@ def test_with_index(): assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" fig = px.scatter(tips, x=tips.index, y=tips.total_bill) assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" - # If we tinker with the column then the name is the one of the kw argument - fig = px.scatter(tips, x=tips.index, y=10 * tips.total_bill) - assert fig.data[0]["hovertemplate"] == "index=%{x}
y=%{y}" fig = px.scatter(tips, x=tips.index, y=tips.total_bill, labels={"index": "number"}) assert fig.data[0]["hovertemplate"] == "number=%{x}
total_bill=%{y}" # We do not allow "x=index" @@ -57,6 +54,8 @@ def test_with_index(): def test_pandas_series(): tips = px.data.tips() before_tip = tips.total_bill - tips.tip + fig = px.bar(tips, x="day", y=before_tip) + assert fig.data[0].hovertemplate == "day=%{x}
y=%{y}" fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"}) assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}" From f5d056c6a1b10926631bb893b5c7c6438a5db450 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 10:10:45 -0400 Subject: [PATCH 54/69] doc modification --- packages/python/plotly/plotly/express/_doc.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 1ecd1d4a3c7..81274588952 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -1,8 +1,8 @@ import inspect -colref = "(string: name of column in `data_frame`, or array_like object)" +colref = "(string or int: name of column in `data_frame`, or array_like object)" colref_list = ( - "(list of string: names of columns in `data_frame`, or array_like objects)" + "(list of string or int: names of columns in `data_frame`, or array_like objects)" ) # TODO contents of columns From bedf74fe3d73a729b7552530062c91cc8152024f Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 10:23:51 -0400 Subject: [PATCH 55/69] qa --- packages/python/plotly/plotly/express/_core.py | 2 +- packages/python/plotly/plotly/express/_doc.py | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d2339b2daa0..c1c3b43f944 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -788,7 +788,7 @@ def _initialize_argument_col_names(args, attrables, array_attrables): continue if isinstance(arg, str): used_col_names.add(arg) - if isinstance(arg, pd.DataFrame) or isinstance(arg, pd.core.series.Series): + if isinstance(arg, pd.core.series.Series): arg_name = arg.name if arg_name and hasattr(df, arg_name): in_df = arg is df[arg_name] diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 81274588952..7df598dccb3 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -1,9 +1,7 @@ import inspect -colref = "(string or int: name of column in `data_frame`, or array_like object)" -colref_list = ( - "(list of string or int: names of columns in `data_frame`, or array_like objects)" -) +colref = "(string or int: name of column in `data_frame`, or pandas Series, or array_like object)" +colref_list = "(list of string or int: names of columns in `data_frame`, or pandas Series, or array_like objects)" # TODO contents of columns # TODO explain categorical From 074342b72f1bee889049b1e3b0e99df86dc72eb5 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 10:53:56 -0400 Subject: [PATCH 56/69] case of named index --- packages/python/plotly/plotly/express/_core.py | 11 ++++++++--- .../plotly/tests/test_core/test_px/test_px_input.py | 4 ++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index c1c3b43f944..54f71a5d014 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -927,9 +927,14 @@ def build_dataframe(args, attrables, array_attrables): if not df_provided: col_name = field else: - keep_name = hasattr( - args["data_frame"], col_name - ) and argument is getattr(args["data_frame"], col_name) + if is_index: + keep_name = ( + df_provided and argument is args["data_frame"].index + ) + else: + keep_name = hasattr( + args["data_frame"], col_name + ) and argument is getattr(args["data_frame"], col_name) col_name = ( col_name if keep_name diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 15bd4b79c48..1b278532d30 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -49,6 +49,10 @@ def test_with_index(): "ValueError: Value of 'x' is not the name of a column in 'data_frame'" in str(err_msg.value) ) + tips = px.data.tips() + tips.index.name = "item" + fig = px.scatter(tips, x=tips.index, y="total_bill") + assert fig.data[0]["hovertemplate"] == "item=%{x}
total_bill=%{y}" def test_pandas_series(): From ead8430733930442983a7dbcabf11722d7b21fb9 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 11:12:20 -0400 Subject: [PATCH 57/69] be more defensive with names --- .../python/plotly/plotly/express/_core.py | 7 +--- .../tests/test_core/test_px/test_px_input.py | 37 +++++++------------ 2 files changed, 16 insertions(+), 28 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 54f71a5d014..2bdcd07f3de 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -759,13 +759,10 @@ def _name_heuristic(argument, field_name, used_col_names): argument = str(argument) if field_name not in used_col_names: return field_name - elif field_name + argument not in used_col_names: - return field_name + "_" + argument else: raise NameError( - "A name conflict was encountered for argument %s." - "Columns with names %s, %s and %s are already used" - % (field_name, argument, field_name, field_name + "_" + argument) + "A name conflict was encountered for argument %s. " + "A column with name %s is already used." % (field_name, field_name) ) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 1b278532d30..ada29e82e44 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -67,30 +67,19 @@ def test_pandas_series(): def test_several_dataframes(): df = pd.DataFrame(dict(x=[0, 1], y=[1, 10], z=[0.1, 0.8])) df2 = pd.DataFrame(dict(time=[23, 26], money=[100, 200])) - fig = px.scatter(df, x="z", y=df2.money, size="y") - assert fig.data[0].hovertemplate == "z=%{x}
y_money=%{y}
y=%{marker.size}" - fig = px.scatter(df2, x=df.z, y=df2.money, size=df.y) + fig = px.scatter(df, x="z", y=df2.money, size="x") + assert fig.data[0].hovertemplate == "z=%{x}
y=%{y}
x=%{marker.size}" + fig = px.scatter(df2, x=df.z, y=df2.money, size=df.z) assert fig.data[0].hovertemplate == "x=%{x}
money=%{y}
size=%{marker.size}" - - -def test_name_conflict(): - df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) - fig = px.scatter(df, x=[10, 1], y="y", color="x") - assert np.all(fig.data[0].x == np.array([10, 1])) - assert fig.data[0].hovertemplate == "x_x=%{x}
y=%{y}
x=%{marker.color}" - - fig = px.scatter(df, x=[10, 1], y="y", color=df.x) - assert np.all(fig.data[0].x == np.array([10, 1])) - assert fig.data[0].hovertemplate == "x_x=%{x}
y=%{y}
x=%{marker.color}" - - df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], color=[1, 2])) - fig = px.scatter(df, x=[10, 1], y="y", size="color", color=df.x) - assert np.all(fig.data[0].x == np.array([10, 1])) - assert ( - fig.data[0].hovertemplate - == "x_x=%{x}
y=%{y}
color=%{marker.size}
x=%{marker.color}" - ) - + # Name conflict + with pytest.raises(NameError) as err_msg: + fig = px.scatter(df, x="z", y=df2.money, size="y") + assert "A name conflict was encountered for argument y" in str(err_msg.value) + with pytest.raises(NameError) as err_msg: + fig = px.scatter(df, x="z", y=df2.money, size=df.y) + assert "A name conflict was encountered for argument y" in str(err_msg.value) + + # No conflict when the dataframe is not given, fields are used df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) fig = px.scatter(x=df.y, y=df2.y) @@ -116,6 +105,8 @@ def test_name_conflict(): fig.data[0].hovertemplate == "x=%{x}
y=%{y}
hover_data_0=%{customdata[0]}" ) + +def test_name_heuristics(): df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], z=[0.1, 0.2])) fig = px.scatter(df, x=df.y, y=df.x, size=df.y) assert np.all(fig.data[0].x == np.array([3, 4])) From db109674ea546cae49a5fccc15d46439d5d33c63 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 15:35:27 -0400 Subject: [PATCH 58/69] better error msg with index --- packages/python/plotly/plotly/express/_core.py | 7 ++++++- .../plotly/plotly/tests/test_core/test_px/test_px_input.py | 5 ++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 2bdcd07f3de..003e93df6d7 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -877,11 +877,16 @@ def build_dataframe(args, attrables, array_attrables): ) # Check validity of column name if argument not in args["data_frame"].columns: - raise ValueError( + err_msg = ( "Value of '%s' is not the name of a column in 'data_frame'. " "Expected one of %s but received: %s" % (field, str(list(args["data_frame"].columns)), argument) ) + if argument == "index": + err_msg += ( + "\n To use the index, pass it in directly as `df.index`." + ) + raise ValueError(err_msg) if length and len(args["data_frame"][argument]) != length: raise ValueError( "All arguments should have the same length. " diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index ada29e82e44..234e6270def 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -45,9 +45,8 @@ def test_with_index(): # We do not allow "x=index" with pytest.raises(ValueError) as err_msg: fig = px.scatter(tips, x="index", y="total_bill") - assert ( - "ValueError: Value of 'x' is not the name of a column in 'data_frame'" - in str(err_msg.value) + assert "To use the index, pass it in directly as `df.index`." in str( + err_msg.value ) tips = px.data.tips() tips.index.name = "item" From ef29378857c50bf2f990c3bde1b669dff6c47165 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 17:22:41 -0400 Subject: [PATCH 59/69] do not modify input arguments --- .../python/plotly/plotly/express/_core.py | 35 ++++++++++--------- .../tests/test_core/test_px/test_px_input.py | 23 ++++++++++-- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 003e93df6d7..22c4da9a966 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -795,7 +795,7 @@ def _initialize_argument_col_names(args, attrables, array_attrables): return used_col_names -def build_dataframe(args, attrables, array_attrables): +def build_dataframe(input_args, attrables, array_attrables): """ Constructs a dataframe and modifies `args` in-place. @@ -813,6 +813,12 @@ def build_dataframe(args, attrables, array_attrables): array_attrables : list argument names corresponding to iterables, such as `hover_data`, ... """ + args = dict(input_args) + for field in args: + if field in array_attrables and isinstance( + args[field], pd.core.indexes.base.Index + ): + args[field] = list(args[field]) # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) df_provided = args["data_frame"] is not None if df_provided and not isinstance(args["data_frame"], pd.DataFrame): @@ -864,8 +870,15 @@ def build_dataframe(args, attrables, array_attrables): length = len(df) if argument is None: continue + # Case of multiindex + if isinstance(argument, pd.core.indexes.multi.MultiIndex): + raise TypeError( + "Argument '%s' is a pandas MultiIndex." + "pandas MultiIndex is not supported by plotly express " + "at the moment." % field + ) ## ----------------- argument is a col name ---------------------- - elif isinstance(argument, str) or isinstance( + if isinstance(argument, str) or isinstance( argument, int ): # just a column name given as str or int if not df_provided: @@ -902,19 +915,7 @@ def build_dataframe(args, attrables, array_attrables): col_name = argument if isinstance(argument, int): col_name = _name_heuristic(argument, field, reserved_names) - if field_name not in array_attrables: - args[field_name] = col_name - else: - args[field_name][i] = col_name df[col_name] = args["data_frame"][argument] - continue - # Case of multiindex - elif isinstance(argument, pd.core.indexes.multi.MultiIndex): - raise TypeError( - "Argument '%s' is a pandas MultiIndex." - "pandas MultiIndex is not supported by plotly express " - "at the moment." % field - ) # ----------------- argument is a column / array / list.... ------- else: is_index = isinstance(argument, pd.core.indexes.range.RangeIndex) @@ -980,7 +981,7 @@ def infer_config(args, constructor, trace_patch): if group_attr in args: all_attrables += [group_attr] - build_dataframe(args, all_attrables, array_attrables) + args = build_dataframe(args, all_attrables, array_attrables) attrs = [k for k in attrables if k in args] grouped_attrs = [] @@ -1058,7 +1059,7 @@ def infer_config(args, constructor, trace_patch): # Create trace specs trace_specs = make_trace_spec(args, constructor, attrs, trace_patch) - return trace_specs, grouped_mappings, sizeref, show_colorbar + return args, trace_specs, grouped_mappings, sizeref, show_colorbar def get_orderings(args, grouper, grouped): @@ -1096,7 +1097,7 @@ def get_orderings(args, grouper, grouped): def make_figure(args, constructor, trace_patch={}, layout_patch={}): apply_default_cascade(args) - trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( + args, trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( args, constructor, trace_patch ) grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 234e6270def..e810325fe49 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -227,11 +227,8 @@ def test_build_df_from_lists(): def test_build_df_with_index(): tips = px.data.tips() args = dict(data_frame=tips, x=tips.index, y="total_bill") - changed_output = dict(x="index") out = build_dataframe(args, all_attrables, array_attrables) assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"]) - out.pop("data_frame") - assert out == args def test_splom_case(): @@ -261,3 +258,23 @@ def test_data_frame_from_dict(): fig = px.scatter({"time": [0, 1], "money": [1, 2]}, x="time", y="money") assert fig.data[0].hovertemplate == "time=%{x}
money=%{y}" assert np.all(fig.data[0].x == [0, 1]) + + +def test_arguments_not_modified(): + iris = px.data.iris() + petal_length = iris.petal_length + fig = px.scatter(iris, x=petal_length, y="petal_width") + assert iris.petal_length.equals(petal_length) + + +def test_pass_df_columns(): + tips = px.data.tips() + fig = px.histogram( + tips, + x="total_bill", + y="tip", + color="sex", + marginal="rug", + hover_data=tips.columns, + ) + assert fig.data[1].hovertemplate.count("customdata") == len(tips.columns) From 5e40653158fb59576797d96106e5f98d07e9c560 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 21:12:29 -0400 Subject: [PATCH 60/69] corrected bug --- packages/python/plotly/plotly/express/_core.py | 6 ++---- .../plotly/plotly/tests/test_core/test_px/test_px_input.py | 4 +++- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 22c4da9a966..891295dc3f3 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -813,11 +813,9 @@ def build_dataframe(input_args, attrables, array_attrables): array_attrables : list argument names corresponding to iterables, such as `hover_data`, ... """ - args = dict(input_args) + args = input_args.copy() for field in args: - if field in array_attrables and isinstance( - args[field], pd.core.indexes.base.Index - ): + if field in array_attrables and args[field] is not None: args[field] = list(args[field]) # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) df_provided = args["data_frame"] is not None diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index e810325fe49..1567d7d8bad 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -263,8 +263,10 @@ def test_data_frame_from_dict(): def test_arguments_not_modified(): iris = px.data.iris() petal_length = iris.petal_length - fig = px.scatter(iris, x=petal_length, y="petal_width") + hover_data = [iris.sepal_length] + fig = px.scatter(iris, x=petal_length, y="petal_width", hover_data=hover_data) assert iris.petal_length.equals(petal_length) + assert iris.sepal_length.equals(hover_data[0]) def test_pass_df_columns(): From 78564cc8d1a6b8ade0adc607235764725be2b01e Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 21:19:26 -0400 Subject: [PATCH 61/69] name consistency --- packages/python/plotly/plotly/express/_core.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 891295dc3f3..0e13318f6ae 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -754,10 +754,10 @@ def apply_default_cascade(args): args["marginal_x"] = None -def _name_heuristic(argument, field_name, used_col_names): +def _name_heuristic(argument, field_name, reserved_names): if isinstance(argument, int): argument = str(argument) - if field_name not in used_col_names: + if field_name not in reserved_names: return field_name else: raise NameError( @@ -766,14 +766,14 @@ def _name_heuristic(argument, field_name, used_col_names): ) -def _initialize_argument_col_names(args, attrables, array_attrables): +def _get_reserved_col_names(args, attrables, array_attrables): """ This function builds a list of columns of the data_frame argument used as arguments, either as str/int arguments or given as columns (pandas series type). """ df = args["data_frame"] - used_col_names = set() + reserved_names = set() for field in args: if field not in attrables: continue @@ -784,15 +784,15 @@ def _initialize_argument_col_names(args, attrables, array_attrables): if arg is None: continue if isinstance(arg, str): - used_col_names.add(arg) + reserved_names.add(arg) if isinstance(arg, pd.core.series.Series): arg_name = arg.name if arg_name and hasattr(df, arg_name): in_df = arg is df[arg_name] if in_df: - used_col_names.add(arg_name) + reserved_names.add(arg_name) - return used_col_names + return reserved_names def build_dataframe(input_args, attrables, array_attrables): @@ -828,9 +828,7 @@ def build_dataframe(input_args, attrables, array_attrables): # Initialize set of column names # These are reserved names if df_provided: - reserved_names = _initialize_argument_col_names( - args, attrables, array_attrables - ) + reserved_names = _get_reserved_col_names(args, attrables, array_attrables) else: reserved_names = set() From 2534ca91201a992bff63cd2444e30ab86881138c Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Tue, 24 Sep 2019 21:31:04 -0400 Subject: [PATCH 62/69] name consistency --- .../python/plotly/plotly/express/_core.py | 35 +++++++++---------- 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 0e13318f6ae..128fc51cada 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -821,9 +821,10 @@ def build_dataframe(input_args, attrables, array_attrables): df_provided = args["data_frame"] is not None if df_provided and not isinstance(args["data_frame"], pd.DataFrame): args["data_frame"] = pd.DataFrame(args["data_frame"]) + df_input = args["data_frame"] # We start from an empty DataFrame - df = pd.DataFrame() + df_output = pd.DataFrame() # Initialize set of column names # These are reserved names @@ -839,8 +840,7 @@ def build_dataframe(input_args, attrables, array_attrables): "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument." ) else: - df_args = args["data_frame"] - df[df_args.columns] = df_args[df_args.columns] + df_output[df_input.columns] = df_input[df_input.columns] # Loop over possible arguments for field_name in attrables: @@ -863,7 +863,7 @@ def build_dataframe(input_args, attrables, array_attrables): # argument_list and field_list ready, iterate over them # Core of the loop starts here for i, (argument, field) in enumerate(zip(argument_list, field_list)): - length = len(df) + length = len(df_output) if argument is None: continue # Case of multiindex @@ -885,33 +885,33 @@ def build_dataframe(input_args, attrables, array_attrables): "is of type str or int." % field ) # Check validity of column name - if argument not in args["data_frame"].columns: + if argument not in df_input.columns: err_msg = ( "Value of '%s' is not the name of a column in 'data_frame'. " "Expected one of %s but received: %s" - % (field, str(list(args["data_frame"].columns)), argument) + % (field, str(list(df_input.columns)), argument) ) if argument == "index": err_msg += ( "\n To use the index, pass it in directly as `df.index`." ) raise ValueError(err_msg) - if length and len(args["data_frame"][argument]) != length: + if length and len(df_input[argument]) != length: raise ValueError( "All arguments should have the same length. " "The length of column argument `df[%s]` is %d, whereas the " "length of previous arguments %s is %d" % ( field, - len(args["data_frame"][argument]), - str(list(df.columns)), + len(df_input[argument]), + str(list(df_output.columns)), length, ) ) col_name = argument if isinstance(argument, int): col_name = _name_heuristic(argument, field, reserved_names) - df[col_name] = args["data_frame"][argument] + df_output[col_name] = df_input[argument] # ----------------- argument is a column / array / list.... ------- else: is_index = isinstance(argument, pd.core.indexes.range.RangeIndex) @@ -927,13 +927,12 @@ def build_dataframe(input_args, attrables, array_attrables): col_name = field else: if is_index: - keep_name = ( - df_provided and argument is args["data_frame"].index - ) + keep_name = df_provided and argument is df_input.index else: + # we use getattr/hasattr because of index keep_name = hasattr( - args["data_frame"], col_name - ) and argument is getattr(args["data_frame"], col_name) + df_input, col_name + ) and argument is getattr(df_input, col_name) col_name = ( col_name if keep_name @@ -946,9 +945,9 @@ def build_dataframe(input_args, attrables, array_attrables): "All arguments should have the same length." "The length of argument `%s` is %d, whereas the" "length of previous arguments %s is %d" - % (field, len(argument), str(list(df.columns)), length) + % (field, len(argument), str(list(df_output.columns)), length) ) - df[str(col_name)] = argument + df_output[str(col_name)] = argument # Finally, update argument with column name now that column exists if field_name not in array_attrables: @@ -956,7 +955,7 @@ def build_dataframe(input_args, attrables, array_attrables): else: args[field_name][i] = str(col_name) - args["data_frame"] = df + args["data_frame"] = df_output return args From c88660f03d8a11a5db1978567f1a97641687781a Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 25 Sep 2019 09:59:32 -0400 Subject: [PATCH 63/69] qa --- packages/python/plotly/plotly/express/_core.py | 3 +-- .../plotly/plotly/tests/test_core/test_px/test_px_input.py | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 128fc51cada..b99af76c454 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -795,7 +795,7 @@ def _get_reserved_col_names(args, attrables, array_attrables): return reserved_names -def build_dataframe(input_args, attrables, array_attrables): +def build_dataframe(args, attrables, array_attrables): """ Constructs a dataframe and modifies `args` in-place. @@ -813,7 +813,6 @@ def build_dataframe(input_args, attrables, array_attrables): array_attrables : list argument names corresponding to iterables, such as `hover_data`, ... """ - args = input_args.copy() for field in args: if field in array_attrables and args[field] is not None: args[field] = list(args[field]) diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 1567d7d8bad..691b7c1eb90 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -280,3 +280,5 @@ def test_pass_df_columns(): hover_data=tips.columns, ) assert fig.data[1].hovertemplate.count("customdata") == len(tips.columns) + tips_copy = px.data.tips() + assert tips_copy.columns.equals(tips.columns) From b5adbcf140fc9c381eefbfe5aa439bf8e6947495 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Wed, 25 Sep 2019 10:31:21 -0400 Subject: [PATCH 64/69] if args[field] is a dict it should stay a dict --- packages/python/plotly/plotly/express/_core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index b99af76c454..d2a64909e4a 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -815,7 +815,11 @@ def build_dataframe(args, attrables, array_attrables): """ for field in args: if field in array_attrables and args[field] is not None: - args[field] = list(args[field]) + args[field] = ( + dict(args[field]) + if isinstance(args[field], dict) + else list(args[field]) + ) # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) df_provided = args["data_frame"] is not None if df_provided and not isinstance(args["data_frame"], pd.DataFrame): From b375f91f8da712c3050b1b095bbf3800c26cec0c Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Thu, 26 Sep 2019 10:12:27 -0400 Subject: [PATCH 65/69] addressed Jon's comments --- .../python/plotly/plotly/express/_core.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index d2a64909e4a..e1dae2b79dc 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -754,9 +754,7 @@ def apply_default_cascade(args): args["marginal_x"] = None -def _name_heuristic(argument, field_name, reserved_names): - if isinstance(argument, int): - argument = str(argument) +def _check_name_not_reserved(field_name, reserved_names): if field_name not in reserved_names: return field_name else: @@ -783,9 +781,9 @@ def _get_reserved_col_names(args, attrables, array_attrables): for arg in names: if arg is None: continue - if isinstance(arg, str): + elif isinstance(arg, str): # no need to add ints since kw arg are not ints reserved_names.add(arg) - if isinstance(arg, pd.core.series.Series): + elif isinstance(arg, pd.Series): arg_name = arg.name if arg_name and hasattr(df, arg_name): in_df = arg is df[arg_name] @@ -870,7 +868,7 @@ def build_dataframe(args, attrables, array_attrables): if argument is None: continue # Case of multiindex - if isinstance(argument, pd.core.indexes.multi.MultiIndex): + if isinstance(argument, pd.MultiIndex): raise TypeError( "Argument '%s' is a pandas MultiIndex." "pandas MultiIndex is not supported by plotly express " @@ -911,13 +909,11 @@ def build_dataframe(args, attrables, array_attrables): length, ) ) - col_name = argument - if isinstance(argument, int): - col_name = _name_heuristic(argument, field, reserved_names) + col_name = str(argument) df_output[col_name] = df_input[argument] # ----------------- argument is a column / array / list.... ------- else: - is_index = isinstance(argument, pd.core.indexes.range.RangeIndex) + is_index = isinstance(argument, pd.RangeIndex) # First pandas # pandas series have a name but it's None if ( @@ -939,10 +935,10 @@ def build_dataframe(args, attrables, array_attrables): col_name = ( col_name if keep_name - else _name_heuristic(col_name, field, reserved_names) + else _check_name_not_reserved(field, reserved_names) ) else: # numpy array, list... - col_name = _name_heuristic(field, field, reserved_names) + col_name = _check_name_not_reserved(field, reserved_names) if length and len(argument) != length: raise ValueError( "All arguments should have the same length." From c59a4d5c962f16371606afa62cab2db71b79dd78 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Sat, 28 Sep 2019 17:21:17 -0400 Subject: [PATCH 66/69] spaces in error messages --- packages/python/plotly/plotly/express/_core.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index e1dae2b79dc..337ebfaaff7 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -870,7 +870,7 @@ def build_dataframe(args, attrables, array_attrables): # Case of multiindex if isinstance(argument, pd.MultiIndex): raise TypeError( - "Argument '%s' is a pandas MultiIndex." + "Argument '%s' is a pandas MultiIndex. " "pandas MultiIndex is not supported by plotly express " "at the moment." % field ) @@ -880,10 +880,10 @@ def build_dataframe(args, attrables, array_attrables): ): # just a column name given as str or int if not df_provided: raise ValueError( - "String or int arguments are only possible when a" - "DataFrame or an array is provided in the `data_frame`" - "argument. No DataFrame was provided, but argument '%s'" - "is of type str or int." % field + "String or int arguments are only possible when a " + "DataFrame or an array is provided in the `data_frame` " + "argument. No DataFrame was provided, but argument " + "'%s' is of type str or int." % field ) # Check validity of column name if argument not in df_input.columns: @@ -941,8 +941,8 @@ def build_dataframe(args, attrables, array_attrables): col_name = _check_name_not_reserved(field, reserved_names) if length and len(argument) != length: raise ValueError( - "All arguments should have the same length." - "The length of argument `%s` is %d, whereas the" + "All arguments should have the same length. " + "The length of argument `%s` is %d, whereas the " "length of previous arguments %s is %d" % (field, len(argument), str(list(df_output.columns)), length) ) From 1774ade53fad7f4dbb1f56cc369cf989afabd3f3 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Sat, 28 Sep 2019 17:35:05 -0400 Subject: [PATCH 67/69] case of size column --- packages/python/plotly/plotly/express/_core.py | 6 +++--- .../plotly/plotly/tests/test_core/test_px/test_px_input.py | 6 ++++++ 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 337ebfaaff7..6bab5117012 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -929,9 +929,9 @@ def build_dataframe(args, attrables, array_attrables): keep_name = df_provided and argument is df_input.index else: # we use getattr/hasattr because of index - keep_name = hasattr( - df_input, col_name - ) and argument is getattr(df_input, col_name) + keep_name = col_name in df_input and argument.equals( + df_input[col_name] + ) col_name = ( col_name if keep_name diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py index 691b7c1eb90..08bb1a9cc95 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -282,3 +282,9 @@ def test_pass_df_columns(): assert fig.data[1].hovertemplate.count("customdata") == len(tips.columns) tips_copy = px.data.tips() assert tips_copy.columns.equals(tips.columns) + + +def test_size_column(): + df = px.data.tips() + fig = px.scatter(df, x=df["size"], y=df.tip) + assert fig.data[0].hovertemplate == "size=%{x}
tip=%{y}" From 16d95af996cd59d79f515b0c67eebb6a3372b904 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 30 Sep 2019 12:44:28 -0400 Subject: [PATCH 68/69] qa --- packages/python/plotly/plotly/express/_core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 6bab5117012..1aa009eab59 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -928,7 +928,6 @@ def build_dataframe(args, attrables, array_attrables): if is_index: keep_name = df_provided and argument is df_input.index else: - # we use getattr/hasattr because of index keep_name = col_name in df_input and argument.equals( df_input[col_name] ) From 5c95786432cb0a7dbf06d1748fd4481ab83c4d46 Mon Sep 17 00:00:00 2001 From: Emmanuelle Gouillart Date: Mon, 30 Sep 2019 13:04:03 -0400 Subject: [PATCH 69/69] revert to is --- packages/python/plotly/plotly/express/_core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 1aa009eab59..e5846be0668 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -928,8 +928,8 @@ def build_dataframe(args, attrables, array_attrables): if is_index: keep_name = df_provided and argument is df_input.index else: - keep_name = col_name in df_input and argument.equals( - df_input[col_name] + keep_name = ( + col_name in df_input and argument is df_input[col_name] ) col_name = ( col_name