diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py
index 7a08d2e6ac4..c17c5a657e5 100644
--- a/packages/python/plotly/plotly/express/_chart_types.py
+++ b/packages/python/plotly/plotly/express/_chart_types.py
@@ -4,7 +4,7 @@
def scatter(
- data_frame,
+ data_frame=None,
x=None,
y=None,
color=None,
@@ -57,7 +57,7 @@ def scatter(
def density_contour(
- data_frame,
+ data_frame=None,
x=None,
y=None,
z=None,
@@ -113,7 +113,7 @@ def density_contour(
def density_heatmap(
- data_frame,
+ data_frame=None,
x=None,
y=None,
z=None,
@@ -167,7 +167,7 @@ def density_heatmap(
def line(
- data_frame,
+ data_frame=None,
x=None,
y=None,
line_group=None,
@@ -212,7 +212,7 @@ def line(
def area(
- data_frame,
+ data_frame=None,
x=None,
y=None,
line_group=None,
@@ -257,7 +257,7 @@ def area(
def bar(
- data_frame,
+ data_frame=None,
x=None,
y=None,
color=None,
@@ -307,7 +307,7 @@ def bar(
def histogram(
- data_frame,
+ data_frame=None,
x=None,
y=None,
color=None,
@@ -364,7 +364,7 @@ def histogram(
def violin(
- data_frame,
+ data_frame=None,
x=None,
y=None,
color=None,
@@ -415,7 +415,7 @@ def violin(
def box(
- data_frame,
+ data_frame=None,
x=None,
y=None,
color=None,
@@ -461,7 +461,7 @@ def box(
def strip(
- data_frame,
+ data_frame=None,
x=None,
y=None,
color=None,
@@ -511,7 +511,7 @@ def strip(
def scatter_3d(
- data_frame,
+ data_frame=None,
x=None,
y=None,
z=None,
@@ -562,7 +562,7 @@ def scatter_3d(
def line_3d(
- data_frame,
+ data_frame=None,
x=None,
y=None,
z=None,
@@ -608,7 +608,7 @@ def line_3d(
def scatter_ternary(
- data_frame,
+ data_frame=None,
a=None,
b=None,
c=None,
@@ -647,7 +647,7 @@ def scatter_ternary(
def line_ternary(
- data_frame,
+ data_frame=None,
a=None,
b=None,
c=None,
@@ -682,7 +682,7 @@ def line_ternary(
def scatter_polar(
- data_frame,
+ data_frame=None,
r=None,
theta=None,
color=None,
@@ -726,7 +726,7 @@ def scatter_polar(
def line_polar(
- data_frame,
+ data_frame=None,
r=None,
theta=None,
color=None,
@@ -766,7 +766,7 @@ def line_polar(
def bar_polar(
- data_frame,
+ data_frame=None,
r=None,
theta=None,
color=None,
@@ -804,7 +804,7 @@ def bar_polar(
def choropleth(
- data_frame,
+ data_frame=None,
lat=None,
lon=None,
locations=None,
@@ -844,7 +844,7 @@ def choropleth(
def scatter_geo(
- data_frame,
+ data_frame=None,
lat=None,
lon=None,
locations=None,
@@ -888,7 +888,7 @@ def scatter_geo(
def line_geo(
- data_frame,
+ data_frame=None,
lat=None,
lon=None,
locations=None,
@@ -930,7 +930,7 @@ def line_geo(
def scatter_mapbox(
- data_frame,
+ data_frame=None,
lat=None,
lon=None,
color=None,
@@ -966,7 +966,7 @@ def scatter_mapbox(
def line_mapbox(
- data_frame,
+ data_frame=None,
lat=None,
lon=None,
color=None,
@@ -997,7 +997,7 @@ def line_mapbox(
def scatter_matrix(
- data_frame,
+ data_frame=None,
dimensions=None,
color=None,
symbol=None,
@@ -1035,7 +1035,7 @@ def scatter_matrix(
def parallel_coordinates(
- data_frame,
+ data_frame=None,
dimensions=None,
color=None,
labels={},
@@ -1059,7 +1059,7 @@ def parallel_coordinates(
def parallel_categories(
- data_frame,
+ data_frame=None,
dimensions=None,
color=None,
labels={},
diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py
index 3d55d3004da..e5846be0668 100644
--- a/packages/python/plotly/plotly/express/_core.py
+++ b/packages/python/plotly/plotly/express/_core.py
@@ -5,7 +5,7 @@
from _plotly_utils.basevalidators import ColorscaleValidator
from .colors import qualitative, sequential
import math
-import pandas
+import pandas as pd
import numpy as np
from plotly.subplots import (
@@ -754,6 +754,209 @@ def apply_default_cascade(args):
args["marginal_x"] = None
+def _check_name_not_reserved(field_name, reserved_names):
+ if field_name not in reserved_names:
+ return field_name
+ else:
+ raise NameError(
+ "A name conflict was encountered for argument %s. "
+ "A column with name %s is already used." % (field_name, field_name)
+ )
+
+
+def _get_reserved_col_names(args, attrables, array_attrables):
+ """
+ This function builds a list of columns of the data_frame argument used
+ as arguments, either as str/int arguments or given as columns
+ (pandas series type).
+ """
+ df = args["data_frame"]
+ reserved_names = set()
+ for field in args:
+ if field not in attrables:
+ continue
+ names = args[field] if field in array_attrables else [args[field]]
+ if names is None:
+ continue
+ for arg in names:
+ if arg is None:
+ continue
+ elif isinstance(arg, str): # no need to add ints since kw arg are not ints
+ reserved_names.add(arg)
+ elif isinstance(arg, pd.Series):
+ arg_name = arg.name
+ if arg_name and hasattr(df, arg_name):
+ in_df = arg is df[arg_name]
+ if in_df:
+ reserved_names.add(arg_name)
+
+ return reserved_names
+
+
+def build_dataframe(args, attrables, array_attrables):
+ """
+ Constructs a dataframe and modifies `args` in-place.
+
+ The argument values in `args` can be either strings corresponding to
+ existing columns of a dataframe, or data arrays (lists, numpy arrays,
+ pandas columns, series).
+
+ Parameters
+ ----------
+ args : OrderedDict
+ arguments passed to the px function and subsequently modified
+ attrables : list
+ list of keys into `args`, all of whose corresponding values are
+ converted into columns of a dataframe.
+ array_attrables : list
+ argument names corresponding to iterables, such as `hover_data`, ...
+ """
+ for field in args:
+ if field in array_attrables and args[field] is not None:
+ args[field] = (
+ dict(args[field])
+ if isinstance(args[field], dict)
+ else list(args[field])
+ )
+ # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.)
+ df_provided = args["data_frame"] is not None
+ if df_provided and not isinstance(args["data_frame"], pd.DataFrame):
+ args["data_frame"] = pd.DataFrame(args["data_frame"])
+ df_input = args["data_frame"]
+
+ # We start from an empty DataFrame
+ df_output = pd.DataFrame()
+
+ # Initialize set of column names
+ # These are reserved names
+ if df_provided:
+ reserved_names = _get_reserved_col_names(args, attrables, array_attrables)
+ else:
+ reserved_names = set()
+
+ # Case of functions with a "dimensions" kw: scatter_matrix, parcats, parcoords
+ if "dimensions" in args and args["dimensions"] is None:
+ if not df_provided:
+ raise ValueError(
+ "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument."
+ )
+ else:
+ df_output[df_input.columns] = df_input[df_input.columns]
+
+ # Loop over possible arguments
+ for field_name in attrables:
+ # Massaging variables
+ argument_list = (
+ [args.get(field_name)]
+ if field_name not in array_attrables
+ else args.get(field_name)
+ )
+ # argument not specified, continue
+ if argument_list is None or argument_list is [None]:
+ continue
+ # Argument name: field_name if the argument is not a list
+ # Else we give names like ["hover_data_0, hover_data_1"] etc.
+ field_list = (
+ [field_name]
+ if field_name not in array_attrables
+ else [field_name + "_" + str(i) for i in range(len(argument_list))]
+ )
+ # argument_list and field_list ready, iterate over them
+ # Core of the loop starts here
+ for i, (argument, field) in enumerate(zip(argument_list, field_list)):
+ length = len(df_output)
+ if argument is None:
+ continue
+ # Case of multiindex
+ if isinstance(argument, pd.MultiIndex):
+ raise TypeError(
+ "Argument '%s' is a pandas MultiIndex. "
+ "pandas MultiIndex is not supported by plotly express "
+ "at the moment." % field
+ )
+ ## ----------------- argument is a col name ----------------------
+ if isinstance(argument, str) or isinstance(
+ argument, int
+ ): # just a column name given as str or int
+ if not df_provided:
+ raise ValueError(
+ "String or int arguments are only possible when a "
+ "DataFrame or an array is provided in the `data_frame` "
+ "argument. No DataFrame was provided, but argument "
+ "'%s' is of type str or int." % field
+ )
+ # Check validity of column name
+ if argument not in df_input.columns:
+ err_msg = (
+ "Value of '%s' is not the name of a column in 'data_frame'. "
+ "Expected one of %s but received: %s"
+ % (field, str(list(df_input.columns)), argument)
+ )
+ if argument == "index":
+ err_msg += (
+ "\n To use the index, pass it in directly as `df.index`."
+ )
+ raise ValueError(err_msg)
+ if length and len(df_input[argument]) != length:
+ raise ValueError(
+ "All arguments should have the same length. "
+ "The length of column argument `df[%s]` is %d, whereas the "
+ "length of previous arguments %s is %d"
+ % (
+ field,
+ len(df_input[argument]),
+ str(list(df_output.columns)),
+ length,
+ )
+ )
+ col_name = str(argument)
+ df_output[col_name] = df_input[argument]
+ # ----------------- argument is a column / array / list.... -------
+ else:
+ is_index = isinstance(argument, pd.RangeIndex)
+ # First pandas
+ # pandas series have a name but it's None
+ if (
+ hasattr(argument, "name") and argument.name is not None
+ ) or is_index:
+ col_name = argument.name # pandas df
+ if col_name is None and is_index:
+ col_name = "index"
+ if not df_provided:
+ col_name = field
+ else:
+ if is_index:
+ keep_name = df_provided and argument is df_input.index
+ else:
+ keep_name = (
+ col_name in df_input and argument is df_input[col_name]
+ )
+ col_name = (
+ col_name
+ if keep_name
+ else _check_name_not_reserved(field, reserved_names)
+ )
+ else: # numpy array, list...
+ col_name = _check_name_not_reserved(field, reserved_names)
+ if length and len(argument) != length:
+ raise ValueError(
+ "All arguments should have the same length. "
+ "The length of argument `%s` is %d, whereas the "
+ "length of previous arguments %s is %d"
+ % (field, len(argument), str(list(df_output.columns)), length)
+ )
+ df_output[str(col_name)] = argument
+
+ # Finally, update argument with column name now that column exists
+ if field_name not in array_attrables:
+ args[field_name] = str(col_name)
+ else:
+ args[field_name][i] = str(col_name)
+
+ args["data_frame"] = df_output
+ return args
+
+
def infer_config(args, constructor, trace_patch):
# Declare all supported attributes, across all plot types
attrables = (
@@ -765,28 +968,13 @@ def infer_config(args, constructor, trace_patch):
)
array_attrables = ["dimensions", "custom_data", "hover_data"]
group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
+ all_attrables = attrables + group_attrables + ["color"]
+ group_attrs = ["symbol", "line_dash"]
+ for group_attr in group_attrs:
+ if group_attr in args:
+ all_attrables += [group_attr]
- # Validate that the strings provided as attribute values reference columns
- # in the provided data_frame
- df_columns = args["data_frame"].columns
-
- for attr in attrables + group_attrables + ["color"]:
- if attr in args and args[attr] is not None:
- maybe_col_list = [args[attr]] if attr not in array_attrables else args[attr]
- for maybe_col in maybe_col_list:
- try:
- in_cols = maybe_col in df_columns
- except TypeError:
- in_cols = False
- if not in_cols:
- value_str = (
- "Element of value" if attr in array_attrables else "Value"
- )
- raise ValueError(
- "%s of '%s' is not the name of a column in 'data_frame'. "
- "Expected one of %s but received: %s"
- % (value_str, attr, str(list(df_columns)), str(maybe_col))
- )
+ args = build_dataframe(args, all_attrables, array_attrables)
attrs = [k for k in attrables if k in args]
grouped_attrs = []
@@ -864,7 +1052,7 @@ def infer_config(args, constructor, trace_patch):
# Create trace specs
trace_specs = make_trace_spec(args, constructor, attrs, trace_patch)
- return trace_specs, grouped_mappings, sizeref, show_colorbar
+ return args, trace_specs, grouped_mappings, sizeref, show_colorbar
def get_orderings(args, grouper, grouped):
@@ -902,7 +1090,7 @@ def get_orderings(args, grouper, grouped):
def make_figure(args, constructor, trace_patch={}, layout_patch={}):
apply_default_cascade(args)
- trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
+ args, trace_specs, grouped_mappings, sizeref, show_colorbar = infer_config(
args, constructor, trace_patch
)
grouper = [x.grouper or one_group for x in grouped_mappings] or [one_group]
@@ -1095,7 +1283,7 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}):
fig.layout.update(layout_patch)
fig.frames = frame_list if len(frames) > 1 else []
- fig._px_trendlines = pandas.DataFrame(trendline_rows)
+ fig._px_trendlines = pd.DataFrame(trendline_rows)
configure_axes(args, constructor, fig, orders)
configure_animation_controls(args, constructor, fig)
diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py
index fbefe4e3860..7df598dccb3 100644
--- a/packages/python/plotly/plotly/express/_doc.py
+++ b/packages/python/plotly/plotly/express/_doc.py
@@ -1,7 +1,7 @@
import inspect
-colref = "(string: name of column in `data_frame`)"
-colref_list = "(list of string: names of columns in `data_frame`)"
+colref = "(string or int: name of column in `data_frame`, or pandas Series, or array_like object)"
+colref_list = "(list of string or int: names of columns in `data_frame`, or pandas Series, or array_like objects)"
# TODO contents of columns
# TODO explain categorical
@@ -12,53 +12,57 @@
# TODO standardize positioning and casing of 'default'
docs = dict(
- data_frame=["A 'tidy' `pandas.DataFrame`"],
+ data_frame=[
+ "A `pandas.DataFrame`, or a `NumPy` array or a dictionary",
+ "which are tranformed internally to `pandas.DataFrame`. This argument needs"
+ "to be passed for column names (and not keyword names) to be used.",
+ ],
x=[
colref,
- "Values from this column are used to position marks along the x axis in cartesian coordinates.",
+ "Values from this column or array_like are used to position marks along the x axis in cartesian coordinates.",
"For horizontal `histogram`s, these values are used as inputs to `histfunc`.",
],
y=[
colref,
- "Values from this column are used to position marks along the y axis in cartesian coordinates.",
+ "Values from this column or array_like are used to position marks along the y axis in cartesian coordinates.",
"For vertical `histogram`s, these values are used as inputs to `histfunc`.",
],
z=[
colref,
- "Values from this column are used to position marks along the z axis in cartesian coordinates.",
+ "Values from this column or array_like are used to position marks along the z axis in cartesian coordinates.",
"For `density_heatmap` and `density_contour` these values are used as the inputs to `histfunc`.",
],
a=[
colref,
- "Values from this column are used to position marks along the a axis in ternary coordinates.",
+ "Values from this column or array_like are used to position marks along the a axis in ternary coordinates.",
],
b=[
colref,
- "Values from this column are used to position marks along the b axis in ternary coordinates.",
+ "Values from this column or array_like are used to position marks along the b axis in ternary coordinates.",
],
c=[
colref,
- "Values from this column are used to position marks along the c axis in ternary coordinates.",
+ "Values from this column or array_like are used to position marks along the c axis in ternary coordinates.",
],
r=[
colref,
- "Values from this column are used to position marks along the radial axis in polar coordinates.",
+ "Values from this column or array_like are used to position marks along the radial axis in polar coordinates.",
],
theta=[
colref,
- "Values from this column are used to position marks along the angular axis in polar coordinates.",
+ "Values from this column or array_like are used to position marks along the angular axis in polar coordinates.",
],
lat=[
colref,
- "Values from this column are used to position marks according to latitude on a map.",
+ "Values from this column or array_like are used to position marks according to latitude on a map.",
],
lon=[
colref,
- "Values from this column are used to position marks according to longitude on a map.",
+ "Values from this column or array_like are used to position marks according to longitude on a map.",
],
locations=[
colref,
- "Values from this column are be interpreted according to `locationmode` and mapped to longitude/latitude.",
+ "Values from this column or array_like are be interpreted according to `locationmode` and mapped to longitude/latitude.",
],
dimensions=[
"(list of strings, names of columns in `data_frame`)",
@@ -66,47 +70,59 @@
],
error_x=[
colref,
- "Values from this column are used to size x-axis error bars.",
+ "Values from this column or array_like are used to size x-axis error bars.",
"If `error_x_minus` is `None`, error bars will be symmetrical, otherwise `error_x` is used for the positive direction only.",
],
error_x_minus=[
colref,
- "Values from this column are used to size x-axis error bars in the negative direction.",
+ "Values from this column or array_like are used to size x-axis error bars in the negative direction.",
"Ignored if `error_x` is `None`.",
],
error_y=[
colref,
- "Values from this column are used to size y-axis error bars.",
+ "Values from this column or array_like are used to size y-axis error bars.",
"If `error_y_minus` is `None`, error bars will be symmetrical, otherwise `error_y` is used for the positive direction only.",
],
error_y_minus=[
colref,
- "Values from this column are used to size y-axis error bars in the negative direction.",
+ "Values from this column or array_like are used to size y-axis error bars in the negative direction.",
"Ignored if `error_y` is `None`.",
],
error_z=[
colref,
- "Values from this column are used to size z-axis error bars.",
+ "Values from this column or array_like are used to size z-axis error bars.",
"If `error_z_minus` is `None`, error bars will be symmetrical, otherwise `error_z` is used for the positive direction only.",
],
error_z_minus=[
colref,
- "Values from this column are used to size z-axis error bars in the negative direction.",
+ "Values from this column or array_like are used to size z-axis error bars in the negative direction.",
"Ignored if `error_z` is `None`.",
],
- color=[colref, "Values from this column are used to assign color to marks."],
+ color=[
+ colref,
+ "Values from this column or array_like are used to assign color to marks.",
+ ],
opacity=["(number, between 0 and 1) Sets the opacity for markers."],
line_dash=[
colref,
- "Values from this column are used to assign dash-patterns to lines.",
+ "Values from this column or array_like are used to assign dash-patterns to lines.",
],
line_group=[
colref,
- "Values from this column are used to group rows of `data_frame` into lines.",
+ "Values from this column or array_like are used to group rows of `data_frame` into lines.",
+ ],
+ symbol=[
+ colref,
+ "Values from this column or array_like are used to assign symbols to marks.",
+ ],
+ size=[
+ colref,
+ "Values from this column or array_like are used to assign mark sizes.",
+ ],
+ hover_name=[
+ colref,
+ "Values from this column or array_like appear in bold in the hover tooltip.",
],
- symbol=[colref, "Values from this column are used to assign symbols to marks."],
- size=[colref, "Values from this column are used to assign mark sizes."],
- hover_name=[colref, "Values from this column appear in bold in the hover tooltip."],
hover_data=[
colref_list,
"Values from these columns appear as extra data in the hover tooltip.",
@@ -115,26 +131,29 @@
colref_list,
"Values from these columns are extra data, to be used in widgets or Dash callbacks for example. This data is not user-visible but is included in events emitted by the figure (lasso selection etc.)",
],
- text=[colref, "Values from this column appear in the figure as text labels."],
+ text=[
+ colref,
+ "Values from this column or array_like appear in the figure as text labels.",
+ ],
locationmode=[
"(string, one of 'ISO-3', 'USA-states', 'country names')",
"Determines the set of locations used to match entries in `locations` to regions on the map.",
],
facet_row=[
colref,
- "Values from this column are used to assign marks to facetted subplots in the vertical direction.",
+ "Values from this column or array_like are used to assign marks to facetted subplots in the vertical direction.",
],
facet_col=[
colref,
- "Values from this column are used to assign marks to facetted subplots in the horizontal direction.",
+ "Values from this column or array_like are used to assign marks to facetted subplots in the horizontal direction.",
],
animation_frame=[
colref,
- "Values from this column are used to assign marks to animation frames.",
+ "Values from this column or array_like are used to assign marks to animation frames.",
],
animation_group=[
colref,
- "Values from this column are used to provide object-constancy across animation frames: rows with matching `animation_group`s will be treated as if they describe the same object in each frame.",
+ "Values from this column or array_like are used to provide object-constancy across animation frames: rows with matching `animation_group`s will be treated as if they describe the same object in each frame.",
],
symbol_sequence=[
"(list of strings defining plotly.js symbols)",
diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py
new file mode 100644
index 00000000000..08bb1a9cc95
--- /dev/null
+++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_px_input.py
@@ -0,0 +1,290 @@
+import plotly.express as px
+import numpy as np
+import pandas as pd
+import pytest
+import plotly.graph_objects as go
+import plotly
+from plotly.express._core import build_dataframe
+from pandas.util.testing import assert_frame_equal
+
+attrables = (
+ ["x", "y", "z", "a", "b", "c", "r", "theta", "size", "dimensions"]
+ + ["custom_data", "hover_name", "hover_data", "text"]
+ + ["error_x", "error_x_minus"]
+ + ["error_y", "error_y_minus", "error_z", "error_z_minus"]
+ + ["lat", "lon", "locations", "animation_group"]
+)
+array_attrables = ["dimensions", "custom_data", "hover_data"]
+group_attrables = ["animation_frame", "facet_row", "facet_col", "line_group"]
+
+all_attrables = attrables + group_attrables + ["color"]
+
+
+def test_numpy():
+ fig = px.scatter(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9])
+ assert np.all(fig.data[0].x == np.array([1, 2, 3]))
+ assert np.all(fig.data[0].y == np.array([2, 3, 4]))
+ assert np.all(fig.data[0].marker.color == np.array([1, 3, 9]))
+
+
+def test_numpy_labels():
+ fig = px.scatter(
+ x=[1, 2, 3], y=[2, 3, 4], labels={"x": "time"}
+ ) # other labels will be kw arguments
+ assert fig.data[0]["hovertemplate"] == "time=%{x}
y=%{y}"
+
+
+def test_with_index():
+ tips = px.data.tips()
+ fig = px.scatter(tips, x=tips.index, y="total_bill")
+ assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}"
+ fig = px.scatter(tips, x=tips.index, y=tips.total_bill)
+ assert fig.data[0]["hovertemplate"] == "index=%{x}
total_bill=%{y}"
+ fig = px.scatter(tips, x=tips.index, y=tips.total_bill, labels={"index": "number"})
+ assert fig.data[0]["hovertemplate"] == "number=%{x}
total_bill=%{y}"
+ # We do not allow "x=index"
+ with pytest.raises(ValueError) as err_msg:
+ fig = px.scatter(tips, x="index", y="total_bill")
+ assert "To use the index, pass it in directly as `df.index`." in str(
+ err_msg.value
+ )
+ tips = px.data.tips()
+ tips.index.name = "item"
+ fig = px.scatter(tips, x=tips.index, y="total_bill")
+ assert fig.data[0]["hovertemplate"] == "item=%{x}
total_bill=%{y}"
+
+
+def test_pandas_series():
+ tips = px.data.tips()
+ before_tip = tips.total_bill - tips.tip
+ fig = px.bar(tips, x="day", y=before_tip)
+ assert fig.data[0].hovertemplate == "day=%{x}
y=%{y}"
+ fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"})
+ assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}"
+
+
+def test_several_dataframes():
+ df = pd.DataFrame(dict(x=[0, 1], y=[1, 10], z=[0.1, 0.8]))
+ df2 = pd.DataFrame(dict(time=[23, 26], money=[100, 200]))
+ fig = px.scatter(df, x="z", y=df2.money, size="x")
+ assert fig.data[0].hovertemplate == "z=%{x}
y=%{y}
x=%{marker.size}"
+ fig = px.scatter(df2, x=df.z, y=df2.money, size=df.z)
+ assert fig.data[0].hovertemplate == "x=%{x}
money=%{y}
size=%{marker.size}"
+ # Name conflict
+ with pytest.raises(NameError) as err_msg:
+ fig = px.scatter(df, x="z", y=df2.money, size="y")
+ assert "A name conflict was encountered for argument y" in str(err_msg.value)
+ with pytest.raises(NameError) as err_msg:
+ fig = px.scatter(df, x="z", y=df2.money, size=df.y)
+ assert "A name conflict was encountered for argument y" in str(err_msg.value)
+
+ # No conflict when the dataframe is not given, fields are used
+ df = pd.DataFrame(dict(x=[0, 1], y=[3, 4]))
+ df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24]))
+ fig = px.scatter(x=df.y, y=df2.y)
+ assert np.all(fig.data[0].x == np.array([3, 4]))
+ assert np.all(fig.data[0].y == np.array([23, 24]))
+ assert fig.data[0].hovertemplate == "x=%{x}
y=%{y}"
+
+ df = pd.DataFrame(dict(x=[0, 1], y=[3, 4]))
+ df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24]))
+ df3 = pd.DataFrame(dict(y=[0.1, 0.2]))
+ fig = px.scatter(x=df.y, y=df2.y, size=df3.y)
+ assert np.all(fig.data[0].x == np.array([3, 4]))
+ assert np.all(fig.data[0].y == np.array([23, 24]))
+ assert fig.data[0].hovertemplate == "x=%{x}
y=%{y}
size=%{marker.size}"
+
+ df = pd.DataFrame(dict(x=[0, 1], y=[3, 4]))
+ df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24]))
+ df3 = pd.DataFrame(dict(y=[0.1, 0.2]))
+ fig = px.scatter(x=df.y, y=df2.y, hover_data=[df3.y])
+ assert np.all(fig.data[0].x == np.array([3, 4]))
+ assert np.all(fig.data[0].y == np.array([23, 24]))
+ assert (
+ fig.data[0].hovertemplate == "x=%{x}
y=%{y}
hover_data_0=%{customdata[0]}"
+ )
+
+
+def test_name_heuristics():
+ df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], z=[0.1, 0.2]))
+ fig = px.scatter(df, x=df.y, y=df.x, size=df.y)
+ assert np.all(fig.data[0].x == np.array([3, 4]))
+ assert np.all(fig.data[0].y == np.array([0, 1]))
+ assert fig.data[0].hovertemplate == "y=%{marker.size}
x=%{y}"
+
+
+def test_repeated_name():
+ iris = px.data.iris()
+ fig = px.scatter(
+ iris,
+ x="sepal_width",
+ y="sepal_length",
+ hover_data=["petal_length", "petal_width", "species_id"],
+ custom_data=["species_id", "species"],
+ )
+ assert fig.data[0].customdata.shape[1] == 4
+
+
+def test_arrayattrable_numpy():
+ tips = px.data.tips()
+ fig = px.scatter(
+ tips, x="total_bill", y="tip", hover_data=[np.random.random(tips.shape[0])]
+ )
+ assert (
+ fig.data[0]["hovertemplate"]
+ == "total_bill=%{x}
tip=%{y}
hover_data_0=%{customdata[0]}"
+ )
+ tips = px.data.tips()
+ fig = px.scatter(
+ tips,
+ x="total_bill",
+ y="tip",
+ hover_data=[np.random.random(tips.shape[0])],
+ labels={"hover_data_0": "suppl"},
+ )
+ assert (
+ fig.data[0]["hovertemplate"]
+ == "total_bill=%{x}
tip=%{y}
suppl=%{customdata[0]}"
+ )
+
+
+def test_wrong_column_name():
+ with pytest.raises(ValueError) as err_msg:
+ fig = px.scatter(px.data.tips(), x="bla", y="wrong")
+ assert "Value of 'x' is not the name of a column in 'data_frame'" in str(
+ err_msg.value
+ )
+
+
+def test_missing_data_frame():
+ with pytest.raises(ValueError) as err_msg:
+ fig = px.scatter(x="arg1", y="arg2")
+ assert "String or int arguments are only possible" in str(err_msg.value)
+
+
+def test_wrong_dimensions_of_array():
+ with pytest.raises(ValueError) as err_msg:
+ fig = px.scatter(x=[1, 2, 3], y=[2, 3, 4, 5])
+ assert "All arguments should have the same length." in str(err_msg.value)
+
+
+def test_wrong_dimensions_mixed_case():
+ with pytest.raises(ValueError) as err_msg:
+ df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25]))
+ fig = px.scatter(df, x="time", y="temperature", color=[1, 3, 9, 5])
+ assert "All arguments should have the same length." in str(err_msg.value)
+
+
+def test_wrong_dimensions():
+ with pytest.raises(ValueError) as err_msg:
+ fig = px.scatter(px.data.tips(), x="tip", y=[1, 2, 3])
+ assert "All arguments should have the same length." in str(err_msg.value)
+ # the order matters
+ with pytest.raises(ValueError) as err_msg:
+ fig = px.scatter(px.data.tips(), x=[1, 2, 3], y="tip")
+ assert "All arguments should have the same length." in str(err_msg.value)
+ with pytest.raises(ValueError):
+ fig = px.scatter(px.data.tips(), x=px.data.iris().index, y="tip")
+ # assert "All arguments should have the same length." in str(err_msg.value)
+
+
+def test_multiindex_raise_error():
+ index = pd.MultiIndex.from_product(
+ [[1, 2, 3], ["a", "b"]], names=["first", "second"]
+ )
+ df = pd.DataFrame(np.random.random((6, 3)), index=index, columns=["A", "B", "C"])
+ # This is ok
+ fig = px.scatter(df, x="A", y="B")
+ with pytest.raises(TypeError) as err_msg:
+ fig = px.scatter(df, x=df.index, y="B")
+ assert "pandas MultiIndex is not supported by plotly express" in str(
+ err_msg.value
+ )
+
+
+def test_build_df_from_lists():
+ # Just lists
+ args = dict(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9])
+ output = {key: key for key in args}
+ df = pd.DataFrame(args)
+ args["data_frame"] = None
+ out = build_dataframe(args, all_attrables, array_attrables)
+ assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1))
+ out.pop("data_frame")
+ assert out == output
+
+ # Arrays
+ args = dict(x=np.array([1, 2, 3]), y=np.array([2, 3, 4]), color=[1, 3, 9])
+ output = {key: key for key in args}
+ df = pd.DataFrame(args)
+ args["data_frame"] = None
+ out = build_dataframe(args, all_attrables, array_attrables)
+ assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1))
+ out.pop("data_frame")
+ assert out == output
+
+
+def test_build_df_with_index():
+ tips = px.data.tips()
+ args = dict(data_frame=tips, x=tips.index, y="total_bill")
+ out = build_dataframe(args, all_attrables, array_attrables)
+ assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"])
+
+
+def test_splom_case():
+ iris = px.data.iris()
+ fig = px.scatter_matrix(iris)
+ assert len(fig.data[0].dimensions) == len(iris.columns)
+ dic = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]}
+ fig = px.scatter_matrix(dic)
+ assert np.all(fig.data[0].dimensions[0].values == np.array(dic["a"]))
+ ar = np.arange(9).reshape((3, 3))
+ fig = px.scatter_matrix(ar)
+ assert np.all(fig.data[0].dimensions[0].values == ar[:, 0])
+
+
+def test_int_col_names():
+ # DataFrame with int column names
+ lengths = pd.DataFrame(np.random.random(100))
+ fig = px.histogram(lengths, x=0)
+ assert np.all(np.array(lengths).flatten() == fig.data[0].x)
+ # Numpy array
+ ar = np.arange(100).reshape((10, 10))
+ fig = px.scatter(ar, x=2, y=8)
+ assert np.all(fig.data[0].x == ar[:, 2])
+
+
+def test_data_frame_from_dict():
+ fig = px.scatter({"time": [0, 1], "money": [1, 2]}, x="time", y="money")
+ assert fig.data[0].hovertemplate == "time=%{x}
money=%{y}"
+ assert np.all(fig.data[0].x == [0, 1])
+
+
+def test_arguments_not_modified():
+ iris = px.data.iris()
+ petal_length = iris.petal_length
+ hover_data = [iris.sepal_length]
+ fig = px.scatter(iris, x=petal_length, y="petal_width", hover_data=hover_data)
+ assert iris.petal_length.equals(petal_length)
+ assert iris.sepal_length.equals(hover_data[0])
+
+
+def test_pass_df_columns():
+ tips = px.data.tips()
+ fig = px.histogram(
+ tips,
+ x="total_bill",
+ y="tip",
+ color="sex",
+ marginal="rug",
+ hover_data=tips.columns,
+ )
+ assert fig.data[1].hovertemplate.count("customdata") == len(tips.columns)
+ tips_copy = px.data.tips()
+ assert tips_copy.columns.equals(tips.columns)
+
+
+def test_size_column():
+ df = px.data.tips()
+ fig = px.scatter(df, x=df["size"], y=df.tip)
+ assert fig.data[0].hovertemplate == "size=%{x}
tip=%{y}"