diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 7a08d2e6ac4..c17c5a657e5 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -4,7 +4,7 @@ def scatter( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -57,7 +57,7 @@ def scatter( def density_contour( - data_frame, + data_frame=None, x=None, y=None, z=None, @@ -113,7 +113,7 @@ def density_contour( def density_heatmap( - data_frame, + data_frame=None, x=None, y=None, z=None, @@ -167,7 +167,7 @@ def density_heatmap( def line( - data_frame, + data_frame=None, x=None, y=None, line_group=None, @@ -212,7 +212,7 @@ def line( def area( - data_frame, + data_frame=None, x=None, y=None, line_group=None, @@ -257,7 +257,7 @@ def area( def bar( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -307,7 +307,7 @@ def bar( def histogram( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -364,7 +364,7 @@ def histogram( def violin( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -415,7 +415,7 @@ def violin( def box( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -461,7 +461,7 @@ def box( def strip( - data_frame, + data_frame=None, x=None, y=None, color=None, @@ -511,7 +511,7 @@ def strip( def scatter_3d( - data_frame, + data_frame=None, x=None, y=None, z=None, @@ -562,7 +562,7 @@ def scatter_3d( def line_3d( - data_frame, + data_frame=None, x=None, y=None, z=None, @@ -608,7 +608,7 @@ def line_3d( def scatter_ternary( - data_frame, + data_frame=None, a=None, b=None, c=None, @@ -647,7 +647,7 @@ def scatter_ternary( def line_ternary( - data_frame, + data_frame=None, a=None, b=None, c=None, @@ -682,7 +682,7 @@ def line_ternary( def scatter_polar( - data_frame, + data_frame=None, r=None, theta=None, color=None, @@ -726,7 +726,7 @@ def scatter_polar( def line_polar( - data_frame, + data_frame=None, r=None, theta=None, color=None, @@ -766,7 +766,7 @@ def line_polar( def bar_polar( - data_frame, + data_frame=None, r=None, theta=None, color=None, @@ -804,7 +804,7 @@ def bar_polar( def choropleth( - data_frame, + data_frame=None, lat=None, lon=None, locations=None, @@ -844,7 +844,7 @@ def choropleth( def scatter_geo( - data_frame, + data_frame=None, lat=None, lon=None, locations=None, @@ -888,7 +888,7 @@ def scatter_geo( def line_geo( - data_frame, + data_frame=None, lat=None, lon=None, locations=None, @@ -930,7 +930,7 @@ def line_geo( def scatter_mapbox( - data_frame, + data_frame=None, lat=None, lon=None, color=None, @@ -966,7 +966,7 @@ def scatter_mapbox( def line_mapbox( - data_frame, + data_frame=None, lat=None, lon=None, color=None, @@ -997,7 +997,7 @@ def line_mapbox( def scatter_matrix( - data_frame, + data_frame=None, dimensions=None, color=None, symbol=None, @@ -1035,7 +1035,7 @@ def scatter_matrix( def parallel_coordinates( - data_frame, + data_frame=None, dimensions=None, color=None, labels={}, @@ -1059,7 +1059,7 @@ def parallel_coordinates( def parallel_categories( - data_frame, + data_frame=None, dimensions=None, color=None, labels={}, diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 3d55d3004da..e5846be0668 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -5,7 +5,7 @@ from _plotly_utils.basevalidators import ColorscaleValidator from .colors import qualitative, sequential import math -import pandas +import pandas as pd import numpy as np from plotly.subplots import ( @@ -754,6 +754,209 @@ def apply_default_cascade(args): args["marginal_x"] = None +def _check_name_not_reserved(field_name, reserved_names): + if field_name not in reserved_names: + return field_name + else: + raise NameError( + "A name conflict was encountered for argument %s. " + "A column with name %s is already used." % (field_name, field_name) + ) + + +def _get_reserved_col_names(args, attrables, array_attrables): + """ + This function builds a list of columns of the data_frame argument used + as arguments, either as str/int arguments or given as columns + (pandas series type). + """ + df = args["data_frame"] + reserved_names = set() + for field in args: + if field not in attrables: + continue + names = args[field] if field in array_attrables else [args[field]] + if names is None: + continue + for arg in names: + if arg is None: + continue + elif isinstance(arg, str): # no need to add ints since kw arg are not ints + reserved_names.add(arg) + elif isinstance(arg, pd.Series): + arg_name = arg.name + if arg_name and hasattr(df, arg_name): + in_df = arg is df[arg_name] + if in_df: + reserved_names.add(arg_name) + + return reserved_names + + +def build_dataframe(args, attrables, array_attrables): + """ + Constructs a dataframe and modifies `args` in-place. + + The argument values in `args` can be either strings corresponding to + existing columns of a dataframe, or data arrays (lists, numpy arrays, + pandas columns, series). + + Parameters + ---------- + args : OrderedDict + arguments passed to the px function and subsequently modified + attrables : list + list of keys into `args`, all of whose corresponding values are + converted into columns of a dataframe. + array_attrables : list + argument names corresponding to iterables, such as `hover_data`, ... + """ + for field in args: + if field in array_attrables and args[field] is not None: + args[field] = ( + dict(args[field]) + if isinstance(args[field], dict) + else list(args[field]) + ) + # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) + df_provided = args["data_frame"] is not None + if df_provided and not isinstance(args["data_frame"], pd.DataFrame): + args["data_frame"] = pd.DataFrame(args["data_frame"]) + df_input = args["data_frame"] + + # We start from an empty DataFrame + df_output = pd.DataFrame() + + # Initialize set of column names + # These are reserved names + if df_provided: + reserved_names = _get_reserved_col_names(args, attrables, array_attrables) + else: + reserved_names = set() + + # Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords + if "dimensions" in args and args["dimensions"] is None: + if not df_provided: + raise ValueError( + "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument." + ) + else: + df_output[df_input.columns] = df_input[df_input.columns] + + # Loop over possible arguments + for field_name in attrables: + # Massaging variables + argument_list = ( + [args.get(field_name)] + if field_name not in array_attrables + else args.get(field_name) + ) + # argument not specified, continue + if argument_list is None or argument_list is [None]: + continue + # Argument name: field_name if the argument is not a list + # Else we give names like ["hover_data_0, hover_data_1"] etc. + field_list = ( + [field_name] + if field_name not in array_attrables + else [field_name + "_" + str(i) for i in range(len(argument_list))] + ) + # argument_list and field_list ready, iterate over them + # Core of the loop starts here + for i, (argument, field) in enumerate(zip(argument_list, field_list)): + length = len(df_output) + if argument is None: + continue + # Case of multiindex + if isinstance(argument, pd.MultiIndex): + raise TypeError( + "Argument '%s' is a pandas MultiIndex. " + "pandas MultiIndex is not supported by plotly express " + "at the moment." % field + ) + ## ----------------- argument is a col name ---------------------- + if isinstance(argument, str) or isinstance( + argument, int + ): # just a column name given as str or int + if not df_provided: + raise ValueError( + "String or int arguments are only possible when a " + "DataFrame or an array is provided in the `data_frame` " + "argument. No DataFrame was provided, but argument " + "'%s' is of type str or int." % field + ) + # Check validity of column name + if argument not in df_input.columns: + err_msg = ( + "Value of '%s' is not the name of a column in 'data_frame'. " + "Expected one of %s but received: %s" + % (field, str(list(df_input.columns)), argument) + ) + if argument == "index": + err_msg += ( + "\n To use the index, pass it in directly as `df.index`." + ) + raise ValueError(err_msg) + if length and len(df_input[argument]) != length: + raise ValueError( + "All arguments should have the same length. " + "The length of column argument `df[%s]` is %d, whereas the " + "length of previous arguments %s is %d" + % ( + field, + len(df_input[argument]), + str(list(df_output.columns)), + length, + ) + ) + col_name = str(argument) + df_output[col_name] = df_input[argument] + # ----------------- argument is a column / array / list.... ------- + else: + is_index = isinstance(argument, pd.RangeIndex) + # First pandas + # pandas series have a name but it's None + if ( + hasattr(argument, "name") and argument.name is not None + ) or is_index: + col_name = argument.name # pandas df + if col_name is None and is_index: + col_name = "index" + if not df_provided: + col_name = field + else: + if is_index: + keep_name = df_provided and argument is df_input.index + else: + keep_name = ( + col_name in df_input and argument is df_input[col_name] + ) + col_name = ( + col_name + if keep_name + else _check_name_not_reserved(field, reserved_names) + ) + else: # numpy array, list... + col_name = _check_name_not_reserved(field, reserved_names) + if length and len(argument) != length: + raise ValueError( + "All arguments should have the same length. " + "The length of argument `%s` is %d, whereas the " + "length of previous arguments %s is %d" + % (field, len(argument), str(list(df_output.columns)), length) + ) + df_output[str(col_name)] = argument + + # Finally, update argument with column name now that column exists + if field_name not in array_attrables: + args[field_name] = str(col_name) + else: + args[field_name][i] = str(col_name) + + args["data_frame"] = df_output + return args + + def infer_config(args, constructor, trace_patch): # Declare all supported attributes, across all plot types attrables = ( @@ -765,28 +968,13 @@ def infer_config(args, constructor, trace_patch): ) array_attrables = ["dimensions", "custom_data", "hover_data"] group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] + all_attrables = attrables + group_attrables + ["color"] + group_attrs = ["symbol", "line_dash"] + for group_attr in group_attrs: + if group_attr in args: + all_attrables += [group_attr] - # Validate that the strings provided as attribute values reference columns - # in the provided data_frame - df_columns = args["data_frame"].columns - - for attr in attrables + group_attrables + ["color"]: - if attr in args and args[attr] is not None: - maybe_col_list = [args[attr]] if attr not in array_attrables else args[attr] - for maybe_col in maybe_col_list: - try: - in_cols = maybe_col in df_columns - except TypeError: - in_cols = False - if not in_cols: - value_str = ( - "Element of value" if attr in array_attrables else "Value" - ) - raise ValueError( - "%s of '%s' is not the name of a column in 'data_frame'. " - "Expected one of %s but received: %s" - % (value_str, attr, str(list(df_columns)), str(maybe_col)) - ) + args = build_dataframe(args, all_attrables, array_attrables) attrs = [k for k in attrables if k in args] grouped_attrs = [] @@ -864,7 +1052,7 @@ def infer_config(args, constructor, trace_patch): # Create trace specs trace_specs = make_trace_spec(args, constructor, attrs, trace_patch) - return trace_specs, grouped_mappings, sizeref, show_colorbar + return args, trace_specs, grouped_mappings, sizeref, show_colorbar def get_orderings(args, grouper, grouped): @@ -902,7 +1090,7 @@ def get_orderings(args, grouper, grouped): def make_figure(args, constructor, trace_patch={}, layout_patch={}): apply_default_cascade(args) - trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( + args, trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config( args, constructor, trace_patch ) grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group] @@ -1095,7 +1283,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): fig.layout.update(layout_patch) fig.frames = frame_list if len(frames) > 1 else [] - fig._px_trendlines = pandas.DataFrame(trendline_rows) + fig._px_trendlines = pd.DataFrame(trendline_rows) configure_axes(args, constructor, fig, orders) configure_animation_controls(args, constructor, fig) diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index fbefe4e3860..7df598dccb3 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -1,7 +1,7 @@ import inspect -colref = "(string: name of column in `data_frame`)" -colref_list = "(list of string: names of columns in `data_frame`)" +colref = "(string or int: name of column in `data_frame`, or pandas Series, or array_like object)" +colref_list = "(list of string or int: names of columns in `data_frame`, or pandas Series, or array_like objects)" # TODO contents of columns # TODO explain categorical @@ -12,53 +12,57 @@ # TODO standardize positioning and casing of 'default' docs = dict( - data_frame=["A 'tidy' `pandas.DataFrame`"], + data_frame=[ + "A `pandas.DataFrame`, or a `NumPy` array or a dictionary", + "which are tranformed internally to `pandas.DataFrame`. This argument needs" + "to be passed for column names (and not keyword names) to be used.", + ], x=[ colref, - "Values from this column are used to position marks along the x axis in cartesian coordinates.", + "Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.", "For horizontal `histogram`s, these values are used as inputs to `histfunc`.", ], y=[ colref, - "Values from this column are used to position marks along the y axis in cartesian coordinates.", + "Values from this column or array_like are used to position marks along the y axis in cartesian coordinates.", "For vertical `histogram`s, these values are used as inputs to `histfunc`.", ], z=[ colref, - "Values from this column are used to position marks along the z axis in cartesian coordinates.", + "Values from this column or array_like are used to position marks along the z axis in cartesian coordinates.", "For `density_heatmap` and `density_contour` these values are used as the inputs to `histfunc`.", ], a=[ colref, - "Values from this column are used to position marks along the a axis in ternary coordinates.", + "Values from this column or array_like are used to position marks along the a axis in ternary coordinates.", ], b=[ colref, - "Values from this column are used to position marks along the b axis in ternary coordinates.", + "Values from this column or array_like are used to position marks along the b axis in ternary coordinates.", ], c=[ colref, - "Values from this column are used to position marks along the c axis in ternary coordinates.", + "Values from this column or array_like are used to position marks along the c axis in ternary coordinates.", ], r=[ colref, - "Values from this column are used to position marks along the radial axis in polar coordinates.", + "Values from this column or array_like are used to position marks along the radial axis in polar coordinates.", ], theta=[ colref, - "Values from this column are used to position marks along the angular axis in polar coordinates.", + "Values from this column or array_like are used to position marks along the angular axis in polar coordinates.", ], lat=[ colref, - "Values from this column are used to position marks according to latitude on a map.", + "Values from this column or array_like are used to position marks according to latitude on a map.", ], lon=[ colref, - "Values from this column are used to position marks according to longitude on a map.", + "Values from this column or array_like are used to position marks according to longitude on a map.", ], locations=[ colref, - "Values from this column are be interpreted according to `locationmode` and mapped to longitude/latitude.", + "Values from this column or array_like are be interpreted according to `locationmode` and mapped to longitude/latitude.", ], dimensions=[ "(list of strings, names of columns in `data_frame`)", @@ -66,47 +70,59 @@ ], error_x=[ colref, - "Values from this column are used to size x-axis error bars.", + "Values from this column or array_like are used to size x-axis error bars.", "If `error_x_minus` is `None`, error bars will be symmetrical, otherwise `error_x` is used for the positive direction only.", ], error_x_minus=[ colref, - "Values from this column are used to size x-axis error bars in the negative direction.", + "Values from this column or array_like are used to size x-axis error bars in the negative direction.", "Ignored if `error_x` is `None`.", ], error_y=[ colref, - "Values from this column are used to size y-axis error bars.", + "Values from this column or array_like are used to size y-axis error bars.", "If `error_y_minus` is `None`, error bars will be symmetrical, otherwise `error_y` is used for the positive direction only.", ], error_y_minus=[ colref, - "Values from this column are used to size y-axis error bars in the negative direction.", + "Values from this column or array_like are used to size y-axis error bars in the negative direction.", "Ignored if `error_y` is `None`.", ], error_z=[ colref, - "Values from this column are used to size z-axis error bars.", + "Values from this column or array_like are used to size z-axis error bars.", "If `error_z_minus` is `None`, error bars will be symmetrical, otherwise `error_z` is used for the positive direction only.", ], error_z_minus=[ colref, - "Values from this column are used to size z-axis error bars in the negative direction.", + "Values from this column or array_like are used to size z-axis error bars in the negative direction.", "Ignored if `error_z` is `None`.", ], - color=[colref, "Values from this column are used to assign color to marks."], + color=[ + colref, + "Values from this column or array_like are used to assign color to marks.", + ], opacity=["(number, between 0 and 1) Sets the opacity for markers."], line_dash=[ colref, - "Values from this column are used to assign dash-patterns to lines.", + "Values from this column or array_like are used to assign dash-patterns to lines.", ], line_group=[ colref, - "Values from this column are used to group rows of `data_frame` into lines.", + "Values from this column or array_like are used to group rows of `data_frame` into lines.", + ], + symbol=[ + colref, + "Values from this column or array_like are used to assign symbols to marks.", + ], + size=[ + colref, + "Values from this column or array_like are used to assign mark sizes.", + ], + hover_name=[ + colref, + "Values from this column or array_like appear in bold in the hover tooltip.", ], - symbol=[colref, "Values from this column are used to assign symbols to marks."], - size=[colref, "Values from this column are used to assign mark sizes."], - hover_name=[colref, "Values from this column appear in bold in the hover tooltip."], hover_data=[ colref_list, "Values from these columns appear as extra data in the hover tooltip.", @@ -115,26 +131,29 @@ colref_list, "Values from these columns are extra data, to be used in widgets or Dash callbacks for example. This data is not user-visible but is included in events emitted by the figure (lasso selection etc.)", ], - text=[colref, "Values from this column appear in the figure as text labels."], + text=[ + colref, + "Values from this column or array_like appear in the figure as text labels.", + ], locationmode=[ "(string, one of 'ISO-3', 'USA-states', 'country names')", "Determines the set of locations used to match entries in `locations` to regions on the map.", ], facet_row=[ colref, - "Values from this column are used to assign marks to facetted subplots in the vertical direction.", + "Values from this column or array_like are used to assign marks to facetted subplots in the vertical direction.", ], facet_col=[ colref, - "Values from this column are used to assign marks to facetted subplots in the horizontal direction.", + "Values from this column or array_like are used to assign marks to facetted subplots in the horizontal direction.", ], animation_frame=[ colref, - "Values from this column are used to assign marks to animation frames.", + "Values from this column or array_like are used to assign marks to animation frames.", ], animation_group=[ colref, - "Values from this column are used to provide object-constancy across animation frames: rows with matching `animation_group`s will be treated as if they describe the same object in each frame.", + "Values from this column or array_like are used to provide object-constancy across animation frames: rows with matching `animation_group`s will be treated as if they describe the same object in each frame.", ], symbol_sequence=[ "(list of strings defining plotly.js symbols)", diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py new file mode 100644 index 00000000000..08bb1a9cc95 --- /dev/null +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py @@ -0,0 +1,290 @@ +import plotly.express as px +import numpy as np +import pandas as pd +import pytest +import plotly.graph_objects as go +import plotly +from plotly.express._core import build_dataframe +from pandas.util.testing import assert_frame_equal + +attrables = ( + ["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"] + + ["custom_data", "hover_name", "hover_data", "text"] + + ["error_x", "error_x_minus"] + + ["error_y", "error_y_minus", "error_z", "error_z_minus"] + + ["lat", "lon", "locations", "animation_group"] +) +array_attrables = ["dimensions", "custom_data", "hover_data"] +group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"] + +all_attrables = attrables + group_attrables + ["color"] + + +def test_numpy(): + fig = px.scatter(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9]) + assert np.all(fig.data[0].x == np.array([1, 2, 3])) + assert np.all(fig.data[0].y == np.array([2, 3, 4])) + assert np.all(fig.data[0].marker.color == np.array([1, 3, 9])) + + +def test_numpy_labels(): + fig = px.scatter( + x=[1, 2, 3], y=[2, 3, 4], labels={"x": "time"} + ) # other labels will be kw arguments + assert fig.data[0]["hovertemplate"] == "time=%{x}
y=%{y}" + + +def test_with_index(): + tips = px.data.tips() + fig = px.scatter(tips, x=tips.index, y="total_bill") + assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" + fig = px.scatter(tips, x=tips.index, y=tips.total_bill) + assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}" + fig = px.scatter(tips, x=tips.index, y=tips.total_bill, labels={"index": "number"}) + assert fig.data[0]["hovertemplate"] == "number=%{x}
total_bill=%{y}" + # We do not allow "x=index" + with pytest.raises(ValueError) as err_msg: + fig = px.scatter(tips, x="index", y="total_bill") + assert "To use the index, pass it in directly as `df.index`." in str( + err_msg.value + ) + tips = px.data.tips() + tips.index.name = "item" + fig = px.scatter(tips, x=tips.index, y="total_bill") + assert fig.data[0]["hovertemplate"] == "item=%{x}
total_bill=%{y}" + + +def test_pandas_series(): + tips = px.data.tips() + before_tip = tips.total_bill - tips.tip + fig = px.bar(tips, x="day", y=before_tip) + assert fig.data[0].hovertemplate == "day=%{x}
y=%{y}" + fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"}) + assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}" + + +def test_several_dataframes(): + df = pd.DataFrame(dict(x=[0, 1], y=[1, 10], z=[0.1, 0.8])) + df2 = pd.DataFrame(dict(time=[23, 26], money=[100, 200])) + fig = px.scatter(df, x="z", y=df2.money, size="x") + assert fig.data[0].hovertemplate == "z=%{x}
y=%{y}
x=%{marker.size}" + fig = px.scatter(df2, x=df.z, y=df2.money, size=df.z) + assert fig.data[0].hovertemplate == "x=%{x}
money=%{y}
size=%{marker.size}" + # Name conflict + with pytest.raises(NameError) as err_msg: + fig = px.scatter(df, x="z", y=df2.money, size="y") + assert "A name conflict was encountered for argument y" in str(err_msg.value) + with pytest.raises(NameError) as err_msg: + fig = px.scatter(df, x="z", y=df2.money, size=df.y) + assert "A name conflict was encountered for argument y" in str(err_msg.value) + + # No conflict when the dataframe is not given, fields are used + df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) + df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) + fig = px.scatter(x=df.y, y=df2.y) + assert np.all(fig.data[0].x == np.array([3, 4])) + assert np.all(fig.data[0].y == np.array([23, 24])) + assert fig.data[0].hovertemplate == "x=%{x}
y=%{y}" + + df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) + df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) + df3 = pd.DataFrame(dict(y=[0.1, 0.2])) + fig = px.scatter(x=df.y, y=df2.y, size=df3.y) + assert np.all(fig.data[0].x == np.array([3, 4])) + assert np.all(fig.data[0].y == np.array([23, 24])) + assert fig.data[0].hovertemplate == "x=%{x}
y=%{y}
size=%{marker.size}" + + df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) + df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) + df3 = pd.DataFrame(dict(y=[0.1, 0.2])) + fig = px.scatter(x=df.y, y=df2.y, hover_data=[df3.y]) + assert np.all(fig.data[0].x == np.array([3, 4])) + assert np.all(fig.data[0].y == np.array([23, 24])) + assert ( + fig.data[0].hovertemplate == "x=%{x}
y=%{y}
hover_data_0=%{customdata[0]}" + ) + + +def test_name_heuristics(): + df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], z=[0.1, 0.2])) + fig = px.scatter(df, x=df.y, y=df.x, size=df.y) + assert np.all(fig.data[0].x == np.array([3, 4])) + assert np.all(fig.data[0].y == np.array([0, 1])) + assert fig.data[0].hovertemplate == "y=%{marker.size}
x=%{y}" + + +def test_repeated_name(): + iris = px.data.iris() + fig = px.scatter( + iris, + x="sepal_width", + y="sepal_length", + hover_data=["petal_length", "petal_width", "species_id"], + custom_data=["species_id", "species"], + ) + assert fig.data[0].customdata.shape[1] == 4 + + +def test_arrayattrable_numpy(): + tips = px.data.tips() + fig = px.scatter( + tips, x="total_bill", y="tip", hover_data=[np.random.random(tips.shape[0])] + ) + assert ( + fig.data[0]["hovertemplate"] + == "total_bill=%{x}
tip=%{y}
hover_data_0=%{customdata[0]}" + ) + tips = px.data.tips() + fig = px.scatter( + tips, + x="total_bill", + y="tip", + hover_data=[np.random.random(tips.shape[0])], + labels={"hover_data_0": "suppl"}, + ) + assert ( + fig.data[0]["hovertemplate"] + == "total_bill=%{x}
tip=%{y}
suppl=%{customdata[0]}" + ) + + +def test_wrong_column_name(): + with pytest.raises(ValueError) as err_msg: + fig = px.scatter(px.data.tips(), x="bla", y="wrong") + assert "Value of 'x' is not the name of a column in 'data_frame'" in str( + err_msg.value + ) + + +def test_missing_data_frame(): + with pytest.raises(ValueError) as err_msg: + fig = px.scatter(x="arg1", y="arg2") + assert "String or int arguments are only possible" in str(err_msg.value) + + +def test_wrong_dimensions_of_array(): + with pytest.raises(ValueError) as err_msg: + fig = px.scatter(x=[1, 2, 3], y=[2, 3, 4, 5]) + assert "All arguments should have the same length." in str(err_msg.value) + + +def test_wrong_dimensions_mixed_case(): + with pytest.raises(ValueError) as err_msg: + df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25])) + fig = px.scatter(df, x="time", y="temperature", color=[1, 3, 9, 5]) + assert "All arguments should have the same length." in str(err_msg.value) + + +def test_wrong_dimensions(): + with pytest.raises(ValueError) as err_msg: + fig = px.scatter(px.data.tips(), x="tip", y=[1, 2, 3]) + assert "All arguments should have the same length." in str(err_msg.value) + # the order matters + with pytest.raises(ValueError) as err_msg: + fig = px.scatter(px.data.tips(), x=[1, 2, 3], y="tip") + assert "All arguments should have the same length." in str(err_msg.value) + with pytest.raises(ValueError): + fig = px.scatter(px.data.tips(), x=px.data.iris().index, y="tip") + # assert "All arguments should have the same length." in str(err_msg.value) + + +def test_multiindex_raise_error(): + index = pd.MultiIndex.from_product( + [[1, 2, 3], ["a", "b"]], names=["first", "second"] + ) + df = pd.DataFrame(np.random.random((6, 3)), index=index, columns=["A", "B", "C"]) + # This is ok + fig = px.scatter(df, x="A", y="B") + with pytest.raises(TypeError) as err_msg: + fig = px.scatter(df, x=df.index, y="B") + assert "pandas MultiIndex is not supported by plotly express" in str( + err_msg.value + ) + + +def test_build_df_from_lists(): + # Just lists + args = dict(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9]) + output = {key: key for key in args} + df = pd.DataFrame(args) + args["data_frame"] = None + out = build_dataframe(args, all_attrables, array_attrables) + assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) + out.pop("data_frame") + assert out == output + + # Arrays + args = dict(x=np.array([1, 2, 3]), y=np.array([2, 3, 4]), color=[1, 3, 9]) + output = {key: key for key in args} + df = pd.DataFrame(args) + args["data_frame"] = None + out = build_dataframe(args, all_attrables, array_attrables) + assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) + out.pop("data_frame") + assert out == output + + +def test_build_df_with_index(): + tips = px.data.tips() + args = dict(data_frame=tips, x=tips.index, y="total_bill") + out = build_dataframe(args, all_attrables, array_attrables) + assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"]) + + +def test_splom_case(): + iris = px.data.iris() + fig = px.scatter_matrix(iris) + assert len(fig.data[0].dimensions) == len(iris.columns) + dic = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} + fig = px.scatter_matrix(dic) + assert np.all(fig.data[0].dimensions[0].values == np.array(dic["a"])) + ar = np.arange(9).reshape((3, 3)) + fig = px.scatter_matrix(ar) + assert np.all(fig.data[0].dimensions[0].values == ar[:, 0]) + + +def test_int_col_names(): + # DataFrame with int column names + lengths = pd.DataFrame(np.random.random(100)) + fig = px.histogram(lengths, x=0) + assert np.all(np.array(lengths).flatten() == fig.data[0].x) + # Numpy array + ar = np.arange(100).reshape((10, 10)) + fig = px.scatter(ar, x=2, y=8) + assert np.all(fig.data[0].x == ar[:, 2]) + + +def test_data_frame_from_dict(): + fig = px.scatter({"time": [0, 1], "money": [1, 2]}, x="time", y="money") + assert fig.data[0].hovertemplate == "time=%{x}
money=%{y}" + assert np.all(fig.data[0].x == [0, 1]) + + +def test_arguments_not_modified(): + iris = px.data.iris() + petal_length = iris.petal_length + hover_data = [iris.sepal_length] + fig = px.scatter(iris, x=petal_length, y="petal_width", hover_data=hover_data) + assert iris.petal_length.equals(petal_length) + assert iris.sepal_length.equals(hover_data[0]) + + +def test_pass_df_columns(): + tips = px.data.tips() + fig = px.histogram( + tips, + x="total_bill", + y="tip", + color="sex", + marginal="rug", + hover_data=tips.columns, + ) + assert fig.data[1].hovertemplate.count("customdata") == len(tips.columns) + tips_copy = px.data.tips() + assert tips_copy.columns.equals(tips.columns) + + +def test_size_column(): + df = px.data.tips() + fig = px.scatter(df, x=df["size"], y=df.tip) + assert fig.data[0].hovertemplate == "size=%{x}
tip=%{y}"