diff --git a/CHANGELOG.md b/CHANGELOG.md index 9d9b6ba1b8c..862eacfc32b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,9 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Updated -- Updated plotly.py to use base64 encoding of arrays in plotly JSON to improve performance. +- Updated plotly.py to use base64 encoding of arrays in plotly JSON to improve performance. - Add `subtitle` attribute to all Plotly Express traces +- Make plotly-express dataframe agnostic via Narwhals [#4790](https://github.com/plotly/plotly.py/pull/4790) ## [5.24.1] - 2024-09-12 diff --git a/packages/python/plotly/_plotly_utils/basevalidators.py b/packages/python/plotly/_plotly_utils/basevalidators.py index 21731afad43..73fe92e3511 100644 --- a/packages/python/plotly/_plotly_utils/basevalidators.py +++ b/packages/python/plotly/_plotly_utils/basevalidators.py @@ -8,6 +8,7 @@ import re import sys import warnings +import narwhals.stable.v1 as nw from _plotly_utils.optional_imports import get_module @@ -72,8 +73,6 @@ def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False): """ np = get_module("numpy") - # Don't force pandas to be loaded, we only want to know if it's already loaded - pd = get_module("pandas", should_load=False) assert np is not None # ### Process kind ### @@ -93,34 +92,26 @@ def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False): "O": "object", } - # Handle pandas Series and Index objects - if pd and isinstance(v, (pd.Series, pd.Index)): - if v.dtype.kind in numeric_kinds: - # Get the numeric numpy array so we use fast path below - v = v.values - elif v.dtype.kind == "M": - # Convert datetime Series/Index to numpy array of datetimes - if isinstance(v, pd.Series): - with warnings.catch_warnings(): - warnings.simplefilter("ignore", FutureWarning) - # Series.dt.to_pydatetime will return Index[object] - # https://github.com/pandas-dev/pandas/pull/52459 - v = np.array(v.dt.to_pydatetime()) - else: - # DatetimeIndex - v = v.to_pydatetime() - elif pd and isinstance(v, pd.DataFrame) and len(set(v.dtypes)) == 1: - dtype = v.dtypes.tolist()[0] - if dtype.kind in numeric_kinds: - v = v.values - elif dtype.kind == "M": - with warnings.catch_warnings(): - warnings.simplefilter("ignore", FutureWarning) - # Series.dt.to_pydatetime will return Index[object] - # https://github.com/pandas-dev/pandas/pull/52459 - v = [ - np.array(row.dt.to_pydatetime()).tolist() for i, row in v.iterrows() - ] + # With `pass_through=True`, the original object will be returned if unable to convert + # to a Narwhals DataFrame or Series. + v = nw.from_native(v, allow_series=True, pass_through=True) + + if isinstance(v, nw.Series): + if v.dtype == nw.Datetime and v.dtype.time_zone is not None: + # Remove time zone so that local time is displayed + v = v.dt.replace_time_zone(None).to_numpy() + else: + v = v.to_numpy() + elif isinstance(v, nw.DataFrame): + schema = v.schema + overrides = {} + for key, val in schema.items(): + if val == nw.Datetime and val.time_zone is not None: + # Remove time zone so that local time is displayed + overrides[key] = nw.col(key).dt.replace_time_zone(None) + if overrides: + v = v.with_columns(**overrides) + v = v.to_numpy() if not isinstance(v, np.ndarray): # v has its own logic on how to convert itself into a numpy array @@ -193,6 +184,7 @@ def is_homogeneous_array(v): np and isinstance(v, np.ndarray) or (pd and isinstance(v, (pd.Series, pd.Index))) + or (isinstance(v, nw.Series)) ): return True if is_numpy_convertable(v): diff --git a/packages/python/plotly/_plotly_utils/tests/validators/test_pandas_series_input.py b/packages/python/plotly/_plotly_utils/tests/validators/test_pandas_series_input.py index ef8818181db..8bb50d1808b 100644 --- a/packages/python/plotly/_plotly_utils/tests/validators/test_pandas_series_input.py +++ b/packages/python/plotly/_plotly_utils/tests/validators/test_pandas_series_input.py @@ -73,12 +73,13 @@ def color_categorical_pandas(request, pandas_type): def dates_array(request): return np.array( [ - datetime(year=2013, month=10, day=10), - datetime(year=2013, month=11, day=10), - datetime(year=2013, month=12, day=10), - datetime(year=2014, month=1, day=10), - datetime(year=2014, month=2, day=10), - ] + "2013-10-10", + "2013-11-10", + "2013-12-10", + "2014-01-10", + "2014-02-10", + ], + dtype="datetime64[ns]", ) @@ -183,7 +184,7 @@ def test_data_array_validator_dates_series( assert isinstance(res, np.ndarray) # Check dtype - assert res.dtype == "object" + assert res.dtype == "=1.13.3 ## scipy deps for some FigureFactory functions ## scipy diff --git a/packages/python/plotly/plotly/data/__init__.py b/packages/python/plotly/plotly/data/__init__.py index 7669c4588c3..ef6bcdd2f06 100644 --- a/packages/python/plotly/plotly/data/__init__.py +++ b/packages/python/plotly/plotly/data/__init__.py @@ -1,32 +1,73 @@ """ Built-in datasets for demonstration, educational and test purposes. """ +import os +from importlib import import_module +import narwhals.stable.v1 as nw -def gapminder(datetimes=False, centroids=False, year=None, pretty_names=False): +AVAILABLE_BACKENDS = {"pandas", "polars", "pyarrow", "modin", "cudf"} +BACKENDS_WITH_INDEX_SUPPORT = {"pandas", "modin", "cudf"} + + +def gapminder( + datetimes=False, + centroids=False, + year=None, + pretty_names=False, + return_type="pandas", +): """ Each row represents a country on a given year. https://www.gapminder.org/data/ - Returns: - A `pandas.DataFrame` with 1704 rows and the following columns: + Parameters + ---------- + datetimes: bool + Whether or not 'year' column will converted to datetime type + + centroids: bool + If True, ['centroid_lat', 'centroid_lon'] columns are added + + year: int | None + If provided, the dataset will be filtered for that year + + pretty_names: bool + If True, prettifies the column names + + return_type: {'pandas', 'polars', 'pyarrow', 'modin', 'cudf'} + Type of the resulting dataframe + + Returns + ------- + Dataframe of `return_type` type + Dataframe with 1704 rows and the following columns: `['country', 'continent', 'year', 'lifeExp', 'pop', 'gdpPercap', 'iso_alpha', 'iso_num']`. + If `datetimes` is True, the 'year' column will be a datetime column If `centroids` is True, two new columns are added: ['centroid_lat', 'centroid_lon'] If `year` is an integer, the dataset will be filtered for that year """ - df = _get_dataset("gapminder") + df = nw.from_native( + _get_dataset("gapminder", return_type=return_type), eager_only=True + ) if year: - df = df[df["year"] == year] + df = df.filter(nw.col("year") == year) if datetimes: - df["year"] = (df["year"].astype(str) + "-01-01").astype("datetime64[ns]") + df = df.with_columns( + # Concatenate the year value with the literal "-01-01" so that it can be + # casted to datetime from "%Y-%m-%d" format + nw.concat_str( + [nw.col("year").cast(nw.String()), nw.lit("-01-01")] + ).str.to_datetime(format="%Y-%m-%d") + ) if not centroids: - df = df.drop(["centroid_lat", "centroid_lon"], axis=1) + df = df.drop("centroid_lat", "centroid_lon") if pretty_names: - df.rename( - mapper=dict( + df = df.rename( + dict( country="Country", continent="Continent", year="Year", @@ -37,27 +78,36 @@ def gapminder(datetimes=False, centroids=False, year=None, pretty_names=False): iso_num="ISO Numeric Country Code", centroid_lat="Centroid Latitude", centroid_lon="Centroid Longitude", - ), - axis="columns", - inplace=True, + ) ) - return df + return df.to_native() -def tips(pretty_names=False): +def tips(pretty_names=False, return_type="pandas"): """ Each row represents a restaurant bill. https://vincentarelbundock.github.io/Rdatasets/doc/reshape2/tips.html - Returns: - A `pandas.DataFrame` with 244 rows and the following columns: - `['total_bill', 'tip', 'sex', 'smoker', 'day', 'time', 'size']`.""" + Parameters + ---------- + pretty_names: bool + If True, prettifies the column names + + return_type: {'pandas', 'polars', 'pyarrow', 'modin', 'cudf'} + Type of the resulting dataframe + + Returns + ------- + Dataframe of `return_type` type + Dataframe with 244 rows and the following columns: + `['total_bill', 'tip', 'sex', 'smoker', 'day', 'time', 'size']`. + """ - df = _get_dataset("tips") + df = nw.from_native(_get_dataset("tips", return_type=return_type), eager_only=True) if pretty_names: - df.rename( - mapper=dict( + df = df.rename( + dict( total_bill="Total Bill", tip="Tip", sex="Payer Gender", @@ -65,54 +115,78 @@ def tips(pretty_names=False): day="Day of Week", time="Meal", size="Party Size", - ), - axis="columns", - inplace=True, + ) ) - return df + return df.to_native() -def iris(): +def iris(return_type="pandas"): """ Each row represents a flower. https://en.wikipedia.org/wiki/Iris_flower_data_set - Returns: - A `pandas.DataFrame` with 150 rows and the following columns: - `['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species', 'species_id']`.""" - return _get_dataset("iris") + Parameters + ---------- + return_type: {'pandas', 'polars', 'pyarrow', 'modin', 'cudf'} + Type of the resulting dataframe + + Returns + ------- + Dataframe of `return_type` type + Dataframe with 150 rows and the following columns: + `['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species', 'species_id']`. + """ + return _get_dataset("iris", return_type=return_type) -def wind(): +def wind(return_type="pandas"): """ Each row represents a level of wind intensity in a cardinal direction, and its frequency. - Returns: - A `pandas.DataFrame` with 128 rows and the following columns: - `['direction', 'strength', 'frequency']`.""" - return _get_dataset("wind") + Parameters + ---------- + return_type: {'pandas', 'polars', 'pyarrow', 'modin', 'cudf'} + Type of the resulting dataframe + Returns + ------- + Dataframe of `return_type` type + Dataframe with 128 rows and the following columns: + `['direction', 'strength', 'frequency']`. + """ + return _get_dataset("wind", return_type=return_type) -def election(): + +def election(return_type="pandas"): """ Each row represents voting results for an electoral district in the 2013 Montreal mayoral election. - Returns: - A `pandas.DataFrame` with 58 rows and the following columns: - `['district', 'Coderre', 'Bergeron', 'Joly', 'total', 'winner', 'result', 'district_id']`.""" - return _get_dataset("election") + Parameters + ---------- + return_type: {'pandas', 'polars', 'pyarrow', 'modin', 'cudf'} + Type of the resulting dataframe + + Returns + ------- + Dataframe of `return_type` type + Dataframe with 58 rows and the following columns: + `['district', 'Coderre', 'Bergeron', 'Joly', 'total', 'winner', 'result', 'district_id']`. + """ + return _get_dataset("election", return_type=return_type) def election_geojson(): """ Each feature represents an electoral district in the 2013 Montreal mayoral election. - Returns: + Returns + ------- A GeoJSON-formatted `dict` with 58 polygon or multi-polygon features whose `id` is an electoral district numerical ID and whose `district` property is the ID and - district name.""" + district name. + """ import gzip import json import os @@ -128,95 +202,228 @@ def election_geojson(): return result -def carshare(): +def carshare(return_type="pandas"): """ Each row represents the availability of car-sharing services near the centroid of a zone in Montreal over a month-long period. - Returns: - A `pandas.DataFrame` with 249 rows and the following columns: - `['centroid_lat', 'centroid_lon', 'car_hours', 'peak_hour']`.""" - return _get_dataset("carshare") + Parameters + ---------- + return_type: {'pandas', 'polars', 'pyarrow', 'modin', 'cudf'} + Type of the resulting dataframe + + Returns + ------- + Dataframe of `return_type` type + Dataframe` with 249 rows and the following columns: + `['centroid_lat', 'centroid_lon', 'car_hours', 'peak_hour']`. + """ + return _get_dataset("carshare", return_type=return_type) -def stocks(indexed=False, datetimes=False): +def stocks(indexed=False, datetimes=False, return_type="pandas"): """ Each row in this wide dataset represents closing prices from 6 tech stocks in 2018/2019. - Returns: - A `pandas.DataFrame` with 100 rows and the following columns: + Parameters + ---------- + indexed: bool + Whether or not the 'date' column is used as the index and the column index + is named 'company'. Applicable only if `return_type='pandas'` + + datetimes: bool + Whether or not the 'date' column will be of datetime type + + return_type: {'pandas', 'polars', 'pyarrow', 'modin', 'cudf'} + Type of the resulting dataframe + + Returns + ------- + Dataframe of `return_type` type + Dataframe with 100 rows and the following columns: `['date', 'GOOG', 'AAPL', 'AMZN', 'FB', 'NFLX', 'MSFT']`. If `indexed` is True, the 'date' column is used as the index and the column index + is named 'company' If `datetimes` is True, the 'date' column will be a datetime column - is named 'company'""" - df = _get_dataset("stocks") + """ + if indexed and return_type not in BACKENDS_WITH_INDEX_SUPPORT: + msg = f"Backend '{return_type}' does not support setting index" + raise NotImplementedError(msg) + + df = nw.from_native( + _get_dataset("stocks", return_type=return_type), eager_only=True + ).with_columns(nw.col("date").cast(nw.String())) + if datetimes: - df["date"] = df["date"].astype("datetime64[ns]") - if indexed: - df = df.set_index("date") + df = df.with_columns(nw.col("date").str.to_datetime()) + + if indexed: # then it must be pandas + df = df.to_native().set_index("date") df.columns.name = "company" - return df + return df + + return df.to_native() -def experiment(indexed=False): +def experiment(indexed=False, return_type="pandas"): """ Each row in this wide dataset represents the results of 100 simulated participants on three hypothetical experiments, along with their gender and control/treatment group. + Parameters + ---------- + indexed: bool + If True, then the index is named "participant". + Applicable only if `return_type='pandas'` - Returns: - A `pandas.DataFrame` with 100 rows and the following columns: + return_type: {'pandas', 'polars', 'pyarrow', 'modin', 'cudf'} + Type of the resulting dataframe + + Returns + ------- + Dataframe of `return_type` type + Dataframe with 100 rows and the following columns: `['experiment_1', 'experiment_2', 'experiment_3', 'gender', 'group']`. - If `indexed` is True, the data frame index is named "participant" """ - df = _get_dataset("experiment") - if indexed: + If `indexed` is True, the data frame index is named "participant" + """ + + if indexed and return_type not in BACKENDS_WITH_INDEX_SUPPORT: + msg = f"Backend '{return_type}' does not support setting index" + raise NotImplementedError(msg) + + df = nw.from_native( + _get_dataset("experiment", return_type=return_type), eager_only=True + ) + if indexed: # then it must be pandas + df = df.to_native() df.index.name = "participant" - return df + return df + return df.to_native() -def medals_wide(indexed=False): +def medals_wide(indexed=False, return_type="pandas"): """ This dataset represents the medal table for Olympic Short Track Speed Skating for the top three nations as of 2020. - Returns: - A `pandas.DataFrame` with 3 rows and the following columns: + Parameters + ---------- + indexed: bool + Whether or not the 'nation' column is used as the index and the column index + is named 'medal'. Applicable only if `return_type='pandas'` + + return_type: {'pandas', 'polars', 'pyarrow', 'modin', 'cudf'} + Type of the resulting dataframe + + Returns + ------- + Dataframe of `return_type` type + Dataframe with 3 rows and the following columns: `['nation', 'gold', 'silver', 'bronze']`. If `indexed` is True, the 'nation' column is used as the index and the column index - is named 'medal'""" - df = _get_dataset("medals") - if indexed: - df = df.set_index("nation") + is named 'medal' + """ + + if indexed and return_type not in BACKENDS_WITH_INDEX_SUPPORT: + msg = f"Backend '{return_type}' does not support setting index" + raise NotImplementedError(msg) + + df = nw.from_native( + _get_dataset("medals", return_type=return_type), eager_only=True + ) + if indexed: # then it must be pandas + df = df.to_native().set_index("nation") df.columns.name = "medal" - return df + return df + return df.to_native() -def medals_long(indexed=False): +def medals_long(indexed=False, return_type="pandas"): """ This dataset represents the medal table for Olympic Short Track Speed Skating for the top three nations as of 2020. - Returns: - A `pandas.DataFrame` with 9 rows and the following columns: - `['nation', 'medal', 'count']`. - If `indexed` is True, the 'nation' column is used as the index.""" - df = _get_dataset("medals").melt( - id_vars=["nation"], value_name="count", var_name="medal" + Parameters + ---------- + indexed: bool + Whether or not the 'nation' column is used as the index. + Applicable only if `return_type='pandas'` + + return_type: {'pandas', 'polars', 'pyarrow', 'modin', 'cudf'} + Type of the resulting dataframe + + Returns + ------- + Dataframe of `return_type` type + Dataframe with 9 rows and the following columns: `['nation', 'medal', 'count']`. + If `indexed` is True, the 'nation' column is used as the index. + """ + + if indexed and return_type not in BACKENDS_WITH_INDEX_SUPPORT: + msg = f"Backend '{return_type}' does not support setting index" + raise NotImplementedError(msg) + + df = nw.from_native( + _get_dataset("medals", return_type=return_type), eager_only=True + ).unpivot( + index=["nation"], + value_name="count", + variable_name="medal", ) if indexed: - df = df.set_index("nation") - return df + df = nw.maybe_set_index(df, "nation") + return df.to_native() -def _get_dataset(d): - import pandas - import os +def _get_dataset(d, return_type): + """ + Loads the dataset using the specified backend. - return pandas.read_csv( - os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "package_data", - "datasets", - d + ".csv.gz", - ) + Notice that the available backends are 'pandas', 'polars', 'pyarrow' and they all have + a `read_csv` function (pyarrow has it via pyarrow.csv). Therefore we can dynamically + load the library using `importlib.import_module` and then call + `backend.read_csv(filepath)`. + + Parameters + ---------- + d: str + Name of the dataset to load. + + return_type: {'pandas', 'polars', 'pyarrow', 'modin', 'cudf'} + Type of the resulting dataframe + + Returns + ------- + Dataframe of `return_type` type + """ + filepath = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "package_data", + "datasets", + d + ".csv.gz", ) + + if return_type not in AVAILABLE_BACKENDS: + msg = ( + f"Unsupported return_type. Found {return_type}, expected one " + f"of {AVAILABLE_BACKENDS}" + ) + raise NotImplementedError(msg) + + try: + if return_type == "pyarrow": + module_to_load = "pyarrow.csv" + elif return_type == "modin": + module_to_load = "modin.pandas" + else: + module_to_load = return_type + backend = import_module(module_to_load) + except ModuleNotFoundError: + msg = f"return_type={return_type}, but {return_type} is not installed" + raise ModuleNotFoundError(msg) + + try: + return backend.read_csv(filepath) + except Exception as e: + msg = f"Unable to read '{d}' dataset due to: {e}" + raise Exception(msg).with_traceback(e.__traceback__) diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py index 935d6745788..33db532cd4d 100644 --- a/packages/python/plotly/plotly/express/__init__.py +++ b/packages/python/plotly/plotly/express/__init__.py @@ -4,11 +4,11 @@ """ from plotly import optional_imports -pd = optional_imports.get_module("pandas") -if pd is None: +np = optional_imports.get_module("numpy") +if np is None: raise ImportError( """\ -Plotly express requires pandas to be installed.""" +Plotly express requires numpy to be installed.""" ) from ._imshow import imshow diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 0e71f41becf..b3bcd096d34 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -7,9 +7,6 @@ from _plotly_utils.basevalidators import ColorscaleValidator from plotly.colors import qualitative, sequential import math -from packaging import version -import pandas as pd -import numpy as np from plotly._subplots import ( make_subplots, @@ -17,9 +14,16 @@ _subplot_type_for_trace_type, ) -pandas_2_2_0 = version.parse(pd.__version__) >= version.parse("2.2.0") +import narwhals.stable.v1 as nw + +# The reason to use narwhals.stable.v1 is to have a stable and perfectly +# backwards-compatible API, hence the confidence to not pin the Narwhals version exactly, +# allowing for multiple major libraries to have Narwhals as a dependency without +# forbidding users to install them all together due to dependency conflicts. NO_COLOR = "px_no_color_constant" + + trendline_functions = dict( lowess=lowess, rolling=rolling, ewm=ewm, expanding=expanding, ols=ols ) @@ -154,8 +158,57 @@ def invert_label(args, column): return column -def _is_continuous(df, col_name): - return df[col_name].dtype.kind in "ifc" +def _is_continuous(df: nw.DataFrame, col_name: str) -> bool: + if nw.dependencies.is_pandas_like_dataframe(df_native := df.to_native()): + # fastpath for pandas: Narwhals' Series.dtype has a bit of overhead, as it + # tries to distinguish between true "object" columns, and "string" columns + # disguised as "object". But here, we deal with neither. + return df_native[col_name].dtype.kind in "ifc" + return df.get_column(col_name).dtype.is_numeric() + + +def _to_unix_epoch_seconds(s: nw.Series) -> nw.Series: + dtype = s.dtype + if dtype == nw.Date: + return s.dt.timestamp("ms") / 1_000 + if dtype == nw.Datetime: + if dtype.time_unit in ("s", "ms"): + return s.dt.timestamp("ms") / 1_000 + elif dtype.time_unit == "us": + return s.dt.timestamp("us") / 1_000_000 + elif dtype.time_unit == "ns": + return s.dt.timestamp("ns") / 1_000_000_000 + else: + msg = "Unexpected dtype, please report a bug" + raise ValueError(msg) + else: + msg = f"Expected Date or Datetime, got {dtype}" + raise TypeError(msg) + + +def _generate_temporary_column_name(n_bytes, columns) -> str: + """Wraps of Narwhals generate_temporary_column_name to generate a token + which is guaranteed to not be in columns, nor in [col + token for col in columns] + """ + counter = 0 + while True: + # This is guaranteed to not be in columns by Narwhals + token = nw.generate_temporary_column_name(n_bytes, columns=columns) + + # Now check that it is not in the [col + token for col in columns] list + if token not in {f"{c}{token}" for c in columns}: + return token + + counter += 1 + if counter > 100: + msg = ( + "Internal Error: Plotly was not able to generate a column name with " + f"{n_bytes=} and not in {columns}.\n" + "Please report this to " + "https://github.com/plotly/plotly.py/issues/new and we will try to " + "replicate and fix it." + ) + raise AssertionError(msg) def get_decorated_label(args, column, role): @@ -270,8 +323,12 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): fit_results : dict fit information to be used for trendlines """ + trace_data: nw.DataFrame + df: nw.DataFrame = args["data_frame"] + if "line_close" in args and args["line_close"]: - trace_data = pd.concat([trace_data, trace_data.iloc[:1]]) + trace_data = nw.concat([trace_data, trace_data.head(1)], how="vertical") + trace_patch = trace_spec.trace_patch.copy() or {} fit_results = None hover_header = "" @@ -280,17 +337,14 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): attr_label = get_decorated_label(args, attr_value, attr_name) if attr_name == "dimensions": dims = [ - (name, column) - for (name, column) in trace_data.items() + (name, trace_data.get_column(name)) + for name in trace_data.columns if ((not attr_value) or (name in attr_value)) - and ( - trace_spec.constructor != go.Parcoords - or _is_continuous(args["data_frame"], name) - ) + and (trace_spec.constructor != go.Parcoords or _is_continuous(df, name)) and ( trace_spec.constructor != go.Parcats or (attr_value is not None and name in attr_value) - or len(args["data_frame"][name].unique()) + or nw.to_py_scalar(df.get_column(name).n_unique()) <= args["dimensions_max_cardinality"] ) ] @@ -308,7 +362,7 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): if attr_name == "size": if "marker" not in trace_patch: trace_patch["marker"] = dict() - trace_patch["marker"]["size"] = trace_data[attr_value] + trace_patch["marker"]["size"] = trace_data.get_column(attr_value) trace_patch["marker"]["sizemode"] = "area" trace_patch["marker"]["sizeref"] = sizeref mapping_labels[attr_label] = "%{marker.size}" @@ -322,28 +376,32 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): if ( args["x"] and args["y"] - and len(trace_data[[args["x"], args["y"]]].dropna()) > 1 + and len( + trace_data.select(nw.col(args["x"], args["y"])).drop_nulls() + ) + > 1 ): # sorting is bad but trace_specs with "trendline" have no other attrs - sorted_trace_data = trace_data.sort_values(by=args["x"]) - y = sorted_trace_data[args["y"]].values - x = sorted_trace_data[args["x"]].values + sorted_trace_data = trace_data.sort(by=args["x"], nulls_last=True) + y = sorted_trace_data.get_column(args["y"]) + x = sorted_trace_data.get_column(args["x"]) - if x.dtype.type == np.datetime64: + if x.dtype == nw.Datetime or x.dtype == nw.Date: # convert to unix epoch seconds - x = x.astype(np.int64) / 10**9 - elif x.dtype.type == np.object_: + x = _to_unix_epoch_seconds(x) + elif not x.dtype.is_numeric(): try: - x = x.astype(np.float64) + x = x.cast(nw.Float64()) except ValueError: raise ValueError( "Could not convert value of 'x' ('%s') into a numeric type. " "If 'x' contains stringified dates, please convert to a datetime column." % args["x"] ) - if y.dtype.type == np.object_: + + if not y.dtype.is_numeric(): try: - y = y.astype(np.float64) + y = y.cast(nw.Float64()) except ValueError: raise ValueError( "Could not convert value of 'y' into a numeric type." @@ -353,19 +411,30 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): # otherwise numpy/pandas can mess with the timezones # NB this means trendline functions must output one-to-one with the input series # i.e. we can't do resampling, because then the X values might not line up! - non_missing = np.logical_not( - np.logical_or(np.isnan(y), np.isnan(x)) + non_missing = ~(x.is_null() | y.is_null()) + trace_patch["x"] = sorted_trace_data.filter(non_missing).get_column( + args["x"] ) - trace_patch["x"] = sorted_trace_data[args["x"]][non_missing] + if ( + trace_patch["x"].dtype == nw.Datetime + and trace_patch["x"].dtype.time_zone is not None + ): + # Remove time zone so that local time is displayed + trace_patch["x"] = ( + trace_patch["x"].dt.replace_time_zone(None).to_numpy() + ) + else: + trace_patch["x"] = trace_patch["x"].to_numpy() + trendline_function = trendline_functions[attr_value] y_out, hover_header, fit_results = trendline_function( args["trendline_options"], - sorted_trace_data[args["x"]], - x, - y, + sorted_trace_data.get_column(args["x"]), # narwhals series + x.to_numpy(), # numpy array + y.to_numpy(), # numpy array args["x"], args["y"], - non_missing, + non_missing.to_numpy(), # numpy array ) assert len(y_out) == len( trace_patch["x"] @@ -378,19 +447,19 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): arr = "arrayminus" if attr_name.endswith("minus") else "array" if error_xy not in trace_patch: trace_patch[error_xy] = {} - trace_patch[error_xy][arr] = trace_data[attr_value] + trace_patch[error_xy][arr] = trace_data.get_column(attr_value) elif attr_name == "custom_data": if len(attr_value) > 0: # here we store a data frame in customdata, and it's serialized # as a list of row lists, which is what we want - trace_patch["customdata"] = trace_data[attr_value] + trace_patch["customdata"] = trace_data.select(nw.col(attr_value)) elif attr_name == "hover_name": if trace_spec.constructor not in [ go.Histogram, go.Histogram2d, go.Histogram2dContour, ]: - trace_patch["hovertext"] = trace_data[attr_value] + trace_patch["hovertext"] = trace_data.get_column(attr_value) if hover_header == "": hover_header = "%{hovertext}

" elif attr_name == "hover_data": @@ -424,14 +493,19 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): if len(customdata_cols) > 0: # here we store a data frame in customdata, and it's serialized # as a list of row lists, which is what we want - trace_patch["customdata"] = trace_data[customdata_cols] + + # dict.fromkeys(customdata_cols) allows to deduplicate column + # names, yet maintaining the original order. + trace_patch["customdata"] = trace_data.select( + *[nw.col(c) for c in dict.fromkeys(customdata_cols)] + ) elif attr_name == "color": if trace_spec.constructor in [ go.Choropleth, go.Choroplethmap, go.Choroplethmapbox, ]: - trace_patch["z"] = trace_data[attr_value] + trace_patch["z"] = trace_data.get_column(attr_value) trace_patch["coloraxis"] = "coloraxis1" mapping_labels[attr_label] = "%{z}" elif trace_spec.constructor in [ @@ -445,7 +519,9 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): trace_patch["marker"] = dict() if args.get("color_is_continuous"): - trace_patch["marker"]["colors"] = trace_data[attr_value] + trace_patch["marker"]["colors"] = trace_data.get_column( + attr_value + ) trace_patch["marker"]["coloraxis"] = "coloraxis1" mapping_labels[attr_label] = "%{color}" else: @@ -454,7 +530,12 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): mapping = args["color_discrete_map"].copy() else: mapping = {} - for cat in trace_data[attr_value]: + for cat in trace_data.get_column(attr_value).to_list(): + # although trace_data.get_column(attr_value) is a Narwhals + # Series, which is an iterable, explicitly calling a to_list() + # makes sure that the elements we loop over are python objects + # in all cases, since depending on the backend this may not be + # the case (e.g. PyArrow) if mapping.get(cat) is None: mapping[cat] = args["color_discrete_sequence"][ len(mapping) % len(args["color_discrete_sequence"]) @@ -466,24 +547,24 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): colorable = "line" if colorable not in trace_patch: trace_patch[colorable] = dict() - trace_patch[colorable]["color"] = trace_data[attr_value] + trace_patch[colorable]["color"] = trace_data.get_column(attr_value) trace_patch[colorable]["coloraxis"] = "coloraxis1" mapping_labels[attr_label] = "%%{%s.color}" % colorable elif attr_name == "animation_group": - trace_patch["ids"] = trace_data[attr_value] + trace_patch["ids"] = trace_data.get_column(attr_value) elif attr_name == "locations": - trace_patch[attr_name] = trace_data[attr_value] + trace_patch[attr_name] = trace_data.get_column(attr_value) mapping_labels[attr_label] = "%{location}" elif attr_name == "values": - trace_patch[attr_name] = trace_data[attr_value] + trace_patch[attr_name] = trace_data.get_column(attr_value) _label = "value" if attr_label == "values" else attr_label mapping_labels[_label] = "%{value}" elif attr_name == "parents": - trace_patch[attr_name] = trace_data[attr_value] + trace_patch[attr_name] = trace_data.get_column(attr_value) _label = "parent" if attr_label == "parents" else attr_label mapping_labels[_label] = "%{parent}" elif attr_name == "ids": - trace_patch[attr_name] = trace_data[attr_value] + trace_patch[attr_name] = trace_data.get_column(attr_value) _label = "id" if attr_label == "ids" else attr_label mapping_labels[_label] = "%{id}" elif attr_name == "names": @@ -494,13 +575,13 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): go.Pie, go.Funnelarea, ]: - trace_patch["labels"] = trace_data[attr_value] + trace_patch["labels"] = trace_data.get_column(attr_value) _label = "label" if attr_label == "names" else attr_label mapping_labels[_label] = "%{label}" else: - trace_patch[attr_name] = trace_data[attr_value] + trace_patch[attr_name] = trace_data.get_column(attr_value) else: - trace_patch[attr_name] = trace_data[attr_value] + trace_patch[attr_name] = trace_data.get_column(attr_value) mapping_labels[attr_label] = "%%{%s}" % attr_name elif (trace_spec.constructor == go.Histogram and attr_name in ["x", "y"]) or ( trace_spec.constructor in [go.Histogram2d, go.Histogram2dContour] @@ -1015,7 +1096,7 @@ def _get_reserved_col_names(args): as arguments, either as str/int arguments or given as columns (pandas series type). """ - df = args["data_frame"] + df: nw.DataFrame = args["data_frame"] reserved_names = set() for field in args: if field not in all_attrables: @@ -1028,26 +1109,27 @@ def _get_reserved_col_names(args): continue elif isinstance(arg, str): # no need to add ints since kw arg are not ints reserved_names.add(arg) - elif isinstance(arg, pd.Series): - arg_name = arg.name - if arg_name and hasattr(df, arg_name): - in_df = arg is df[arg_name] + elif nw.dependencies.is_into_series(arg): + arg_series = nw.from_native(arg, series_only=True) + arg_name = arg_series.name + if arg_name and arg_name in df.columns: + in_df = (arg_series == df.get_column(arg_name)).all() if in_df: reserved_names.add(arg_name) - elif arg is df.index and arg.name is not None: + elif arg is nw.maybe_get_index(df) and arg.name is not None: reserved_names.add(arg.name) return reserved_names -def _is_col_list(columns, arg): +def _is_col_list(columns, arg, is_pd_like, native_namespace): """Returns True if arg looks like it's a list of columns or references to columns in df_input, and False otherwise (in which case it's assumed to be a single column or reference to a column). """ if arg is None or isinstance(arg, str) or isinstance(arg, int): return False - if isinstance(arg, pd.MultiIndex): + if is_pd_like and isinstance(arg, native_namespace.MultiIndex): return False # just to keep existing behaviour for now try: iter(arg) @@ -1082,22 +1164,38 @@ def _isinstance_listlike(x): def _escape_col_name(columns, col_name, extra): - while columns is not None and (col_name in columns or col_name in extra): + if columns is None: + return col_name + while col_name in columns or col_name in extra: col_name = "_" + col_name return col_name -def to_unindexed_series(x, name=None): - """ - assuming x is list-like or even an existing pd.Series, return a new pd.Series with - no index, without extracting the data from an existing Series via numpy, which - seems to mangle datetime columns. Stripping the index from existing pd.Series is - required to get things to match up right in the new DataFrame we're building +def to_unindexed_series(x, name=None, native_namespace=None): + """Assuming x is list-like or even an existing Series, returns a new Series (with + its index reset if pandas-like). Stripping the index from existing pd.Series is + required to get things to match up right in the new DataFrame we're building. """ - return pd.Series(x, name=name).reset_index(drop=True) + # With `pass_through=True`, the original object will be returned if unable to convert + # to a Narwhals Series. + x = nw.from_native(x, series_only=True, pass_through=True) + if isinstance(x, nw.Series): + return nw.maybe_reset_index(x).rename(name) + elif native_namespace is not None: + return nw.new_series(name=name, values=x, native_namespace=native_namespace) + else: + try: + import pandas as pd + + return nw.new_series(name=name, values=x, native_namespace=pd) + except ImportError: + msg = "Pandas installation is required if no dataframe is provided." + raise NotImplementedError(msg) -def process_args_into_dataframe(args, wide_mode, var_name, value_name): +def process_args_into_dataframe( + args, wide_mode, var_name, value_name, is_pd_like, native_namespace +): """ After this function runs, the `all_attrables` keys of `args` all contain only references to columns of `df_output`. This function handles the extraction of data @@ -1106,7 +1204,7 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): reference. """ - df_input = args["data_frame"] + df_input: nw.DataFrame | None = args["data_frame"] df_provided = df_input is not None # we use a dict instead of a dataframe directly so that it doesn't cause @@ -1125,7 +1223,7 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): "No data were provided. Please provide data either with the `data_frame` or with the `dimensions` argument." ) else: - df_output = {col: series for col, series in df_input.items()} + df_output = {col: df_input.get_column(col) for col in df_input.columns} # hover_data is a dict hover_data_is_dict = ( @@ -1171,14 +1269,14 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): continue col_name = None # Case of multiindex - if isinstance(argument, pd.MultiIndex): + if is_pd_like and isinstance(argument, native_namespace.MultiIndex): raise TypeError( - "Argument '%s' is a pandas MultiIndex. " - "pandas MultiIndex is not supported by plotly express " - "at the moment." % field + f"Argument '{field}' is a {native_namespace.__name__} MultiIndex. " + f"{native_namespace.__name__} MultiIndex is not supported by plotly " + "express at the moment." ) # ----------------- argument is a special value ---------------------- - if isinstance(argument, Constant) or isinstance(argument, Range): + if isinstance(argument, (Constant, Range)): col_name = _check_name_not_reserved( str(argument.label) if argument.label is not None else field, reserved_names, @@ -1199,19 +1297,21 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): col_name = str(argument) real_argument = args["hover_data"][col_name][1] - if length and len(real_argument) != length: + if length and (real_length := len(real_argument)) != length: raise ValueError( "All arguments should have the same length. " "The length of hover_data key `%s` is %d, whereas the " "length of previously-processed arguments %s is %d" % ( argument, - len(real_argument), + real_length, str(list(df_output.keys())), length, ) ) - df_output[col_name] = to_unindexed_series(real_argument, col_name) + df_output[col_name] = to_unindexed_series( + real_argument, col_name, native_namespace + ) elif not df_provided: raise ValueError( "String or int arguments are only possible when a " @@ -1232,14 +1332,14 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): if argument == "index": err_msg += "\n To use the index, pass it in directly as `df.index`." raise ValueError(err_msg) - elif length and len(df_input[argument]) != length: + elif length and (actual_len := len(df_input)) != length: raise ValueError( "All arguments should have the same length. " "The length of column argument `df[%s]` is %d, whereas the " "length of previously-processed arguments %s is %d" % ( field, - len(df_input[argument]), + actual_len, str(list(df_output.keys())), length, ) @@ -1247,37 +1347,47 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): else: col_name = str(argument) df_output[col_name] = to_unindexed_series( - df_input[argument], col_name + df_input.get_column(argument), col_name ) # ----------------- argument is likely a column / array / list.... ------- else: if df_provided and hasattr(argument, "name"): - if argument is df_input.index: - if argument.name is None or argument.name in df_input: + if is_pd_like and argument is nw.maybe_get_index(df_input): + if argument.name is None or argument.name in df_input.columns: col_name = "index" else: col_name = argument.name col_name = _escape_col_name( - df_input, col_name, [var_name, value_name] + df_input.columns, col_name, [var_name, value_name] ) else: if ( argument.name is not None - and argument.name in df_input - and argument is df_input[argument.name] + and argument.name in df_input.columns + and ( + to_unindexed_series( + argument, argument.name, native_namespace + ) + == df_input.get_column(argument.name) + ).all() ): col_name = argument.name if col_name is None: # numpy array, list... col_name = _check_name_not_reserved(field, reserved_names) - if length and len(argument) != length: + if length and (len_arg := len(argument)) != length: raise ValueError( "All arguments should have the same length. " "The length of argument `%s` is %d, whereas the " "length of previously-processed arguments %s is %d" - % (field, len(argument), str(list(df_output.keys())), length) + % (field, len_arg, str(list(df_output.keys())), length) ) - df_output[str(col_name)] = to_unindexed_series(argument, str(col_name)) + + df_output[str(col_name)] = to_unindexed_series( + x=argument, + name=str(col_name), + native_namespace=native_namespace, + ) # Finally, update argument with column name now that column exists assert col_name is not None, ( @@ -1296,18 +1406,49 @@ def process_args_into_dataframe(args, wide_mode, var_name, value_name): wide_id_vars.add(str(col_name)) length = len(df_output[next(iter(df_output))]) if len(df_output) else 0 - df_output.update( - {col_name: to_unindexed_series(range(length), col_name) for col_name in ranges} - ) + + if native_namespace is None: + try: + import pandas as pd + + native_namespace = pd + except ImportError: + msg = "Pandas installation is required if no dataframe is provided." + raise NotImplementedError(msg) + + if ranges: + import numpy as np + + range_series = nw.new_series( + name="__placeholder__", + values=np.arange(length), + native_namespace=native_namespace, + ) + df_output.update( + {col_name: range_series.alias(col_name) for col_name in ranges} + ) + df_output.update( { - # constant is single value. repeat by len to avoid creating NaN on concating - col_name: to_unindexed_series([constants[col_name]] * length, col_name) + # constant is single value. repeat by len to avoid creating NaN on concatenating + col_name: nw.new_series( + name=col_name, + values=[constants[col_name]] * length, + native_namespace=native_namespace, + ) for col_name in constants } ) - df_output = pd.DataFrame(df_output) + if df_output: + df_output = nw.from_dict(df_output) + else: + try: + import pandas as pd + except ImportError: + msg = "Pandas installation is required." + raise NotImplementedError(msg) + df_output = nw.from_native(pd.DataFrame({}), eager_only=True) return df_output, wide_id_vars @@ -1341,46 +1482,131 @@ def build_dataframe(args, constructor): # Cast data_frame argument to DataFrame (it could be a numpy array, dict etc.) df_provided = args["data_frame"] is not None + + # Flag that indicates if the resulting data_frame after parsing is pandas-like + # (in terms of resulting Narwhals DataFrame). + # True if pandas, modin.pandas or cudf DataFrame/Series instance, or converted from + # PySpark to pandas. + is_pd_like = False + + # Flag that indicates if data_frame requires to be converted to arrow via the + # dataframe interchange protocol. + # True if Ibis, DuckDB, Vaex or implements __dataframe__ needs_interchanging = False - if df_provided and not isinstance(args["data_frame"], pd.DataFrame): - if hasattr(args["data_frame"], "__dataframe__") and version.parse( - pd.__version__ - ) >= version.parse("2.0.2"): - import pandas.api.interchange - - df_not_pandas = args["data_frame"] - args["data_frame"] = df_not_pandas.__dataframe__() - # According interchange protocol: `def column_names(self) -> Iterable[str]:` - # so this function can return for example a generator. - # The easiest way is to convert `columns` to `pandas.Index` so that the - # type is similar to the types in other code branches. - columns = pd.Index(args["data_frame"].column_names()) - needs_interchanging = True - elif hasattr(args["data_frame"], "to_pandas"): - args["data_frame"] = args["data_frame"].to_pandas() + + # If data_frame is provided, we parse it into a narwhals DataFrame, while accounting + # for compatibility with pandas specific paths (e.g. Index/MultiIndex case). + if df_provided: + + # data_frame is pandas-like DataFrame (pandas, modin.pandas, cudf) + if nw.dependencies.is_pandas_like_dataframe(args["data_frame"]): + + columns = args["data_frame"].columns # This can be multi index + args["data_frame"] = nw.from_native(args["data_frame"], eager_only=True) + is_pd_like = True + + # data_frame is pandas-like Series (pandas, modin.pandas, cudf) + elif nw.dependencies.is_pandas_like_series(args["data_frame"]): + + args["data_frame"] = nw.from_native( + args["data_frame"], series_only=True + ).to_frame() columns = args["data_frame"].columns - elif hasattr(args["data_frame"], "toPandas"): - args["data_frame"] = args["data_frame"].toPandas() + is_pd_like = True + + # data_frame is any other DataFrame object natively supported via Narwhals. + # With `pass_through=True`, the original object will be returned if unable to convert + # to a Narwhals DataFrame, making this condition False. + elif isinstance( + data_frame := nw.from_native( + args["data_frame"], eager_or_interchange_only=True, pass_through=True + ), + nw.DataFrame, + ): + args["data_frame"] = data_frame + needs_interchanging = nw.get_level(data_frame) == "interchange" + columns = args["data_frame"].columns + + # data_frame is any other Series object natively supported via Narwhals. + # With `pass_through=True`, the original object will be returned if unable to convert + # to a Narwhals Series, making this condition False. + elif isinstance( + series := nw.from_native( + args["data_frame"], series_only=True, pass_through=True + ), + nw.Series, + ): + args["data_frame"] = series.to_frame() columns = args["data_frame"].columns - elif hasattr(args["data_frame"], "to_pandas_df"): - args["data_frame"] = args["data_frame"].to_pandas_df() + + # data_frame is PySpark: it does not support interchange protocol and it is not + # integrated in Narwhals. We use its native method to convert it to pandas. + elif hasattr(args["data_frame"], "toPandas"): + args["data_frame"] = nw.from_native( + args["data_frame"].toPandas(), eager_only=True + ) columns = args["data_frame"].columns + is_pd_like = True + + # data_frame is some other object type (e.g. dict, list, ...) + # We try to import pandas, and then try to instantiate a pandas dataframe from + # this such object else: - args["data_frame"] = pd.DataFrame(args["data_frame"]) - columns = args["data_frame"].columns - elif df_provided: - columns = args["data_frame"].columns + try: + import pandas as pd + + try: + args["data_frame"] = nw.from_native( + pd.DataFrame(args["data_frame"]) + ) + columns = args["data_frame"].columns + is_pd_like = True + except Exception: + msg = ( + f"Unable to convert data_frame of type {type(args['data_frame'])} " + "to pandas DataFrame. Please provide a supported dataframe type " + "or a type that can be passed to pd.DataFrame." + ) + + raise NotImplementedError(msg) + except ImportError: + msg = ( + f"Attempting to convert data_frame of type {type(args['data_frame'])} " + "to pandas DataFrame, but Pandas is not installed. " + "Convert it to supported dataframe type or install pandas." + ) + raise NotImplementedError(msg) + + # data_frame is not provided else: columns = None - df_input = args["data_frame"] + df_input: nw.DataFrame | None = args["data_frame"] + index = ( + nw.maybe_get_index(df_input) + if df_provided and not needs_interchanging + else None + ) + native_namespace = ( + nw.get_native_namespace(df_input) + if df_provided and not needs_interchanging + else None + ) # now we handle special cases like wide-mode or x-xor-y specification # by rearranging args to tee things up for process_args_into_dataframe to work no_x = args.get("x") is None no_y = args.get("y") is None - wide_x = False if no_x else _is_col_list(columns, args["x"]) - wide_y = False if no_y else _is_col_list(columns, args["y"]) + wide_x = ( + False + if no_x + else _is_col_list(columns, args["x"], is_pd_like, native_namespace) + ) + wide_y = ( + False + if no_y + else _is_col_list(columns, args["y"], is_pd_like, native_namespace) + ) wide_mode = False var_name = None # will likely be "variable" in wide_mode @@ -1395,14 +1621,14 @@ def build_dataframe(args, constructor): ) if df_provided and no_x and no_y: wide_mode = True - if isinstance(columns, pd.MultiIndex): + if is_pd_like and isinstance(columns, native_namespace.MultiIndex): raise TypeError( - "Data frame columns is a pandas MultiIndex. " - "pandas MultiIndex is not supported by plotly express " - "at the moment." + f"Data frame columns is a {native_namespace.__name__} MultiIndex. " + f"{native_namespace.__name__} MultiIndex is not supported by plotly " + "express at the moment." ) args["wide_variable"] = list(columns) - if isinstance(columns, pd.Index): + if is_pd_like and isinstance(columns, native_namespace.Index): var_name = columns.name else: var_name = None @@ -1417,9 +1643,9 @@ def build_dataframe(args, constructor): elif wide_x != wide_y: wide_mode = True args["wide_variable"] = args["y"] if wide_y else args["x"] - if df_provided and args["wide_variable"] is columns: + if df_provided and is_pd_like and args["wide_variable"] is columns: var_name = columns.name - if isinstance(args["wide_variable"], pd.Index): + if is_pd_like and isinstance(args["wide_variable"], native_namespace.Index): args["wide_variable"] = list(args["wide_variable"]) if var_name in [None, "value", "index"] or ( df_provided and var_name in columns @@ -1438,42 +1664,34 @@ def build_dataframe(args, constructor): value_name = _escape_col_name(columns, "value", []) var_name = _escape_col_name(columns, var_name, []) + # If the data_frame has interchange-only support levelin Narwhals, then we need to + # convert it to a full support level backend. + # Hence we convert requires Interchange to PyArrow. if needs_interchanging: - try: - if wide_mode or not hasattr(args["data_frame"], "select_columns_by_name"): - args["data_frame"] = pd.api.interchange.from_dataframe( - args["data_frame"] - ) - else: - # Save precious resources by only interchanging columns that are - # actually going to be plotted. - necessary_columns = { - i for i in args.values() if isinstance(i, str) and i in columns - } - for field in args: - if args[field] is not None and field in array_attrables: - necessary_columns.update(i for i in args[field] if i in columns) - columns = list(necessary_columns) - args["data_frame"] = pd.api.interchange.from_dataframe( - args["data_frame"].select_columns_by_name(columns) - ) - except (ImportError, NotImplementedError) as exc: - # temporary workaround; developers of third-party libraries themselves - # should try a different implementation, if available. For example: - # def __dataframe__(self, ...): - # if not some_condition: - # self.to_pandas(...) - if hasattr(df_not_pandas, "toPandas"): - args["data_frame"] = df_not_pandas.toPandas() - elif hasattr(df_not_pandas, "to_pandas_df"): - args["data_frame"] = df_not_pandas.to_pandas_df() - elif hasattr(df_not_pandas, "to_pandas"): - args["data_frame"] = df_not_pandas.to_pandas() - else: - raise exc - - df_input = args["data_frame"] + if wide_mode: + args["data_frame"] = nw.from_native( + args["data_frame"].to_arrow(), eager_only=True + ) + else: + # Save precious resources by only interchanging columns that are + # actually going to be plotted. This is tricky to do in the general case, + # because Plotly allows calls like `px.line(df, x='x', y=['y1', df['y1']])`, + # but interchange-only objects (e.g. DuckDB) don't typically have a concept + # of self-standing Series. It's more important to perform project pushdown + # here seeing as we're materialising to an (eager) PyArrow table. + necessary_columns = { + i for i in args.values() if isinstance(i, str) and i in columns + } + for field in args: + if args[field] is not None and field in array_attrables: + necessary_columns.update(i for i in args[field] if i in columns) + columns = list(necessary_columns) + args["data_frame"] = nw.from_native( + args["data_frame"].select(columns).to_arrow(), eager_only=True + ) + import pyarrow as pa + native_namespace = pa missing_bar_dim = None if ( constructor in [go.Scatter, go.Bar, go.Funnel] + hist2d_types @@ -1482,7 +1700,13 @@ def build_dataframe(args, constructor): if not wide_mode and (no_x != no_y): for ax in ["x", "y"]: if args.get(ax) is None: - args[ax] = df_input.index if df_provided else Range() + args[ax] = ( + index + if index is not None + else Range( + label=_escape_col_name(columns, ax, [var_name, value_name]) + ) + ) if constructor == go.Bar: missing_bar_dim = ax else: @@ -1491,34 +1715,39 @@ def build_dataframe(args, constructor): if wide_mode and wide_cross_name is None: if no_x != no_y and args["orientation"] is None: args["orientation"] = "v" if no_x else "h" - if df_provided: - if isinstance(df_input.index, pd.MultiIndex): + if df_provided and is_pd_like and index is not None: + if isinstance(index, native_namespace.MultiIndex): raise TypeError( - "Data frame index is a pandas MultiIndex. " - "pandas MultiIndex is not supported by plotly express " - "at the moment." + f"Data frame index is a {native_namespace.__name__} MultiIndex. " + f"{native_namespace.__name__} MultiIndex is not supported by " + "plotly express at the moment." ) - args["wide_cross"] = df_input.index + args["wide_cross"] = index else: args["wide_cross"] = Range( - label=_escape_col_name(df_input, "index", [var_name, value_name]) + label=_escape_col_name(columns, "index", [var_name, value_name]) ) no_color = False - if type(args.get("color")) == str and args["color"] == NO_COLOR: + if isinstance(args.get("color"), str) and args["color"] == NO_COLOR: no_color = True args["color"] = None # now that things have been prepped, we do the systematic rewriting of `args` df_output, wide_id_vars = process_args_into_dataframe( - args, wide_mode, var_name, value_name + args, + wide_mode, + var_name, + value_name, + is_pd_like, + native_namespace, ) - + df_output: nw.DataFrame # now that `df_output` exists and `args` contains only references, we complete # the special-case and wide-mode handling by further rewriting args and/or mutating # df_output - count_name = _escape_col_name(df_output, "count", [var_name, value_name]) + count_name = _escape_col_name(df_output.columns, "count", [var_name, value_name]) if not wide_mode and missing_bar_dim and constructor == go.Bar: # now that we've populated df_output, we check to see if the non-missing # dimension is categorical: if so, then setting the missing dimension to a @@ -1527,7 +1756,7 @@ def build_dataframe(args, constructor): other_dim = "x" if missing_bar_dim == "y" else "y" if not _is_continuous(df_output, args[other_dim]): args[missing_bar_dim] = count_name - df_output[count_name] = 1 + df_output = df_output.with_columns(nw.lit(1).alias(count_name)) else: # on the other hand, if the non-missing dimension is continuous, then we # can use this information to override the normal auto-orientation code @@ -1552,18 +1781,18 @@ def build_dataframe(args, constructor): del args["wide_cross"] dtype = None for v in wide_value_vars: - v_dtype = df_output[v].dtype.kind - v_dtype = "number" if v_dtype in ["i", "f", "u"] else v_dtype + v_dtype = df_output.get_column(v).dtype + v_dtype = "number" if v_dtype.is_numeric() else str(v_dtype) if dtype is None: dtype = v_dtype elif dtype != v_dtype: raise ValueError( "Plotly Express cannot process wide-form data with columns of different type." ) - df_output = df_output.melt( - id_vars=wide_id_vars, - value_vars=wide_value_vars, - var_name=var_name, + df_output = df_output.unpivot( + index=wide_id_vars, + on=wide_value_vars, + variable_name=var_name, value_name=value_name, ) assert len(df_output.columns) == len(set(df_output.columns)), ( @@ -1572,7 +1801,7 @@ def build_dataframe(args, constructor): "https://github.com/plotly/plotly.py/issues/new and we will try to " "replicate and fix it." ) - df_output[var_name] = df_output[var_name].astype(str) + df_output = df_output.with_columns(nw.col(var_name).cast(nw.String)) orient_v = wide_orientation == "v" if hist1d_orientation: @@ -1594,7 +1823,7 @@ def build_dataframe(args, constructor): else: args["x" if orient_v else "y"] = value_name args["y" if orient_v else "x"] = count_name - df_output[count_name] = 1 + df_output = df_output.with_columns(nw.lit(1).alias(count_name)) args["color"] = args["color"] or var_name elif constructor in [go.Violin, go.Box]: args["x" if orient_v else "y"] = wide_cross_name or var_name @@ -1607,12 +1836,12 @@ def build_dataframe(args, constructor): args["histfunc"] = None args["orientation"] = "h" args["x"] = count_name - df_output[count_name] = 1 + df_output = df_output.with_columns(nw.lit(1).alias(count_name)) else: args["histfunc"] = None args["orientation"] = "v" args["y"] = count_name - df_output[count_name] = 1 + df_output = df_output.with_columns(nw.lit(1).alias(count_name)) if no_color: args["color"] = None @@ -1620,26 +1849,57 @@ def build_dataframe(args, constructor): return args -def _check_dataframe_all_leaves(df): - df_sorted = df.sort_values(by=list(df.columns)) - null_mask = df_sorted.isnull() - df_sorted = df_sorted.astype(str) - null_indices = np.nonzero(null_mask.any(axis=1).values)[0] - for null_row_index in null_indices: - row = null_mask.iloc[null_row_index] - i = np.nonzero(row.values)[0][0] - if not row[i:].all(): +def _check_dataframe_all_leaves(df: nw.DataFrame) -> None: + cols = df.columns + df_sorted = df.sort(by=cols, descending=False, nulls_last=True) + null_mask = df_sorted.select(nw.all().is_null()) + df_sorted = df_sorted.select(nw.all().cast(nw.String())) + null_indices_mask = null_mask.select( + null_mask=nw.any_horizontal(nw.all()) + ).get_column("null_mask") + + for row_idx, row in zip( + null_indices_mask, null_mask.filter(null_indices_mask).iter_rows() + ): + + i = row.index(True) + + if not all(row[i:]): raise ValueError( "None entries cannot have not-None children", - df_sorted.iloc[null_row_index], + df_sorted.row(row_idx), ) - df_sorted[null_mask] = "" - row_strings = list(df_sorted.apply(lambda x: "".join(x), axis=1)) - for i, row in enumerate(row_strings[:-1]): - if row_strings[i + 1] in row and (i + 1) in null_indices: + + fill_series = nw.new_series( + name="fill_value", + values=[""] * len(df_sorted), + dtype=nw.String(), + native_namespace=nw.get_native_namespace(df_sorted), + ) + df_sorted = df_sorted.with_columns( + **{ + c: df_sorted.get_column(c).zip_with(~null_mask.get_column(c), fill_series) + for c in cols + } + ) + + # Conversion to list is due to python native vs pyarrow scalars + row_strings = ( + df_sorted.select( + row_strings=nw.concat_str(cols, separator="", ignore_nulls=False) + ) + .get_column("row_strings") + .to_list() + ) + + null_indices = set(null_indices_mask.arg_true().to_list()) + for i, (current_row, next_row) in enumerate( + zip(row_strings[:-1], row_strings[1:]), start=1 + ): + if (next_row in current_row) and (i in null_indices): raise ValueError( "Non-leaves rows are not permitted in the dataframe \n", - df_sorted.iloc[i + 1], + df_sorted.row(i), "is not a leaf.", ) @@ -1648,109 +1908,199 @@ def process_dataframe_hierarchy(args): """ Build dataframe for sunburst, treemap, or icicle when the path argument is provided. """ - df = args["data_frame"] + df: nw.DataFrame = args["data_frame"] path = args["path"][::-1] _check_dataframe_all_leaves(df[path[::-1]]) - discrete_color = False + discrete_color = not _is_continuous(df, args["color"]) if args["color"] else False + + df = df.lazy() - new_path = [] - for col_name in path: - new_col_name = col_name + "_path_copy" - new_path.append(new_col_name) - df[new_col_name] = df[col_name] + new_path = [col_name + "_path_copy" for col_name in path] + df = df.with_columns( + nw.col(col_name).alias(new_col_name) + for new_col_name, col_name in zip(new_path, path) + ) path = new_path # ------------ Define aggregation functions -------------------------------- - - def aggfunc_discrete(x): - uniques = x.unique() - if len(uniques) == 1: - return uniques[0] - else: - return "(?)" - agg_f = {} - aggfunc_color = None if args["values"]: try: - df[args["values"]] = pd.to_numeric(df[args["values"]]) - except ValueError: + df = df.with_columns(nw.col(args["values"]).cast(nw.Float64())) + + except Exception: # pandas, Polars and pyarrow exception types are different raise ValueError( "Column `%s` of `df` could not be converted to a numerical data type." % args["values"] ) - if args["color"]: - if args["color"] == args["values"]: - new_value_col_name = args["values"] + "_sum" - df[new_value_col_name] = df[args["values"]] - args["values"] = new_value_col_name + if args["color"] and args["color"] == args["values"]: + new_value_col_name = args["values"] + "_sum" + df = df.with_columns(nw.col(args["values"]).alias(new_value_col_name)) + args["values"] = new_value_col_name count_colname = args["values"] else: # we need a count column for the first groupby and the weighted mean of color # trick to be sure the col name is unused: take the sum of existing names + columns = df.collect_schema().names() count_colname = ( - "count" - if "count" not in df.columns - else "".join([str(el) for el in list(df.columns)]) + "count" if "count" not in columns else "".join([str(el) for el in columns]) ) # we can modify df because it's a copy of the px argument - df[count_colname] = 1 + df = df.with_columns(nw.lit(1).alias(count_colname)) args["values"] = count_colname - agg_f[count_colname] = "sum" + + # Since count_colname is always in agg_f, it can be used later to normalize color + # in the continuous case after some gymnastic + agg_f[count_colname] = nw.sum(count_colname) + + discrete_aggs = [] + continuous_aggs = [] + + n_unique_token = _generate_temporary_column_name( + n_bytes=16, columns=df.collect_schema().names() + ) + + # In theory, for discrete columns aggregation, we should have a way to do + # `.agg(nw.col(x).unique())` in group_by and successively unpack/parse it as: + # ``` + # (nw.when(nw.col(x).list.len()==1) + # .then(nw.col(x).list.first()) + # .otherwise(nw.lit("(?)")) + # ) + # ``` + # which replicates the original pandas only codebase: + # ``` + # def discrete_agg(x): + # uniques = x.unique() + # return uniques[0] if len(uniques) == 1 else "(?)" + # + # df.groupby(path[i:]).agg(...) + # ``` + # However this is not possible, therefore the following workaround is provided. + # We make two aggregations for the same column: + # - take the max value + # - take the number of unique values + # Finally, after the group by statement, it is unpacked via: + # ``` + # (nw.when(nw.col(col_n_unique) == 1) + # .then(nw.col(col_max_value)) # which is the unique value + # .otherwise(nw.lit("(?)")) + # ) + # ``` if args["color"]: - if not _is_continuous(df, args["color"]): - aggfunc_color = aggfunc_discrete - discrete_color = True + if discrete_color: + + discrete_aggs.append(args["color"]) + agg_f[args["color"]] = nw.col(args["color"]).max() + agg_f[f'{args["color"]}{n_unique_token}'] = ( + nw.col(args["color"]) + .n_unique() + .alias(f'{args["color"]}{n_unique_token}') + ) else: + # This first needs to be multiplied by `count_colname` + continuous_aggs.append(args["color"]) - def aggfunc_continuous(x): - return np.average(x, weights=df.loc[x.index, count_colname]) - - aggfunc_color = aggfunc_continuous - agg_f[args["color"]] = aggfunc_color + agg_f[args["color"]] = nw.sum(args["color"]) # Other columns (for color, hover_data, custom_data etc.) - cols = list(set(df.columns).difference(path)) + cols = list(set(df.collect_schema().names()).difference(path)) + df = df.with_columns(nw.col(c).cast(nw.String()) for c in cols if c not in agg_f) + for col in cols: # for hover_data, custom_data etc. if col not in agg_f: - agg_f[col] = aggfunc_discrete + # Similar trick as above + discrete_aggs.append(col) + agg_f[col] = nw.col(col).max() + agg_f[f"{col}{n_unique_token}"] = ( + nw.col(col).n_unique().alias(f"{col}{n_unique_token}") + ) # Avoid collisions with reserved names - columns in the path have been copied already cols = list(set(cols) - set(["labels", "parent", "id"])) # ---------------------------------------------------------------------------- - df_all_trees = pd.DataFrame(columns=["labels", "parent", "id"] + cols) - # Set column type here (useful for continuous vs discrete colorscale) - for col in cols: - df_all_trees[col] = df_all_trees[col].astype(df[col].dtype) + all_trees = [] + + if args["color"] and not discrete_color: + df = df.with_columns( + (nw.col(args["color"]) * nw.col(count_colname)).alias(args["color"]) + ) + + def post_agg(dframe: nw.LazyFrame, continuous_aggs, discrete_aggs) -> nw.LazyFrame: + """ + - continuous_aggs is either [] or [args["color"]] + - discrete_aggs is either [args["color"], ] or [] + """ + return dframe.with_columns( + *[nw.col(col) / nw.col(count_colname) for col in continuous_aggs], + *[ + ( + nw.when(nw.col(f"{col}{n_unique_token}") == 1) + .then(nw.col(col)) + .otherwise(nw.lit("(?)")) + .alias(col) + ) + for col in discrete_aggs + ], + ).drop([f"{col}{n_unique_token}" for col in discrete_aggs]) + for i, level in enumerate(path): - df_tree = pd.DataFrame(columns=df_all_trees.columns) - dfg = df.groupby(path[i:]).agg(agg_f) - dfg = dfg.reset_index() + + dfg = ( + df.group_by(path[i:], drop_null_keys=True) + .agg(**agg_f) + .pipe(post_agg, continuous_aggs, discrete_aggs) + ) + # Path label massaging - df_tree["labels"] = dfg[level].copy().astype(str) - df_tree["parent"] = "" - df_tree["id"] = dfg[level].copy().astype(str) + df_tree = dfg.with_columns( + *cols, + labels=nw.col(level).cast(nw.String()), + parent=nw.lit(""), + id=nw.col(level).cast(nw.String()), + ) if i < len(path) - 1: - j = i + 1 - while j < len(path): - df_tree["parent"] = ( - dfg[path[j]].copy().astype(str) + "/" + df_tree["parent"] + _concat_str_token = _generate_temporary_column_name( + n_bytes=16, columns=[*cols, "labels", "parent", "id"] + ) + df_tree = ( + df_tree.with_columns( + nw.concat_str( + [ + nw.col(path[j]).cast(nw.String()) + for j in range(len(path) - 1, i, -1) + ], + separator="/", + ).alias(_concat_str_token) ) - df_tree["id"] = dfg[path[j]].copy().astype(str) + "/" + df_tree["id"] - j += 1 + .with_columns( + parent=nw.concat_str( + [nw.col(_concat_str_token), nw.col("parent")], separator="/" + ), + id=nw.concat_str( + [nw.col(_concat_str_token), nw.col("id")], separator="/" + ), + ) + .drop(_concat_str_token) + ) - df_tree["parent"] = df_tree["parent"].str.rstrip("/") - if cols: - df_tree[cols] = dfg[cols] - df_all_trees = pd.concat([df_all_trees, df_tree], ignore_index=True) + # strip "/" if at the end of the string, equivalent to `.str.rstrip` + df_tree = df_tree.with_columns( + parent=nw.col("parent").str.replace("/?$", "").str.replace("^/?", "") + ) + + all_trees.append(df_tree.select(*["labels", "parent", "id", *cols])) + + df_all_trees = nw.maybe_reset_index(nw.concat(all_trees, how="vertical").collect()) # we want to make sure than (?) is the first color of the sequence if args["color"] and discrete_color: sort_col_name = "sort_color_if_discrete_color" while sort_col_name in df_all_trees.columns: sort_col_name += "0" - df_all_trees[sort_col_name] = df[args["color"]].astype(str) - df_all_trees = df_all_trees.sort_values(by=sort_col_name) + df_all_trees = df_all_trees.with_columns( + nw.col(args["color"]).cast(nw.String()).alias(sort_col_name) + ).sort(by=sort_col_name, nulls_last=True) # Now modify arguments args["data_frame"] = df_all_trees @@ -1778,21 +2128,25 @@ def process_dataframe_timeline(args): raise ValueError("Both x_start and x_end are required") try: - x_start = pd.to_datetime(args["data_frame"][args["x_start"]]) - x_end = pd.to_datetime(args["data_frame"][args["x_end"]]) - except (ValueError, TypeError): + df: nw.DataFrame = args["data_frame"] + df = df.with_columns( + nw.col(args["x_start"]).str.to_datetime().alias(args["x_start"]), + nw.col(args["x_end"]).str.to_datetime().alias(args["x_end"]), + ) + except Exception: raise TypeError( "Both x_start and x_end must refer to data convertible to datetimes." ) # note that we are not adding any columns to the data frame here, so no risk of overwrite - args["data_frame"][args["x_end"]] = (x_end - x_start).astype( - "timedelta64[ns]" - ) / np.timedelta64(1, "ms") + args["data_frame"] = df.with_columns( + (nw.col(args["x_end"]) - nw.col(args["x_start"])) + .dt.total_milliseconds() + .alias(args["x_end"]) + ) args["x"] = args["x_end"] - del args["x_end"] args["base"] = args["x_start"] - del args["x_start"] + del args["x_start"], args["x_end"] return args @@ -1803,23 +2157,37 @@ def process_dataframe_pie(args, trace_patch): order_in = args["category_orders"].get(names, {}).copy() if not order_in: return args, trace_patch - df = args["data_frame"] + df: nw.DataFrame = args["data_frame"] trace_patch["sort"] = False trace_patch["direction"] = "clockwise" - uniques = list(df[names].unique()) + uniques = df.get_column(names).unique(maintain_order=True).to_list() order = [x for x in OrderedDict.fromkeys(list(order_in) + uniques) if x in uniques] - args["data_frame"] = df.set_index(names).loc[order].reset_index() + + # Sort args['data_frame'] by column 'b' according to order `order`. + token = nw.generate_temporary_column_name(8, df.columns) + args["data_frame"] = ( + df.with_columns( + nw.col("b") + .replace_strict(order, range(len(order)), return_dtype=nw.UInt32) + .alias(token) + ) + .sort(token) + .drop(token) + ) return args, trace_patch def infer_config(args, constructor, trace_patch, layout_patch): attrs = [k for k in direct_attrables + array_attrables if k in args] grouped_attrs = [] + df: nw.DataFrame = args["data_frame"] # Compute sizeref sizeref = 0 if "size" in args and args["size"]: - sizeref = args["data_frame"][args["size"]].max() / args["size_max"] ** 2 + sizeref = ( + nw.to_py_scalar(df.get_column(args["size"]).max()) / args["size_max"] ** 2 + ) # Compute color attributes and grouping attributes if "color" in args: @@ -1827,7 +2195,7 @@ def infer_config(args, constructor, trace_patch, layout_patch): if "color_discrete_sequence" not in args: attrs.append("color") else: - if args["color"] and _is_continuous(args["data_frame"], args["color"]): + if args["color"] and _is_continuous(df, args["color"]): attrs.append("color") args["color_is_continuous"] = True elif constructor in [go.Sunburst, go.Treemap, go.Icicle]: @@ -1882,8 +2250,8 @@ def infer_config(args, constructor, trace_patch, layout_patch): args["orientation"] = "h" if args["orientation"] is None and has_x and has_y: - x_is_continuous = _is_continuous(args["data_frame"], args["x"]) - y_is_continuous = _is_continuous(args["data_frame"], args["y"]) + x_is_continuous = _is_continuous(df, args["x"]) + y_is_continuous = _is_continuous(df, args["y"]) if x_is_continuous and not y_is_continuous: args["orientation"] = "h" if y_is_continuous and not x_is_continuous: @@ -1991,7 +2359,7 @@ def infer_config(args, constructor, trace_patch, layout_patch): args[other_position] = None # Ignore facet rows and columns when data frame is empty so as to prevent nrows/ncols equaling 0 - if len(args["data_frame"]) == 0: + if df.is_empty(): args["facet_row"] = args["facet_col"] = None # If both marginals and faceting are specified, faceting wins @@ -2028,9 +2396,7 @@ def infer_config(args, constructor, trace_patch, layout_patch): args["histnorm"] = args["ecdfnorm"] # Compute applicable grouping attributes - for k in group_attrables: - if k in args: - grouped_attrs.append(k) + grouped_attrs.extend([k for k in group_attrables if k in args]) # Create grouped mappings grouped_mappings = [make_mapping(args, a) for a in grouped_attrs] @@ -2052,16 +2418,20 @@ def get_groups_and_orders(args, grouper): of a single dimension-group """ orders = {} if "category_orders" not in args else args["category_orders"].copy() - + df: nw.DataFrame = args["data_frame"] # figure out orders and what the single group name would be if there were one single_group_name = [] unique_cache = dict() - for col in grouper: + grp_to_idx = dict() + + for i, col in enumerate(grouper): if col == one_group: single_group_name.append("") else: if col not in unique_cache: - unique_cache[col] = list(args["data_frame"][col].unique()) + unique_cache[col] = ( + df.get_column(col).unique(maintain_order=True).to_list() + ) uniques = unique_cache[col] if len(uniques) == 1: single_group_name.append(uniques[0]) @@ -2069,19 +2439,16 @@ def get_groups_and_orders(args, grouper): orders[col] = uniques else: orders[col] = list(OrderedDict.fromkeys(list(orders[col]) + uniques)) - df = args["data_frame"] + + grp_to_idx = {k: i for i, k in enumerate(orders)} + if len(single_group_name) == len(grouper): # we have a single group, so we can skip all group-by operations! groups = {tuple(single_group_name): df} else: - required_grouper = [g for g in grouper if g != one_group] - grouped = df.groupby( - required_grouper, sort=False, observed=True - ) # skip one_group groupers - group_indices = grouped.indices - sorted_group_names = [ - g if len(required_grouper) != 1 else (g,) for g in group_indices - ] + required_grouper = list(orders.keys()) + grouped = dict(df.group_by(required_grouper, drop_null_keys=True).__iter__()) + sorted_group_names = list(grouped.keys()) for i, col in reversed(list(enumerate(required_grouper))): sorted_group_names = sorted( @@ -2090,22 +2457,19 @@ def get_groups_and_orders(args, grouper): ) # calculate the full group_names by inserting "" in the tuple index for one_group groups - full_sorted_group_names = [list(t) for t in sorted_group_names] - for i, col in enumerate(grouper): - if col == one_group: - for g in full_sorted_group_names: - g.insert(i, "") - full_sorted_group_names = [tuple(g) for g in full_sorted_group_names] - - groups = {} - for sf, s in zip(full_sorted_group_names, sorted_group_names): - if len(s) > 1: - groups[sf] = grouped.get_group(s) - else: - if pandas_2_2_0: - groups[sf] = grouped.get_group((s[0],)) - else: - groups[sf] = grouped.get_group(s[0]) + full_sorted_group_names = [ + tuple( + [ + "" if col == one_group else sub_group_names[grp_to_idx[col]] + for col in grouper + ] + ) + for sub_group_names in sorted_group_names + ] + + groups = { + sf: grouped[s] for sf, s in zip(full_sorted_group_names, sorted_group_names) + } return groups, orders @@ -2280,19 +2644,21 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): base = args["x"] if args["orientation"] == "v" else args["y"] var = args["x"] if args["orientation"] == "h" else args["y"] ascending = args.get("ecdfmode", "standard") != "reversed" - group = group.sort_values(by=base, ascending=ascending) - group_sum = group[var].sum() # compute here before next line mutates - group[var] = group[var].cumsum() + group = group.sort(by=base, descending=not ascending, nulls_last=True) + group_sum = group.get_column( + var + ).sum() # compute here before next line mutates + group = group.with_columns(nw.col(var).cum_sum().alias(var)) if not ascending: - group = group.sort_values(by=base, ascending=True) + group = group.sort(by=base, descending=False, nulls_last=True) if args.get("ecdfmode", "standard") == "complementary": - group[var] = group_sum - group[var] + group = group.with_columns((group_sum - nw.col(var)).alias(var)) if args["ecdfnorm"] == "probability": - group[var] = group[var] / group_sum + group = group.with_columns(nw.col(var) / group_sum) elif args["ecdfnorm"] == "percent": - group[var] = 100.0 * group[var] / group_sum + group = group.with_columns((nw.col(var) / group_sum) * 100.0) patch, fit_results = make_trace_kwargs( args, trace_spec, group, mapping_labels.copy(), sizeref @@ -2408,7 +2774,16 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): if fit_results is not None: trendline_rows.append(dict(px_fit_results=fit_results)) - fig._px_trendlines = pd.DataFrame(trendline_rows) + if trendline_rows: + try: + import pandas as pd + + fig._px_trendlines = pd.DataFrame(trendline_rows) + except ImportError: + msg = "Trendlines require pandas to be installed." + raise NotImplementedError(msg) + else: + fig._px_trendlines = [] configure_axes(args, constructor, fig, orders) configure_animation_controls(args, constructor, fig) diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 8e4678101fc..8754e5265b3 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -503,7 +503,7 @@ zoom=["int (default `8`)", "Between 0 and 20.", "Sets map zoom level."], orientation=[ "str, one of `'h'` for horizontal or `'v'` for vertical. ", - "(default `'v'` if `x` and `y` are provided and both continous or both categorical, ", + "(default `'v'` if `x` and `y` are provided and both continuous or both categorical, ", "otherwise `'v'`(`'h'`) if `x`(`y`) is categorical and `y`(`x`) is continuous, ", "otherwise `'v'`(`'h'`) if only `x`(`y`) is provided) ", ], diff --git a/packages/python/plotly/plotly/express/_imshow.py b/packages/python/plotly/plotly/express/_imshow.py index de0e22284b4..ce6ddb84286 100644 --- a/packages/python/plotly/plotly/express/_imshow.py +++ b/packages/python/plotly/plotly/express/_imshow.py @@ -2,7 +2,7 @@ from _plotly_utils.basevalidators import ColorscaleValidator from ._core import apply_default_cascade, init_figure, configure_animation_controls from .imshow_utils import rescale_intensity, _integer_ranges, _integer_types -import pandas as pd +import narwhals.stable.v1 as nw import numpy as np import itertools from plotly.utils import image_array_to_data_uri @@ -321,7 +321,8 @@ def imshow( aspect = "equal" # --- Set the value of binary_string (forbidden for pandas) - if isinstance(img, pd.DataFrame): + img = nw.from_native(img, pass_through=True) + if isinstance(img, nw.DataFrame): if binary_string: raise ValueError("Binary strings cannot be used with pandas arrays") is_dataframe = True diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py index 4a1faf70b28..18ff219979a 100644 --- a/packages/python/plotly/plotly/express/trendline_functions/__init__.py +++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py @@ -8,9 +8,6 @@ exposed as part of the public API for documentation purposes. """ -import pandas as pd -import numpy as np - __all__ = ["ols", "lowess", "rolling", "ewm", "expanding"] @@ -32,6 +29,8 @@ def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing): respect to the base 10 logarithm of the input. Note that this means no zeros can be present in the input. """ + import numpy as np + valid_options = ["add_constant", "log_x", "log_y"] for k in trendline_options.keys(): if k not in valid_options: @@ -110,11 +109,25 @@ def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing): def _pandas(mode, trendline_options, x_raw, y, non_missing): + import numpy as np + + try: + import pandas as pd + except ImportError: + msg = "Trendline requires pandas to be installed" + raise ImportError(msg) + modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding") trendline_options = trendline_options.copy() function_name = trendline_options.pop("function", "mean") function_args = trendline_options.pop("function_args", dict()) - series = pd.Series(y, index=x_raw) + + series = pd.Series(np.copy(y), index=x_raw.to_pandas()) + + # TODO: Narwhals Series/DataFrame do not support rolling, ewm nor expanding, therefore + # it fallbacks to pandas Series independently of the original type. + # Plotly issue: https://github.com/plotly/plotly.py/issues/4834 + # Narwhals issue: https://github.com/narwhals-dev/narwhals/issues/1254 agg = getattr(series, mode) # e.g. series.rolling agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts) function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean diff --git a/packages/python/plotly/plotly/figure_factory/_hexbin_mapbox.py b/packages/python/plotly/plotly/figure_factory/_hexbin_mapbox.py index 38a81d1afc3..c76352248b0 100644 --- a/packages/python/plotly/plotly/figure_factory/_hexbin_mapbox.py +++ b/packages/python/plotly/plotly/figure_factory/_hexbin_mapbox.py @@ -1,8 +1,8 @@ from plotly.express._core import build_dataframe from plotly.express._doc import make_docstring from plotly.express._chart_types import choropleth_mapbox, scatter_mapbox +import narwhals.stable.v1 as nw import numpy as np -import pandas as pd def _project_latlon_to_wgs84(lat, lon): @@ -231,6 +231,7 @@ def _compute_wgs84_hexbin( nx=None, agg_func=None, min_count=None, + native_namespace=None, ): """ Computes the lat-lon aggregation at hexagonal bin level. @@ -263,7 +264,7 @@ def _compute_wgs84_hexbin( Lat coordinates of each hexagon (shape M x 6) np.ndarray Lon coordinates of each hexagon (shape M x 6) - pd.Series + nw.Series Unique id for each hexagon, to be used in the geojson data (shape M) np.ndarray Aggregated value in each hexagon (shape M) @@ -288,7 +289,14 @@ def _compute_wgs84_hexbin( # Create unique feature id based on hexagon center centers = centers.astype(str) - hexagons_ids = pd.Series(centers[:, 0]) + "," + pd.Series(centers[:, 1]) + hexagons_ids = ( + nw.from_dict( + {"x1": centers[:, 0], "x2": centers[:, 1]}, + native_namespace=native_namespace, + ) + .select(hexagons_ids=nw.concat_str([nw.col("x1"), nw.col("x2")], separator=",")) + .get_column("hexagons_ids") + ) return hexagons_lats, hexagons_lons, hexagons_ids, agreggated_value @@ -344,22 +352,40 @@ def create_hexbin_mapbox( Returns a figure aggregating scattered points into connected hexagons """ args = build_dataframe(args=locals(), constructor=None) - + native_namespace = nw.get_native_namespace(args["data_frame"]) if agg_func is None: agg_func = np.mean - lat_range = args["data_frame"][args["lat"]].agg(["min", "max"]).values - lon_range = args["data_frame"][args["lon"]].agg(["min", "max"]).values + lat_range = ( + args["data_frame"] + .select( + nw.min(args["lat"]).name.suffix("_min"), + nw.max(args["lat"]).name.suffix("_max"), + ) + .to_numpy() + .squeeze() + ) + + lon_range = ( + args["data_frame"] + .select( + nw.min(args["lon"]).name.suffix("_min"), + nw.max(args["lon"]).name.suffix("_max"), + ) + .to_numpy() + .squeeze() + ) hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_wgs84_hexbin( - lat=args["data_frame"][args["lat"]].values, - lon=args["data_frame"][args["lon"]].values, + lat=args["data_frame"].get_column(args["lat"]).to_numpy(), + lon=args["data_frame"].get_column(args["lon"]).to_numpy(), lat_range=lat_range, lon_range=lon_range, color=None, nx=nx_hexagon, agg_func=agg_func, min_count=min_count, + native_namespace=native_namespace, ) geojson = _hexagons_to_geojson(hexagons_lats, hexagons_lons, hexagons_ids) @@ -381,41 +407,47 @@ def create_hexbin_mapbox( center = dict(lat=lat_range.mean(), lon=lon_range.mean()) if args["animation_frame"] is not None: - groups = args["data_frame"].groupby(args["animation_frame"]).groups + groups = dict( + args["data_frame"] + .group_by(args["animation_frame"], drop_null_keys=True) + .__iter__() + ) else: - groups = {0: args["data_frame"].index} + groups = {(0,): args["data_frame"]} agg_data_frame_list = [] - for frame, index in groups.items(): - df = args["data_frame"].loc[index] + for key, df in groups.items(): _, _, hexagons_ids, aggregated_value = _compute_wgs84_hexbin( - lat=df[args["lat"]].values, - lon=df[args["lon"]].values, + lat=df.get_column(args["lat"]).to_numpy(), + lon=df.get_column(args["lon"]).to_numpy(), lat_range=lat_range, lon_range=lon_range, - color=df[args["color"]].values if args["color"] else None, + color=df.get_column(args["color"]).to_numpy() if args["color"] else None, nx=nx_hexagon, agg_func=agg_func, min_count=min_count, + native_namespace=native_namespace, ) agg_data_frame_list.append( - pd.DataFrame( - np.c_[hexagons_ids, aggregated_value], columns=["locations", "color"] + nw.from_dict( + { + "frame": [key[0]] * len(hexagons_ids), + "locations": hexagons_ids, + "color": aggregated_value, + }, + native_namespace=native_namespace, ) ) - agg_data_frame = ( - pd.concat(agg_data_frame_list, axis=0, keys=groups.keys()) - .rename_axis(index=("frame", "index")) - .reset_index("frame") - ) - agg_data_frame["color"] = pd.to_numeric(agg_data_frame["color"]) + agg_data_frame = nw.concat(agg_data_frame_list, how="vertical").with_columns( + color=nw.col("color").cast(nw.Int64) + ) if range_color is None: range_color = [agg_data_frame["color"].min(), agg_data_frame["color"].max()] fig = choropleth_mapbox( - data_frame=agg_data_frame, + data_frame=agg_data_frame.to_native(), geojson=geojson, locations="locations", color="color", @@ -440,10 +472,12 @@ def create_hexbin_mapbox( if show_original_data: original_fig = scatter_mapbox( data_frame=( - args["data_frame"].sort_values(by=args["animation_frame"]) + args["data_frame"].sort( + by=args["animation_frame"], descending=False, nulls_last=True + ) if args["animation_frame"] is not None else args["data_frame"] - ), + ).to_native(), lat=args["lat"], lon=args["lon"], animation_frame=args["animation_frame"], diff --git a/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py b/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py index 52c16bf09c9..077a39dcb17 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_figure_factory/test_figure_factory.py @@ -4475,7 +4475,7 @@ def test_build_dataframe(self): lon = np.random.randn(N) color = np.ones(N) frame = np.random.randint(0, n_frames, N) - df = pd.DataFrame( + df = pd.DataFrame( # TODO: Test other constructors? np.c_[lat, lon, color, frame], columns=["Latitude", "Longitude", "Metric", "Frame"], ) diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/conftest.py b/packages/python/plotly/plotly/tests/test_optional/test_px/conftest.py new file mode 100644 index 00000000000..b207fb29a80 --- /dev/null +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/conftest.py @@ -0,0 +1,48 @@ +import pandas as pd +import polars as pl +import pyarrow as pa +import pytest + +from narwhals.typing import IntoDataFrame +from narwhals.utils import parse_version + + +def pandas_constructor(obj) -> IntoDataFrame: + return pd.DataFrame(obj) # type: ignore[no-any-return] + + +def pandas_nullable_constructor(obj) -> IntoDataFrame: + return pd.DataFrame(obj).convert_dtypes(dtype_backend="numpy_nullable") # type: ignore[no-any-return] + + +def pandas_pyarrow_constructor(obj) -> IntoDataFrame: + return pd.DataFrame(obj).convert_dtypes(dtype_backend="pyarrow") # type: ignore[no-any-return] + + +def polars_eager_constructor(obj) -> IntoDataFrame: + return pl.DataFrame(obj) + + +def pyarrow_table_constructor(obj) -> IntoDataFrame: + return pa.table(obj) # type: ignore[no-any-return] + + +constructors = [polars_eager_constructor, pyarrow_table_constructor, pandas_constructor] + +if parse_version(pd.__version__) >= parse_version("2.0.0"): + constructors.extend( + [ + pandas_nullable_constructor, + pandas_pyarrow_constructor, + ] + ) + + +@pytest.fixture(params=constructors) +def constructor(request: pytest.FixtureRequest): + return request.param # type: ignore[no-any-return] + + +@pytest.fixture(params=["pandas", "pyarrow", "polars"]) +def backend(request: pytest.FixtureRequest) -> str: + return request.param # type: ignore[no-any-return] diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_facets.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_facets.py index c1db2afe775..593b214ed1e 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_facets.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_facets.py @@ -1,12 +1,12 @@ -import pandas as pd import plotly.express as px from pytest import approx import pytest import random -def test_facets(): - df = px.data.tips() +def test_facets(backend): + df = px.data.tips(return_type=backend) + fig = px.scatter(df, x="total_bill", y="tip") assert "xaxis2" not in fig.layout assert "yaxis2" not in fig.layout @@ -46,8 +46,8 @@ def test_facets(): assert fig.layout.yaxis4.domain[0] - fig.layout.yaxis.domain[1] == approx(0.08) -def test_facets_with_marginals(): - df = px.data.tips() +def test_facets_with_marginals(backend): + df = px.data.tips(return_type=backend) fig = px.histogram(df, x="total_bill", facet_col="sex", marginal="rug") assert len(fig.data) == 4 @@ -93,12 +93,11 @@ def test_facets_with_marginals(): assert len(fig.data) == 2 # ignore all marginals -@pytest.fixture -def bad_facet_spacing_df(): +def bad_facet_spacing_df(constructor_func): NROWS = 101 NDATA = 1000 categories = [n % NROWS for n in range(NDATA)] - df = pd.DataFrame( + df = constructor_func( { "x": [random.random() for _ in range(NDATA)], "y": [random.random() for _ in range(NDATA)], @@ -108,8 +107,8 @@ def bad_facet_spacing_df(): return df -def test_bad_facet_spacing_eror(bad_facet_spacing_df): - df = bad_facet_spacing_df +def test_bad_facet_spacing_error(constructor): + df = bad_facet_spacing_df(constructor_func=constructor) with pytest.raises( ValueError, match="Use the facet_row_spacing argument to adjust this spacing." ): diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_marginals.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_marginals.py index 0274068d27a..9a7ec64d123 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_marginals.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_marginals.py @@ -5,8 +5,8 @@ @pytest.mark.parametrize("px_fn", [px.scatter, px.density_heatmap, px.density_contour]) @pytest.mark.parametrize("marginal_x", [None, "histogram", "box", "violin"]) @pytest.mark.parametrize("marginal_y", [None, "rug"]) -def test_xy_marginals(px_fn, marginal_x, marginal_y): - df = px.data.tips() +def test_xy_marginals(backend, px_fn, marginal_x, marginal_y): + df = px.data.tips(return_type=backend) fig = px_fn( df, x="total_bill", y="tip", marginal_x=marginal_x, marginal_y=marginal_y @@ -17,8 +17,8 @@ def test_xy_marginals(px_fn, marginal_x, marginal_y): @pytest.mark.parametrize("px_fn", [px.histogram, px.ecdf]) @pytest.mark.parametrize("marginal", [None, "rug", "histogram", "box", "violin"]) @pytest.mark.parametrize("orientation", ["h", "v"]) -def test_single_marginals(px_fn, marginal, orientation): - df = px.data.tips() +def test_single_marginals(backend, px_fn, marginal, orientation): + df = px.data.tips(return_type=backend) fig = px_fn( df, x="total_bill", y="total_bill", marginal=marginal, orientation=orientation diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py index 8bcff763ab2..8d091df3ae2 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px.py @@ -1,54 +1,63 @@ import plotly.express as px import plotly.io as pio +import narwhals.stable.v1 as nw import numpy as np import pytest from itertools import permutations -def test_scatter(): - iris = px.data.iris() - fig = px.scatter(iris, x="sepal_width", y="sepal_length") +def test_scatter(backend): + iris = nw.from_native(px.data.iris(return_type=backend)) + fig = px.scatter(iris.to_native(), x="sepal_width", y="sepal_length") assert fig.data[0].type == "scatter" - assert np.all(fig.data[0].x == iris.sepal_width) - assert np.all(fig.data[0].y == iris.sepal_length) + assert np.all(fig.data[0].x == iris.get_column("sepal_width").to_numpy()) + assert np.all(fig.data[0].y == iris.get_column("sepal_length").to_numpy()) # test defaults assert fig.data[0].mode == "markers" -def test_custom_data_scatter(): - iris = px.data.iris() +def test_custom_data_scatter(backend): + iris = nw.from_native(px.data.iris(return_type=backend)) # No hover, no custom data - fig = px.scatter(iris, x="sepal_width", y="sepal_length", color="species") + fig = px.scatter( + iris.to_native(), x="sepal_width", y="sepal_length", color="species" + ) assert fig.data[0].customdata is None # Hover, no custom data fig = px.scatter( - iris, + iris.to_native(), x="sepal_width", y="sepal_length", color="species", hover_data=["petal_length", "petal_width"], ) for data in fig.data: - assert np.all(np.in1d(data.customdata[:, 1], iris.petal_width)) + assert np.all( + np.in1d(data.customdata[:, 1], iris.get_column("petal_width").to_numpy()) + ) # Hover and custom data, no repeated arguments fig = px.scatter( - iris, + iris.to_native(), x="sepal_width", y="sepal_length", hover_data=["petal_length", "petal_width"], custom_data=["species_id", "species"], ) - assert np.all(fig.data[0].customdata[:, 0] == iris.species_id) + assert np.all( + fig.data[0].customdata[:, 0] == iris.get_column("species_id").to_numpy() + ) assert fig.data[0].customdata.shape[1] == 4 # Hover and custom data, with repeated arguments fig = px.scatter( - iris, + iris.to_native(), x="sepal_width", y="sepal_length", hover_data=["petal_length", "petal_width", "species_id"], custom_data=["species_id", "species"], ) - assert np.all(fig.data[0].customdata[:, 0] == iris.species_id) + assert np.all( + fig.data[0].customdata[:, 0] == iris.get_column("species_id").to_numpy() + ) assert fig.data[0].customdata.shape[1] == 4 assert ( fig.data[0].hovertemplate @@ -56,10 +65,10 @@ def test_custom_data_scatter(): ) -def test_labels(): - tips = px.data.tips() +def test_labels(backend): + tips = nw.from_native(px.data.tips(return_type=backend)) fig = px.scatter( - tips, + tips.to_native(), x="total_bill", y="tip", facet_row="time", @@ -88,8 +97,8 @@ def test_labels(): ({"text": "continent"}, "lines+markers+text"), ], ) -def test_line_mode(extra_kwargs, expected_mode): - gapminder = px.data.gapminder() +def test_line_mode(backend, extra_kwargs, expected_mode): + gapminder = px.data.gapminder(return_type=backend) fig = px.line( gapminder, x="year", @@ -100,11 +109,11 @@ def test_line_mode(extra_kwargs, expected_mode): assert fig.data[0].mode == expected_mode -def test_px_templates(): +def test_px_templates(backend): try: import plotly.graph_objects as go - tips = px.data.tips() + tips = px.data.tips(return_type=backend) # use the normal defaults fig = px.scatter() @@ -230,11 +239,14 @@ def test_px_defaults(): pio.templates.default = "plotly" -def assert_orderings(days_order, days_check, times_order, times_check): +def assert_orderings(backend, days_order, days_check, times_order, times_check): symbol_sequence = ["circle", "diamond", "square", "cross", "circle", "diamond"] color_sequence = ["red", "blue", "red", "blue", "red", "blue", "red", "blue"] + + tips = nw.from_native(px.data.tips(return_type=backend)) + fig = px.scatter( - px.data.tips(), + tips.to_native(), x="total_bill", y="tip", facet_row="time", @@ -265,14 +277,16 @@ def assert_orderings(days_order, days_check, times_order, times_check): @pytest.mark.parametrize("days", permutations(["Sun", "Sat", "Fri", "x"])) @pytest.mark.parametrize("times", permutations(["Lunch", "x"])) -def test_orthogonal_and_missing_orderings(days, times): - assert_orderings(days, list(days) + ["Thur"], times, list(times) + ["Dinner"]) +def test_orthogonal_and_missing_orderings(backend, days, times): + assert_orderings( + backend, days, list(days) + ["Thur"], times, list(times) + ["Dinner"] + ) @pytest.mark.parametrize("days", permutations(["Sun", "Sat", "Fri", "Thur"])) @pytest.mark.parametrize("times", permutations(["Lunch", "Dinner"])) -def test_orthogonal_orderings(days, times): - assert_orderings(days, days, times, times) +def test_orthogonal_orderings(backend, days, times): + assert_orderings(backend, days, days, times, times) def test_permissive_defaults(): @@ -281,8 +295,8 @@ def test_permissive_defaults(): px.defaults.should_not_work = "test" -def test_marginal_ranges(): - df = px.data.tips() +def test_marginal_ranges(backend): + df = px.data.tips(return_type=backend) fig = px.scatter( df, x="total_bill", @@ -296,23 +310,32 @@ def test_marginal_ranges(): assert fig.layout.yaxis3.range is None -def test_render_mode(): - df = px.data.gapminder() - df2007 = df.query("year == 2007") - fig = px.scatter(df2007, x="gdpPercap", y="lifeExp", trendline="ols") +def test_render_mode(backend): + df = nw.from_native(px.data.gapminder(return_type=backend)) + df2007 = df.filter(nw.col("year") == 2007) + + fig = px.scatter(df2007.to_native(), x="gdpPercap", y="lifeExp", trendline="ols") assert fig.data[0].type == "scatter" assert fig.data[1].type == "scatter" fig = px.scatter( - df2007, x="gdpPercap", y="lifeExp", trendline="ols", render_mode="webgl" + df2007.to_native(), + x="gdpPercap", + y="lifeExp", + trendline="ols", + render_mode="webgl", ) assert fig.data[0].type == "scattergl" assert fig.data[1].type == "scattergl" - fig = px.scatter(df, x="gdpPercap", y="lifeExp", trendline="ols") + fig = px.scatter(df.to_native(), x="gdpPercap", y="lifeExp", trendline="ols") assert fig.data[0].type == "scattergl" assert fig.data[1].type == "scattergl" - fig = px.scatter(df, x="gdpPercap", y="lifeExp", trendline="ols", render_mode="svg") + fig = px.scatter( + df.to_native(), x="gdpPercap", y="lifeExp", trendline="ols", render_mode="svg" + ) assert fig.data[0].type == "scatter" assert fig.data[1].type == "scatter" - fig = px.density_contour(df, x="gdpPercap", y="lifeExp", trendline="ols") + fig = px.density_contour( + df.to_native(), x="gdpPercap", y="lifeExp", trendline="ols" + ) assert fig.data[0].type == "histogram2dcontour" assert fig.data[1].type == "scatter" diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_functions.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_functions.py index e34dd0d20bd..2f165db078c 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_functions.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_functions.py @@ -1,8 +1,9 @@ import plotly.express as px import plotly.graph_objects as go from numpy.testing import assert_array_equal +import narwhals.stable.v1 as nw import numpy as np -import pandas as pd +from polars.exceptions import InvalidOperationError import pytest @@ -118,7 +119,7 @@ def test_sunburst_treemap_colorscales(): assert list(fig.layout[colorway]) == color_seq -def test_sunburst_treemap_with_path(): +def test_sunburst_treemap_with_path(constructor): vendors = ["A", "B", "C", "D", "E", "F", "G", "H"] sectors = [ "Tech", @@ -133,7 +134,7 @@ def test_sunburst_treemap_with_path(): regions = ["North", "North", "North", "North", "South", "South", "South", "South"] values = [1, 3, 2, 4, 2, 2, 1, 4] total = ["total"] * 8 - df = pd.DataFrame( + df = constructor( dict( vendors=vendors, sectors=sectors, @@ -150,53 +151,71 @@ def test_sunburst_treemap_with_path(): fig = px.sunburst(df, path=path, values="values") assert fig.data[0].branchvalues == "total" assert fig.data[0].values[-1] == np.sum(values) - # Values passed - fig = px.sunburst(df, path=path, values="values") - assert fig.data[0].branchvalues == "total" - assert fig.data[0].values[-1] == np.sum(values) + # Error when values cannot be converted to numerical data type - df["values"] = ["1 000", "3 000", "2", "4", "2", "2", "1 000", "4 000"] - msg = "Column `values` of `df` could not be converted to a numerical data type." - with pytest.raises(ValueError, match=msg): - fig = px.sunburst(df, path=path, values="values") + df = nw.from_native(df) + native_namespace = nw.get_native_namespace(df) + df = df.with_columns( + values=nw.new_series( + "values", + ["1 000", "3 000", "2", "4", "2", "2", "1 000", "4 000"], + dtype=nw.String(), + native_namespace=native_namespace, + ) + ) + pd_msg = "Column `values` of `df` could not be converted to a numerical data type." + pl_msg = "conversion from `str` to `f64` failed in column 'values'" + + with pytest.raises( + (ValueError, InvalidOperationError), match=f"({pd_msg}|{pl_msg})" + ): + fig = px.sunburst(df.to_native(), path=path, values="values") # path is a mixture of column names and array-like - path = [df.total, "regions", df.sectors, "vendors"] - fig = px.sunburst(df, path=path) + path = [ + df.get_column("total").to_native(), + "regions", + df.get_column("sectors").to_native(), + "vendors", + ] + fig = px.sunburst(df.to_native(), path=path) assert fig.data[0].branchvalues == "total" # Continuous colorscale - df["values"] = 1 - fig = px.sunburst(df, path=path, values="values", color="values") + df = df.with_columns(values=nw.lit(1)) + fig = px.sunburst(df.to_native(), path=path, values="values", color="values") assert "coloraxis" in fig.data[0].marker assert np.all(np.array(fig.data[0].marker.colors) == 1) assert fig.data[0].values[-1] == 8 -def test_sunburst_treemap_with_path_and_hover(): - df = px.data.tips() +def test_sunburst_treemap_with_path_and_hover(backend): + df = px.data.tips(return_type=backend) fig = px.sunburst( df, path=["sex", "day", "time", "smoker"], color="smoker", hover_data=["smoker"] ) assert "smoker" in fig.data[0].hovertemplate - df = px.data.gapminder().query("year == 2007") + df = nw.from_native(px.data.gapminder(year=2007, return_type=backend)) fig = px.sunburst( - df, path=["continent", "country"], color="lifeExp", hover_data=df.columns + df.to_native(), + path=["continent", "country"], + color="lifeExp", + hover_data=df.columns, ) assert fig.layout.coloraxis.colorbar.title.text == "lifeExp" - df = px.data.tips() + df = px.data.tips(return_type=backend) fig = px.sunburst(df, path=["sex", "day", "time", "smoker"], hover_name="smoker") assert "smoker" not in fig.data[0].hovertemplate # represented as '%{hovertext}' assert "%{hovertext}" in fig.data[0].hovertemplate # represented as '%{hovertext}' - df = px.data.tips() + df = px.data.tips(return_type=backend) fig = px.sunburst(df, path=["sex", "day", "time", "smoker"], custom_data=["smoker"]) assert fig.data[0].customdata[0][0] in ["Yes", "No"] assert "smoker" not in fig.data[0].hovertemplate assert "%{hovertext}" not in fig.data[0].hovertemplate -def test_sunburst_treemap_with_path_color(): +def test_sunburst_treemap_with_path_color(constructor): vendors = ["A", "B", "C", "D", "E", "F", "G", "H"] sectors = [ "Tech", @@ -212,50 +231,69 @@ def test_sunburst_treemap_with_path_color(): values = [1, 3, 2, 4, 2, 2, 1, 4] calls = [8, 2, 1, 3, 2, 2, 4, 1] total = ["total"] * 8 - df = pd.DataFrame( - dict( - vendors=vendors, - sectors=sectors, - regions=regions, - values=values, - total=total, - calls=calls, - ) + hover = [el.lower() for el in vendors] + + data = dict( + vendors=vendors, + sectors=sectors, + regions=regions, + values=values, + total=total, + calls=calls, ) + df = nw.from_native(constructor(data)) path = ["total", "regions", "sectors", "vendors"] - fig = px.sunburst(df, path=path, values="values", color="calls") + fig = px.sunburst(df.to_native(), path=path, values="values", color="calls") colors = fig.data[0].marker.colors - assert np.all(np.array(colors[:8]) == np.array(calls)) - fig = px.sunburst(df, path=path, color="calls") + assert np.all(np.array(np.sort(colors[:8])) == np.array(sorted(calls))) + fig = px.sunburst(df.to_native(), path=path, color="calls") colors = fig.data[0].marker.colors - assert np.all(np.array(colors[:8]) == np.array(calls)) + assert np.all(np.sort(colors[:8]) == np.array(sorted(calls))) # Hover info - df["hover"] = [el.lower() for el in vendors] - fig = px.sunburst(df, path=path, color="calls", hover_data=["hover"]) + df = df.with_columns( + hover=nw.new_series( + name="hover", + values=hover, + dtype=nw.String(), + native_namespace=nw.get_native_namespace(df), + ) + ) + fig = px.sunburst(df.to_native(), path=path, color="calls", hover_data=["hover"]) custom = fig.data[0].customdata - assert np.all(custom[:8, 0] == df["hover"]) - assert np.all(custom[8:, 0] == "(?)") - assert np.all(custom[:8, 1] == df["calls"]) + assert np.all(np.sort(custom[:8, 0]) == sorted(hover)) + assert np.all(np.sort(custom[8:, 0]) == "(?)") + assert np.all(np.sort(custom[:8, 1]) == sorted(calls)) # Discrete color - fig = px.sunburst(df, path=path, color="vendors") + fig = px.sunburst(df.to_native(), path=path, color="vendors") assert len(np.unique(fig.data[0].marker.colors)) == 9 # Discrete color and color_discrete_map cmap = {"Tech": "yellow", "Finance": "magenta", "(?)": "black"} - fig = px.sunburst(df, path=path, color="sectors", color_discrete_map=cmap) + fig = px.sunburst( + df.to_native(), path=path, color="sectors", color_discrete_map=cmap + ) assert np.all(np.in1d(fig.data[0].marker.colors, list(cmap.values()))) # Numerical column in path - df["regions"] = df["regions"].map({"North": 1, "South": 2}) + df = ( + nw.from_native(df) + .with_columns( + regions=nw.when(nw.col("regions") == "North") + .then(1) + .otherwise(2) + .cast(nw.Int64()) + ) + .to_native() + ) path = ["total", "regions", "sectors", "vendors"] fig = px.sunburst(df, path=path, values="values", color="calls") colors = fig.data[0].marker.colors - assert np.all(np.array(colors[:8]) == np.array(calls)) + assert np.all(np.sort(colors[:8]) == sorted(calls)) -def test_sunburst_treemap_column_parent(): +def test_sunburst_treemap_column_parent(constructor): vendors = ["A", "B", "C", "D", "E", "F", "G", "H"] sectors = [ "Tech", @@ -269,7 +307,7 @@ def test_sunburst_treemap_column_parent(): ] regions = ["North", "North", "North", "North", "South", "South", "South", "South"] values = [1, 3, 2, 4, 2, 2, 1, 4] - df = pd.DataFrame( + df = constructor( dict( id=vendors, sectors=sectors, @@ -282,7 +320,7 @@ def test_sunburst_treemap_column_parent(): px.sunburst(df, path=path, values="values") -def test_sunburst_treemap_with_path_non_rectangular(): +def test_sunburst_treemap_with_path_non_rectangular(constructor): vendors = ["A", "B", "C", "D", None, "E", "F", "G", "H", None] sectors = [ "Tech", @@ -310,7 +348,7 @@ def test_sunburst_treemap_with_path_non_rectangular(): ] values = [1, 3, 2, 4, 1, 2, 2, 1, 4, 1] total = ["total"] * 10 - df = pd.DataFrame( + df = constructor( dict( vendors=vendors, sectors=sectors, @@ -323,7 +361,18 @@ def test_sunburst_treemap_with_path_non_rectangular(): msg = "Non-leaves rows are not permitted in the dataframe" with pytest.raises(ValueError, match=msg): fig = px.sunburst(df, path=path, values="values") - df.loc[df["vendors"].isnull(), "sectors"] = "Other" + + df = ( + nw.from_native(df) + .with_columns( + sectors=( + nw.when(~nw.col("vendors").is_null()) + .then(nw.col("sectors")) + .otherwise(nw.lit("Other")) + ) + ) + .to_native() + ) fig = px.sunburst(df, path=path, values="values") assert fig.data[0].values[-1] == np.sum(values) @@ -362,8 +411,8 @@ def test_funnel(): assert len(fig.data) == 2 -def test_parcats_dimensions_max(): - df = px.data.tips() +def test_parcats_dimensions_max(backend): + df = px.data.tips(return_type=backend) # default behaviour fig = px.parallel_categories(df) @@ -396,12 +445,12 @@ def test_parcats_dimensions_max(): @pytest.mark.parametrize("histfunc,y", [(None, None), ("count", "tip")]) -def test_histfunc_hoverlabels_univariate(histfunc, y): +def test_histfunc_hoverlabels_univariate(backend, histfunc, y): def check_label(label, fig): assert fig.layout.yaxis.title.text == label assert label + "=" in fig.data[0].hovertemplate - df = px.data.tips() + df = px.data.tips(return_type=backend) # base case, just "count" (note count(tip) is same as count()) fig = px.histogram(df, x="total_bill", y=y, histfunc=histfunc) @@ -427,12 +476,12 @@ def check_label(label, fig): check_label("%s (normalized as %s)" % (histnorm, barnorm), fig) -def test_histfunc_hoverlabels_bivariate(): +def test_histfunc_hoverlabels_bivariate(backend): def check_label(label, fig): assert fig.layout.yaxis.title.text == label assert label + "=" in fig.data[0].hovertemplate - df = px.data.tips() + df = px.data.tips(return_type=backend) # with y, should be same as forcing histfunc to sum fig = px.histogram(df, x="total_bill", y="tip") @@ -487,13 +536,14 @@ def check_label(label, fig): check_label("density of max of tip", fig) -def test_timeline(): - df = pd.DataFrame( - [ - dict(Task="Job A", Start="2009-01-01", Finish="2009-02-28"), - dict(Task="Job B", Start="2009-03-05", Finish="2009-04-15"), - dict(Task="Job C", Start="2009-02-20", Finish="2009-05-30"), - ] +def test_timeline(constructor): + + df = constructor( + { + "Task": ["Job A", "Job B", "Job C"], + "Start": ["2009-01-01", "2009-03-05", "2009-02-20"], + "Finish": ["2009-02-28", "2009-04-15", "2009-05-30"], + } ) fig = px.timeline(df, x_start="Start", x_end="Finish", y="Task", color="Task") assert len(fig.data) == 3 diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_hover.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_hover.py index 26ee1a26198..35a91a5c34f 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_hover.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_hover.py @@ -1,12 +1,13 @@ import plotly.express as px +import narwhals.stable.v1 as nw import numpy as np import pandas as pd import pytest from collections import OrderedDict # an OrderedDict is needed for Python 2 -def test_skip_hover(): - df = px.data.iris() +def test_skip_hover(backend): + df = px.data.iris(return_type=backend) fig = px.scatter( df, x="petal_length", @@ -17,8 +18,8 @@ def test_skip_hover(): assert fig.data[0].hovertemplate == "species_id=%{marker.size}" -def test_hover_data_string_column(): - df = px.data.tips() +def test_hover_data_string_column(backend): + df = px.data.tips(return_type=backend) fig = px.scatter( df, x="tip", @@ -28,8 +29,8 @@ def test_hover_data_string_column(): assert "sex" in fig.data[0].hovertemplate -def test_composite_hover(): - df = px.data.tips() +def test_composite_hover(backend): + df = px.data.tips(return_type=backend) hover_dict = OrderedDict( {"day": False, "time": False, "sex": True, "total_bill": ":.1f"} ) @@ -87,8 +88,8 @@ def test_newdatain_hover_data(): ) -def test_formatted_hover_and_labels(): - df = px.data.tips() +def test_formatted_hover_and_labels(backend): + df = px.data.tips(return_type=backend) fig = px.scatter( df, x="tip", @@ -171,8 +172,8 @@ def test_fail_wrong_column(): ) -def test_sunburst_hoverdict_color(): - df = px.data.gapminder().query("year == 2007") +def test_sunburst_hoverdict_color(backend): + df = px.data.gapminder(year=2007, return_type=backend) fig = px.sunburst( df, path=["continent", "country"], @@ -183,8 +184,13 @@ def test_sunburst_hoverdict_color(): assert "color" in fig.data[0].hovertemplate -def test_date_in_hover(): - df = pd.DataFrame({"date": ["2015-04-04 19:31:30+1:00"], "value": [3]}) - df["date"] = pd.to_datetime(df["date"]) - fig = px.scatter(df, x="value", y="value", hover_data=["date"]) - assert str(fig.data[0].customdata[0][0]) == str(df["date"][0]) +def test_date_in_hover(constructor): + df = nw.from_native( + constructor({"date": ["2015-04-04 19:31:30+01:00"], "value": [3]}) + ).with_columns(date=nw.col("date").str.to_datetime(format="%Y-%m-%d %H:%M:%S%z")) + fig = px.scatter(df.to_native(), x="value", y="value", hover_data=["date"]) + + # Check that what gets displayed is the local datetime + assert nw.to_py_scalar(fig.data[0].customdata[0][0]) == nw.to_py_scalar( + df.item(row=0, column="date") + ).replace(tzinfo=None) diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py index a6bbf9b4e46..81605b59a29 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_input.py @@ -1,5 +1,7 @@ import plotly.express as px +import pyarrow as pa import plotly.graph_objects as go +import narwhals.stable.v1 as nw import numpy as np import pandas as pd import pytest @@ -11,21 +13,6 @@ import warnings -# Fixtures -# -------- -@pytest.fixture -def add_interchange_module_for_old_pandas(): - if not hasattr(pd.api, "interchange"): - with mock.patch.object(pd.api, "interchange", mock.MagicMock(), create=True): - # to make the following import work: `import pandas.api.interchange` - with mock.patch.dict( - "sys.modules", {"pandas.api.interchange": pd.api.interchange} - ): - yield - else: - yield - - def test_numpy(): fig = px.scatter(x=[1, 2, 3], y=[2, 3, 4], color=[1, 3, 9]) assert np.all(fig.data[0].x == np.array([1, 2, 3])) @@ -64,51 +51,84 @@ def test_with_index(): assert fig.data[0]["hovertemplate"] == "item=%{x}
total_bill=%{y}" -def test_pandas_series(): - tips = px.data.tips() - before_tip = tips.total_bill - tips.tip +def test_series(request, backend): + if backend == "pyarrow": + # By converting to native, we lose the name for pyarrow chunked_array + # and the assertions fail + request.applymarker(pytest.mark.xfail) + + tips = nw.from_native(px.data.tips(return_type=backend)) + before_tip = (tips.get_column("total_bill") - tips.get_column("tip")).to_native() + day = tips.get_column("day").to_native() + tips = tips.to_native() + fig = px.bar(tips, x="day", y=before_tip) assert fig.data[0].hovertemplate == "day=%{x}
y=%{y}" fig = px.bar(tips, x="day", y=before_tip, labels={"y": "bill"}) assert fig.data[0].hovertemplate == "day=%{x}
bill=%{y}" # lock down that we can pass df.col to facet_* - fig = px.bar(tips, x="day", y="tip", facet_row=tips.day, facet_col=tips.day) + fig = px.bar(tips, x="day", y="tip", facet_row=day, facet_col=day) assert fig.data[0].hovertemplate == "day=%{x}
tip=%{y}" -def test_several_dataframes(): - df = pd.DataFrame(dict(x=[0, 1], y=[1, 10], z=[0.1, 0.8])) - df2 = pd.DataFrame(dict(time=[23, 26], money=[100, 200])) - fig = px.scatter(df, x="z", y=df2.money, size="x") +def test_several_dataframes(request, constructor): + if "pyarrow_table" in str(constructor): + # By converting to native, we lose the name for pyarrow chunked_array + # and the assertions fail + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(dict(x=[0, 1], y=[1, 10], z=[0.1, 0.8]))) + df2 = nw.from_native(constructor(dict(time=[23, 26], money=[100, 200]))) + fig = px.scatter( + df.to_native(), x="z", y=df2.get_column("money").to_native(), size="x" + ) assert ( fig.data[0].hovertemplate == "z=%{x}
y=%{y}
x=%{marker.size}" ) - fig = px.scatter(df2, x=df.z, y=df2.money, size=df.z) + fig = px.scatter( + df2.to_native(), + x=df.get_column("z").to_native(), + y=df2.get_column("money").to_native(), + size=df.get_column("z").to_native(), + ) assert ( fig.data[0].hovertemplate == "x=%{x}
money=%{y}
size=%{marker.size}" ) # Name conflict with pytest.raises(NameError) as err_msg: - fig = px.scatter(df, x="z", y=df2.money, size="y") + fig = px.scatter( + df.to_native(), x="z", y=df2.get_column("money").to_native(), size="y" + ) assert "A name conflict was encountered for argument 'y'" in str(err_msg.value) with pytest.raises(NameError) as err_msg: - fig = px.scatter(df, x="z", y=df2.money, size=df.y) + fig = px.scatter( + df.to_native(), + x="z", + y=df2.get_column("money").to_native(), + size=df.get_column("y").to_native(), + ) assert "A name conflict was encountered for argument 'y'" in str(err_msg.value) # No conflict when the dataframe is not given, fields are used - df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) - df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) - fig = px.scatter(x=df.y, y=df2.y) + df = nw.from_native(constructor(dict(x=[0, 1], y=[3, 4]))) + df2 = nw.from_native(constructor(dict(x=[3, 5], y=[23, 24]))) + fig = px.scatter( + x=df.get_column("y").to_native(), y=df2.get_column("y").to_native() + ) assert np.all(fig.data[0].x == np.array([3, 4])) assert np.all(fig.data[0].y == np.array([23, 24])) assert fig.data[0].hovertemplate == "x=%{x}
y=%{y}" - df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) - df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) - df3 = pd.DataFrame(dict(y=[0.1, 0.2])) - fig = px.scatter(x=df.y, y=df2.y, size=df3.y) + df = nw.from_native(constructor(dict(x=[0, 1], y=[3, 4]))) + df2 = nw.from_native(constructor(dict(x=[3, 5], y=[23, 24]))) + df3 = nw.from_native(constructor(dict(y=[0.1, 0.2]))) + fig = px.scatter( + x=df.get_column("y").to_native(), + y=df2.get_column("y").to_native(), + size=df3.get_column("y").to_native(), + ) assert np.all(fig.data[0].x == np.array([3, 4])) assert np.all(fig.data[0].y == np.array([23, 24])) assert ( @@ -116,10 +136,14 @@ def test_several_dataframes(): == "x=%{x}
y=%{y}
size=%{marker.size}" ) - df = pd.DataFrame(dict(x=[0, 1], y=[3, 4])) - df2 = pd.DataFrame(dict(x=[3, 5], y=[23, 24])) - df3 = pd.DataFrame(dict(y=[0.1, 0.2])) - fig = px.scatter(x=df.y, y=df2.y, hover_data=[df3.y]) + df = nw.from_native(constructor(dict(x=[0, 1], y=[3, 4]))) + df2 = nw.from_native(constructor(dict(x=[3, 5], y=[23, 24]))) + df3 = nw.from_native(constructor(dict(y=[0.1, 0.2]))) + fig = px.scatter( + x=df.get_column("y").to_native(), + y=df2.get_column("y").to_native(), + hover_data=[df3.get_column("y").to_native()], + ) assert np.all(fig.data[0].x == np.array([3, 4])) assert np.all(fig.data[0].y == np.array([23, 24])) assert ( @@ -128,16 +152,26 @@ def test_several_dataframes(): ) -def test_name_heuristics(): - df = pd.DataFrame(dict(x=[0, 1], y=[3, 4], z=[0.1, 0.2])) - fig = px.scatter(df, x=df.y, y=df.x, size=df.y) +def test_name_heuristics(request, constructor): + if "pyarrow_table" in str(constructor): + # By converting to native, we lose the name for pyarrow chunked_array + # and the assertions fail + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor(dict(x=[0, 1], y=[3, 4], z=[0.1, 0.2]))) + fig = px.scatter( + df.to_native(), + x=df.get_column("y").to_native(), + y=df.get_column("x").to_native(), + size=df.get_column("y").to_native(), + ) assert np.all(fig.data[0].x == np.array([3, 4])) assert np.all(fig.data[0].y == np.array([0, 1])) assert fig.data[0].hovertemplate == "y=%{marker.size}
x=%{y}" -def test_repeated_name(): - iris = px.data.iris() +def test_repeated_name(backend): + iris = px.data.iris(return_type=backend) fig = px.scatter( iris, x="sepal_width", @@ -148,8 +182,8 @@ def test_repeated_name(): assert fig.data[0].customdata.shape[1] == 4 -def test_arrayattrable_numpy(): - tips = px.data.tips() +def test_arrayattrable_numpy(backend): + tips = px.data.tips(return_type=backend) fig = px.scatter( tips, x="total_bill", y="tip", hover_data=[np.random.random(tips.shape[0])] ) @@ -157,7 +191,6 @@ def test_arrayattrable_numpy(): fig.data[0]["hovertemplate"] == "total_bill=%{x}
tip=%{y}
hover_data_0=%{customdata[0]}" ) - tips = px.data.tips() fig = px.scatter( tips, x="total_bill", @@ -191,20 +224,21 @@ def test_wrong_dimensions_of_array(): assert "All arguments should have the same length." in str(err_msg.value) -def test_wrong_dimensions_mixed_case(): +def test_wrong_dimensions_mixed_case(constructor): with pytest.raises(ValueError) as err_msg: - df = pd.DataFrame(dict(time=[1, 2, 3], temperature=[20, 30, 25])) + df = constructor(dict(time=[1, 2, 3], temperature=[20, 30, 25])) px.scatter(df, x="time", y="temperature", color=[1, 3, 9, 5]) assert "All arguments should have the same length." in str(err_msg.value) -def test_wrong_dimensions(): +def test_wrong_dimensions(backend): + df = px.data.tips(return_type=backend) with pytest.raises(ValueError) as err_msg: - px.scatter(px.data.tips(), x="tip", y=[1, 2, 3]) + px.scatter(df, x="tip", y=[1, 2, 3]) assert "All arguments should have the same length." in str(err_msg.value) # the order matters with pytest.raises(ValueError) as err_msg: - px.scatter(px.data.tips(), x=[1, 2, 3], y="tip") + px.scatter(df, x=[1, 2, 3], y="tip") assert "All arguments should have the same length." in str(err_msg.value) with pytest.raises(ValueError): px.scatter(px.data.tips(), x=px.data.iris().index, y="tip") @@ -230,8 +264,9 @@ def test_build_df_from_lists(): df = pd.DataFrame(args) args["data_frame"] = None out = build_dataframe(args, go.Scatter) - assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) - out.pop("data_frame") + df_out = out.pop("data_frame").to_native() + + assert df_out.equals(df) assert out == output # Arrays @@ -240,8 +275,8 @@ def test_build_df_from_lists(): df = pd.DataFrame(args) args["data_frame"] = None out = build_dataframe(args, go.Scatter) - assert_frame_equal(df.sort_index(axis=1), out["data_frame"].sort_index(axis=1)) - out.pop("data_frame") + df_out = out.pop("data_frame").to_native() + assert df_out.equals(df) assert out == output @@ -249,74 +284,57 @@ def test_build_df_with_index(): tips = px.data.tips() args = dict(data_frame=tips, x=tips.index, y="total_bill") out = build_dataframe(args, go.Scatter) - assert_frame_equal(tips.reset_index()[out["data_frame"].columns], out["data_frame"]) + assert_frame_equal( + tips.reset_index()[out["data_frame"].columns], out["data_frame"].to_pandas() + ) -@pytest.mark.parametrize("column_names_as_generator", [False, True]) -def test_build_df_using_interchange_protocol_mock( - add_interchange_module_for_old_pandas, column_names_as_generator -): +def test_build_df_using_interchange_protocol_mock(): class InterchangeDataFrame: - def __init__(self, columns): - self._columns = columns + def __init__(self, df): + self._df = df - if column_names_as_generator: + def __dataframe__(self): + return self + + def column_names(self): + return list(self._df._data.keys()) + + def select_columns_by_name(self, columns): + return InterchangeDataFrame( + CustomDataFrame( + { + key: value + for key, value in self._df._data.items() + if key in columns + } + ) + ) - def column_names(self): - for col in self._columns: - yield col + class CustomDataFrame: + def __init__(self, data): + self._data = data - else: + def __dataframe__(self, allow_copy: bool = True): + return InterchangeDataFrame(self) - def column_names(self): - return self._columns + input_dataframe = CustomDataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - interchange_dataframe = InterchangeDataFrame( - ["petal_width", "sepal_length", "sepal_width"] - ) - interchange_dataframe_reduced = InterchangeDataFrame( - ["petal_width", "sepal_length"] - ) - interchange_dataframe.select_columns_by_name = mock.MagicMock( - return_value=interchange_dataframe_reduced - ) - interchange_dataframe_reduced.select_columns_by_name = mock.MagicMock( - return_value=interchange_dataframe_reduced - ) + input_dataframe_pa = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}) - class CustomDataFrame: - def __dataframe__(self): - return interchange_dataframe + args = dict(data_frame=input_dataframe, x="a", y="b") + with mock.patch( + "narwhals._interchange.dataframe.InterchangeFrame.to_arrow", + return_value=input_dataframe_pa, + ) as mock_from_dataframe: + out = build_dataframe(args, go.Scatter) - class CustomDataFrameReduced: - def __dataframe__(self): - return interchange_dataframe_reduced - - input_dataframe = CustomDataFrame() - input_dataframe_reduced = CustomDataFrameReduced() - - iris_pandas = px.data.iris() - - with mock.patch("pandas.__version__", "2.0.2"): - args = dict(data_frame=input_dataframe, x="petal_width", y="sepal_length") - with mock.patch( - "pandas.api.interchange.from_dataframe", return_value=iris_pandas - ) as mock_from_dataframe: - build_dataframe(args, go.Scatter) - mock_from_dataframe.assert_called_once_with(interchange_dataframe_reduced) - assert set(interchange_dataframe.select_columns_by_name.call_args[0][0]) == { - "petal_width", - "sepal_length", - } - - args = dict(data_frame=input_dataframe_reduced, color=None) - with mock.patch( - "pandas.api.interchange.from_dataframe", - return_value=iris_pandas[["petal_width", "sepal_length"]], - ) as mock_from_dataframe: - build_dataframe(args, go.Scatter) - mock_from_dataframe.assert_called_once_with(interchange_dataframe_reduced) - interchange_dataframe_reduced.select_columns_by_name.assert_not_called() + mock_from_dataframe.assert_called_once() + + assert_frame_equal( + input_dataframe_pa.select(out["data_frame"].columns).to_pandas(), + out["data_frame"].to_pandas(), + ) @pytest.mark.skipif( @@ -337,7 +355,8 @@ def test_build_df_from_vaex_and_polars(test_lib): args = dict(data_frame=iris_vaex, x="petal_width", y="sepal_length") out = build_dataframe(args, go.Scatter) assert_frame_equal( - iris_pandas.reset_index()[out["data_frame"].columns], out["data_frame"] + iris_pandas.reset_index()[out["data_frame"].columns], + out["data_frame"].to_pandas(), ) @@ -367,16 +386,21 @@ def test_build_df_with_hover_data_from_vaex_and_polars(test_lib, hover_data): ) out = build_dataframe(args, go.Scatter) assert_frame_equal( - iris_pandas.reset_index()[out["data_frame"].columns], out["data_frame"] + iris_pandas.reset_index()[out["data_frame"].columns], + out["data_frame"].to_pandas(), ) -def test_timezones(): - df = pd.DataFrame({"date": ["2015-04-04 19:31:30+1:00"], "value": [3]}) - df["date"] = pd.to_datetime(df["date"]) - args = dict(data_frame=df, x="date", y="value") +def test_timezones(constructor): + df = nw.from_native( + constructor({"date": ["2015-04-04 19:31:30+01:00"], "value": [3]}) + ).with_columns(nw.col("date").str.to_datetime(format="%Y-%m-%d %H:%M:%S%z")) + args = dict(data_frame=df.to_native(), x="date", y="value") out = build_dataframe(args, go.Scatter) - assert str(out["data_frame"]["date"][0]) == str(df["date"][0]) + + assert str(out["data_frame"].item(row=0, column="date")) == str( + nw.from_native(df).item(row=0, column="date") + ) def test_non_matching_index(): @@ -386,21 +410,21 @@ def test_non_matching_index(): args = dict(data_frame=df, x=df.index, y="y") out = build_dataframe(args, go.Scatter) - assert_frame_equal(expected, out["data_frame"]) + assert_frame_equal(expected, out["data_frame"].to_pandas()) expected = pd.DataFrame(dict(x=["a", "b", "c"], y=[1, 2, 3])) args = dict(data_frame=None, x=df.index, y=df.y) out = build_dataframe(args, go.Scatter) - assert_frame_equal(expected, out["data_frame"]) + assert_frame_equal(expected, out["data_frame"].to_pandas()) args = dict(data_frame=None, x=["a", "b", "c"], y=df.y) out = build_dataframe(args, go.Scatter) - assert_frame_equal(expected, out["data_frame"]) + assert_frame_equal(expected, out["data_frame"].to_pandas()) -def test_splom_case(): - iris = px.data.iris() +def test_splom_case(backend): + iris = px.data.iris(return_type=backend) fig = px.scatter_matrix(iris) assert len(fig.data[0].dimensions) == len(iris.columns) dic = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9]} @@ -411,11 +435,11 @@ def test_splom_case(): assert np.all(fig.data[0].dimensions[0].values == ar[:, 0]) -def test_int_col_names(): +def test_int_col_names(constructor): # DataFrame with int column names - lengths = pd.DataFrame(np.random.random(100)) - fig = px.histogram(lengths, x=0) - assert np.all(np.array(lengths).flatten() == fig.data[0].x) + lengths = constructor({"0": np.random.random(100)}) + fig = px.histogram(lengths, x="0") + assert np.all(nw.from_native(lengths).to_numpy().flatten() == fig.data[0].x) # Numpy array ar = np.arange(100).reshape((10, 10)) fig = px.scatter(ar, x=2, y=8) @@ -428,17 +452,17 @@ def test_data_frame_from_dict(): assert np.all(fig.data[0].x == [0, 1]) -def test_arguments_not_modified(): - iris = px.data.iris() - petal_length = iris.petal_length - hover_data = [iris.sepal_length] - px.scatter(iris, x=petal_length, y="petal_width", hover_data=hover_data) - assert iris.petal_length.equals(petal_length) - assert iris.sepal_length.equals(hover_data[0]) +def test_arguments_not_modified(backend): + iris = nw.from_native(px.data.iris(return_type=backend)) + petal_length = iris.get_column("petal_length").to_native() + hover_data = [iris.get_column("sepal_length").to_native()] + px.scatter(iris.to_native(), x=petal_length, y="petal_width", hover_data=hover_data) + assert petal_length.equals(petal_length) + assert iris.get_column("sepal_length").to_native().equals(hover_data[0]) -def test_pass_df_columns(): - tips = px.data.tips() +def test_pass_df_columns(backend): + tips = nw.from_native(px.data.tips(return_type=backend)) fig = px.histogram( tips, x="total_bill", @@ -449,13 +473,21 @@ def test_pass_df_columns(): ) # the "- 2" is because we re-use x and y in the hovertemplate where possible assert fig.data[1].hovertemplate.count("customdata") == len(tips.columns) - 2 - tips_copy = px.data.tips() - assert tips_copy.columns.equals(tips.columns) + tips_copy = nw.from_native(px.data.tips(return_type=backend)) + assert tips_copy.columns == tips.columns -def test_size_column(): - df = px.data.tips() - fig = px.scatter(df, x=df["size"], y=df.tip) +def test_size_column(request, backend): + if backend == "pyarrow": + # By converting to native, we lose the name for pyarrow chunked_array + # and the assertions fail + request.applymarker(pytest.mark.xfail) + tips = nw.from_native(px.data.tips(return_type=backend)) + fig = px.scatter( + tips.to_native(), + x=tips.get_column("size").to_native(), + y=tips.get_column("tip").to_native(), + ) assert fig.data[0].hovertemplate == "size=%{x}
tip=%{y}" @@ -595,8 +627,8 @@ def test_auto_histfunc(): ("numerical", "categorical1", "categorical1", "overlay"), ], ) -def test_auto_boxlike_overlay(fn, mode, x, y, color, result): - df = pd.DataFrame( +def test_auto_boxlike_overlay(constructor, fn, mode, x, y, color, result): + df = constructor( dict( categorical1=["a", "a", "b", "b"], categorical2=["a", "a", "b", "b"], diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_wide.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_wide.py index 1aac7b70ea1..88e1fd0278b 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_wide.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_px_wide.py @@ -1,50 +1,64 @@ import plotly.express as px import plotly.graph_objects as go -import pandas as pd +import narwhals.stable.v1 as nw import numpy as np +import pandas as pd from plotly.express._core import build_dataframe, _is_col_list from pandas.testing import assert_frame_equal import pytest import warnings -def test_is_col_list(): - df_input = pd.DataFrame(dict(a=[1, 2], b=[1, 2])) - assert _is_col_list(df_input, ["a"]) - assert _is_col_list(df_input, ["a", "b"]) - assert _is_col_list(df_input, [[3, 4]]) - assert _is_col_list(df_input, [[3, 4], [3, 4]]) - assert not _is_col_list(df_input, pytest) - assert not _is_col_list(df_input, False) - assert not _is_col_list(df_input, ["a", 1]) - assert not _is_col_list(df_input, "a") - assert not _is_col_list(df_input, 1) - assert not _is_col_list(df_input, ["a", "b", "c"]) - assert not _is_col_list(df_input, [1, 2]) +def test_is_col_list(constructor): + df_input = nw.from_native(constructor(dict(a=[1, 2], b=[1, 2]))) + native_namespace = nw.get_native_namespace(df_input) + columns = df_input.columns + df_input = df_input.to_native() + is_pd_like = nw.dependencies.is_pandas_like_dataframe(df_input) + assert _is_col_list(columns, ["a"], is_pd_like, native_namespace) + assert _is_col_list(columns, ["a", "b"], is_pd_like, native_namespace) + assert _is_col_list(columns, [[3, 4]], is_pd_like, native_namespace) + assert _is_col_list(columns, [[3, 4], [3, 4]], is_pd_like, native_namespace) + assert not _is_col_list(columns, pytest, is_pd_like, native_namespace) + assert not _is_col_list(columns, False, is_pd_like, native_namespace) + assert not _is_col_list(columns, ["a", 1], is_pd_like, native_namespace) + assert not _is_col_list(columns, "a", is_pd_like, native_namespace) + assert not _is_col_list(columns, 1, is_pd_like, native_namespace) + assert not _is_col_list(columns, ["a", "b", "c"], is_pd_like, native_namespace) + assert not _is_col_list(columns, [1, 2], is_pd_like, native_namespace) + + +def test_is_col_list_pandas(): df_input = pd.DataFrame([[1, 2], [1, 2]]) - assert _is_col_list(df_input, [0]) - assert _is_col_list(df_input, [0, 1]) - assert _is_col_list(df_input, [[3, 4]]) - assert _is_col_list(df_input, [[3, 4], [3, 4]]) - assert not _is_col_list(df_input, pytest) - assert not _is_col_list(df_input, False) - assert not _is_col_list(df_input, ["a", 1]) - assert not _is_col_list(df_input, "a") - assert not _is_col_list(df_input, 1) - assert not _is_col_list(df_input, [0, 1, 2]) - assert not _is_col_list(df_input, ["a", "b"]) + is_pd_like = True + native_namespace = pd + columns = list(df_input.columns) + assert _is_col_list(columns, [0], is_pd_like, native_namespace) + assert _is_col_list(columns, [0, 1], is_pd_like, native_namespace) + assert _is_col_list(columns, [[3, 4]], is_pd_like, native_namespace) + assert _is_col_list(columns, [[3, 4], [3, 4]], is_pd_like, native_namespace) + assert not _is_col_list(columns, pytest, is_pd_like, native_namespace) + assert not _is_col_list(columns, False, is_pd_like, native_namespace) + assert not _is_col_list(columns, ["a", 1], is_pd_like, native_namespace) + assert not _is_col_list(columns, "a", is_pd_like, native_namespace) + assert not _is_col_list(columns, 1, is_pd_like, native_namespace) + assert not _is_col_list(columns, [0, 1, 2], is_pd_like, native_namespace) + assert not _is_col_list(columns, ["a", "b"], is_pd_like, native_namespace) + df_input = None - assert _is_col_list(df_input, [[3, 4]]) - assert _is_col_list(df_input, [[3, 4], [3, 4]]) - assert not _is_col_list(df_input, [0]) - assert not _is_col_list(df_input, [0, 1]) - assert not _is_col_list(df_input, pytest) - assert not _is_col_list(df_input, False) - assert not _is_col_list(df_input, ["a", 1]) - assert not _is_col_list(df_input, "a") - assert not _is_col_list(df_input, 1) - assert not _is_col_list(df_input, [0, 1, 2]) - assert not _is_col_list(df_input, ["a", "b"]) + is_pd_like = False + native_namespace = None + assert _is_col_list(df_input, [[3, 4]], is_pd_like, native_namespace) + assert _is_col_list(df_input, [[3, 4], [3, 4]], is_pd_like, native_namespace) + assert not _is_col_list(df_input, [0], is_pd_like, native_namespace) + assert not _is_col_list(df_input, [0, 1], is_pd_like, native_namespace) + assert not _is_col_list(df_input, pytest, is_pd_like, native_namespace) + assert not _is_col_list(df_input, False, is_pd_like, native_namespace) + assert not _is_col_list(df_input, ["a", 1], is_pd_like, native_namespace) + assert not _is_col_list(df_input, "a", is_pd_like, native_namespace) + assert not _is_col_list(df_input, 1, is_pd_like, native_namespace) + assert not _is_col_list(df_input, [0, 1, 2], is_pd_like, native_namespace) + assert not _is_col_list(df_input, ["a", "b"], is_pd_like, native_namespace) @pytest.mark.parametrize( @@ -157,8 +171,8 @@ def test_wide_mode_internal(trace_type, x, y, color, orientation): if x == "index": expected["index"] = [11, 12, 13, 11, 12, 13] assert_frame_equal( - df_out.sort_index(axis=1), - pd.DataFrame(expected).sort_index(axis=1), + df_out.to_pandas(), + pd.DataFrame(expected)[df_out.columns], ) if trace_type in [go.Histogram2dContour, go.Histogram2d]: if orientation is None or orientation == "v": @@ -285,8 +299,8 @@ def test_wide_x_or_y(tt, df_in, args_in, x, y, color, df_out_exp, transpose): args_in["y"], args_in["x"] = args_in["x"], args_in["y"] args_in["data_frame"] = df_in args_out = build_dataframe(args_in, tt) - df_out = args_out.pop("data_frame").sort_index(axis=1) - assert_frame_equal(df_out, pd.DataFrame(df_out_exp).sort_index(axis=1)) + df_out = args_out.pop("data_frame") + assert_frame_equal(df_out.to_native(), pd.DataFrame(df_out_exp)[df_out.columns]) if transpose: args_exp = dict(x=y, y=x, color=color) else: @@ -306,7 +320,7 @@ def test_wide_mode_internal_bar_exception(orientation): args_out = build_dataframe(args_in, go.Bar) df_out = args_out.pop("data_frame") assert_frame_equal( - df_out.sort_index(axis=1), + df_out.to_native(), pd.DataFrame( dict( index=[11, 12, 13, 11, 12, 13], @@ -314,7 +328,7 @@ def test_wide_mode_internal_bar_exception(orientation): value=["q", "r", "s", "t", "u", "v"], count=[1, 1, 1, 1, 1, 1], ) - ).sort_index(axis=1), + )[df_out.columns], ) if orientation is None or orientation == "v": assert args_out == dict(x="value", y="count", color="variable", orientation="v") @@ -797,10 +811,11 @@ def test_wide_mode_internal_special_cases(df_in, args_in, args_expect, df_expect args_in["data_frame"] = df_in args_out = build_dataframe(args_in, go.Scatter) df_out = args_out.pop("data_frame") + assert args_out == args_expect assert_frame_equal( - df_out.sort_index(axis=1), - df_expect.sort_index(axis=1), + df_out.to_pandas(), + df_expect[df_out.columns], ) @@ -828,15 +843,15 @@ def test_mixed_input_error(df): ) -def test_mixed_number_input(): - df = pd.DataFrame(dict(a=[1, 2], b=[1.1, 2.1])) +def test_mixed_number_input(constructor): + df = constructor(dict(a=[1, 2], b=[1.1, 2.1])) fig = px.line(df) assert len(fig.data) == 2 -def test_line_group(): - df = pd.DataFrame( - data={ +def test_line_group(constructor): + df = constructor( + { "who": ["a", "a", "b", "b"], "x": [0, 1, 0, 1], "score": [1.0, 2, 3, 4], diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py index 12e7f4b5035..0deecd8586d 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py @@ -1,6 +1,6 @@ import plotly.express as px +import narwhals.stable.v1 as nw import numpy as np -import pandas as pd import pytest from datetime import datetime from plotly.tests.test_optional.test_utils.test_utils import np_nan @@ -17,10 +17,12 @@ ("ewm", dict(alpha=0.5)), ], ) -def test_trendline_results_passthrough(mode, options): - df = px.data.gapminder().query("continent == 'Oceania'") +def test_trendline_results_passthrough(backend, mode, options): + df = nw.from_native(px.data.gapminder(return_type=backend)).filter( + nw.col("continent") == "Oceania" + ) fig = px.scatter( - df, + df.to_native(), x="year", y="pop", color="country", @@ -37,9 +39,10 @@ def test_trendline_results_passthrough(mode, options): results = px.get_trendline_results(fig) if mode == "ols": assert len(results) == 2 - assert results["country"].values[0] == "Australia" - au_result = results["px_fit_results"].values[0] - assert len(au_result.params) == 2 + # Polars does not guarantee to maintain order in group by + assert set(results["country"].to_list()) == {"Australia", "New Zealand"} + result = results["px_fit_results"].values[0] + assert len(result.params) == 2 else: assert len(results) == 0 @@ -110,12 +113,20 @@ def test_trendline_enough_values(mode, options): ("ewm", dict(alpha=0.5)), ], ) -def test_trendline_nan_values(mode, options): - df = px.data.gapminder().query("continent == 'Oceania'") +def test_trendline_nan_values(backend, mode, options): start_date = 1970 - df["pop"][df["year"] < start_date] = np_nan() + df = ( + nw.from_native(px.data.gapminder(return_type=backend)) + .filter(nw.col("continent") == "Oceania") + .with_columns( + pop=nw.when(nw.col("year") >= start_date) + .then(nw.col("pop")) + .otherwise(None) + ) + ) + fig = px.scatter( - df, + df.to_native(), x="year", y="pop", color="country", @@ -183,28 +194,42 @@ def test_ols_trendline_slopes(): ("ewm", dict(alpha=0.5)), ], ) -def test_trendline_on_timeseries(mode, options): - df = px.data.stocks() +def test_trendline_on_timeseries(backend, mode, options): + + df = nw.from_native(px.data.stocks(return_type=backend)) - with pytest.raises(ValueError) as err_msg: - px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options) - assert "Could not convert value of 'x' ('date') into a numeric type." in str( - err_msg.value + pd_err_msg = r"Could not convert value of 'x' \('date'\) into a numeric type." + pl_err_msg = "conversion from `str` to `f64` failed in column 'date'" + + with pytest.raises(Exception, match=rf"({pd_err_msg}|{pl_err_msg})"): + px.scatter( + df.to_native(), + x="date", + y="GOOG", + trendline=mode, + trendline_options=options, + ) + + df = df.with_columns( + date=nw.col("date") + .str.to_datetime(format="%Y-%m-%d") + .dt.replace_time_zone("CET") + ) + + fig = px.scatter( + df.to_native(), x="date", y="GOOG", trendline=mode, trendline_options=options ) - df["date"] = pd.to_datetime(df["date"]) - df["date"] = df["date"].dt.tz_localize("CET") # force a timezone - fig = px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options) assert len(fig.data) == 2 assert len(fig.data[0].x) == len(fig.data[1].x) - assert type(fig.data[0].x[0]) == datetime - assert type(fig.data[1].x[0]) == datetime + assert isinstance(fig.data[0].x[0], (datetime, np.datetime64)) + assert isinstance(fig.data[1].x[0], (datetime, np.datetime64)) assert np.all(fig.data[0].x == fig.data[1].x) assert str(fig.data[0].x[0]) == str(fig.data[1].x[0]) -def test_overall_trendline(): - df = px.data.tips() +def test_overall_trendline(backend): + df = px.data.tips(return_type=backend) fig1 = px.scatter(df, x="total_bill", y="tip", trendline="ols") assert len(fig1.data) == 2 assert "trendline" in fig1.data[1].hovertemplate diff --git a/packages/python/plotly/plotly/tests/test_optional/test_tools/test_figure_factory.py b/packages/python/plotly/plotly/tests/test_optional/test_tools/test_figure_factory.py index 22dfa89199c..bcebfa2914d 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_tools/test_figure_factory.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_tools/test_figure_factory.py @@ -903,7 +903,7 @@ def test_simple_annotated_heatmap(self): def test_annotated_heatmap_kwargs(self): # we should be able to create an annotated heatmap with x and y axes - # lables, a defined colorscale, and supplied text. + # labels, a defined colorscale, and supplied text. z = [[1, 0], [0.25, 0.75], [0.45, 0.5]] text = [["first", "second"], ["third", "fourth"], ["fifth", "sixth"]] @@ -999,7 +999,7 @@ def test_annotated_heatmap_kwargs(self): def test_annotated_heatmap_reversescale(self): # we should be able to create an annotated heatmap with x and y axes - # lables, a defined colorscale, and supplied text. + # labels, a defined colorscale, and supplied text. z = [[1, 0], [0.25, 0.75], [0.45, 0.5]] text = [["first", "second"], ["third", "fourth"], ["fifth", "sixth"]] @@ -1222,7 +1222,7 @@ def test_fontcolor_input(self): def test_simple_table(self): - # we should be able to create a striped table by suppling a text matrix + # we should be able to create a striped table by supplying a text matrix text = [ ["Country", "Year", "Population"], diff --git a/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py b/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py index 33284452b77..c3b117e2655 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py @@ -109,7 +109,7 @@ def test_encode_as_plotly(self): utils.NotEncodable, utils.PlotlyJSONEncoder.encode_as_plotly, obj ) - # should return without exception when obj has `to_plotly_josn` attr + # should return without exception when obj has `to_plotly_json` attr expected_res = "wedidit" class ObjWithAttr(object): @@ -319,11 +319,11 @@ def test_encode_customdata_datetime_series(self): ) self.assertTrue( fig_json.startswith( - '{"data":[{"customdata":["2010-01-01T00:00:00","2010-01-02T00:00:00"]' + '{"data":[{"customdata":["2010-01-01T00:00:00.000000000","2010-01-02T00:00:00.000000000"]' ) ) - def test_encode_customdata_datetime_homogenous_dataframe(self): + def test_encode_customdata_datetime_homogeneous_dataframe(self): df = pd.DataFrame( dict( t1=pd.to_datetime(["2010-01-01", "2010-01-02"]), @@ -341,12 +341,12 @@ def test_encode_customdata_datetime_homogenous_dataframe(self): self.assertTrue( fig_json.startswith( '{"data":[{"customdata":' - '[["2010-01-01T00:00:00","2011-01-01T00:00:00"],' - '["2010-01-02T00:00:00","2011-01-02T00:00:00"]' + '[["2010-01-01T00:00:00.000000000","2011-01-01T00:00:00.000000000"],' + '["2010-01-02T00:00:00.000000000","2011-01-02T00:00:00.000000000"]' ) ) - def test_encode_customdata_datetime_inhomogenous_dataframe(self): + def test_encode_customdata_datetime_inhomogeneous_dataframe(self): df = pd.DataFrame( dict( t=pd.to_datetime(["2010-01-01", "2010-01-02"]), diff --git a/packages/python/plotly/requirements.txt b/packages/python/plotly/requirements.txt index 463fe1bbfbd..ddd5d2bf773 100644 --- a/packages/python/plotly/requirements.txt +++ b/packages/python/plotly/requirements.txt @@ -4,3 +4,6 @@ ### $ pip install -r requirements.txt ### ### ### ################################################### + +## dataframe agnostic layer ## +narwhals>=1.13.3 diff --git a/packages/python/plotly/setup.py b/packages/python/plotly/setup.py index 5432d4d765f..19c663a21bd 100644 --- a/packages/python/plotly/setup.py +++ b/packages/python/plotly/setup.py @@ -551,7 +551,7 @@ def run(self): "package_data/datasets/*", ], }, - install_requires=["packaging"], + install_requires=["narwhals>=1.13.3", "packaging"], zip_safe=False, cmdclass=dict( build_py=js_prerelease(versioneer_cmds["build_py"]), diff --git a/packages/python/plotly/test_requirements/requirements_310_core.txt b/packages/python/plotly/test_requirements/requirements_310_core.txt index c3af689b055..771df9cc87a 100644 --- a/packages/python/plotly/test_requirements/requirements_310_core.txt +++ b/packages/python/plotly/test_requirements/requirements_310_core.txt @@ -1,2 +1,3 @@ requests==2.25.1 pytest==7.4.4 +narwhals>=1.13.3 diff --git a/packages/python/plotly/test_requirements/requirements_310_optional.txt b/packages/python/plotly/test_requirements/requirements_310_optional.txt index ecb8d094eb5..295d594fb5b 100644 --- a/packages/python/plotly/test_requirements/requirements_310_optional.txt +++ b/packages/python/plotly/test_requirements/requirements_310_optional.txt @@ -18,4 +18,7 @@ scikit-image==0.22.0 psutil==5.7.0 kaleido orjson==3.8.12 -anywidget==0.9.13 \ No newline at end of file +polars[timezone] +pyarrow +narwhals>=1.13.3 +anywidget==0.9.13 diff --git a/packages/python/plotly/test_requirements/requirements_311_core.txt b/packages/python/plotly/test_requirements/requirements_311_core.txt index c3af689b055..938e282b0f0 100644 --- a/packages/python/plotly/test_requirements/requirements_311_core.txt +++ b/packages/python/plotly/test_requirements/requirements_311_core.txt @@ -1,2 +1,3 @@ requests==2.25.1 pytest==7.4.4 +narwhals>=1.13.3 \ No newline at end of file diff --git a/packages/python/plotly/test_requirements/requirements_311_optional.txt b/packages/python/plotly/test_requirements/requirements_311_optional.txt index 6a9c190744c..e18e6536384 100644 --- a/packages/python/plotly/test_requirements/requirements_311_optional.txt +++ b/packages/python/plotly/test_requirements/requirements_311_optional.txt @@ -18,4 +18,7 @@ scikit-image==0.22.0 psutil==5.7.0 kaleido orjson==3.8.12 +polars[timezone] +pyarrow +narwhals>=1.13.3 anywidget==0.9.13 diff --git a/packages/python/plotly/test_requirements/requirements_312_core.txt b/packages/python/plotly/test_requirements/requirements_312_core.txt index c3af689b055..771df9cc87a 100644 --- a/packages/python/plotly/test_requirements/requirements_312_core.txt +++ b/packages/python/plotly/test_requirements/requirements_312_core.txt @@ -1,2 +1,3 @@ requests==2.25.1 pytest==7.4.4 +narwhals>=1.13.3 diff --git a/packages/python/plotly/test_requirements/requirements_312_no_numpy_optional.txt b/packages/python/plotly/test_requirements/requirements_312_no_numpy_optional.txt index 930a9ee951f..79cfe14a7e2 100644 --- a/packages/python/plotly/test_requirements/requirements_312_no_numpy_optional.txt +++ b/packages/python/plotly/test_requirements/requirements_312_no_numpy_optional.txt @@ -17,5 +17,8 @@ scikit-image==0.22.0 psutil==5.9.7 kaleido orjson==3.9.10 +polars[timezone] +pyarrow +narwhals>=1.13.3 anywidget==0.9.13 jupyter-console==6.4.4 diff --git a/packages/python/plotly/test_requirements/requirements_312_np2_optional.txt b/packages/python/plotly/test_requirements/requirements_312_np2_optional.txt index bfdbf203f75..94e73ac11ae 100644 --- a/packages/python/plotly/test_requirements/requirements_312_np2_optional.txt +++ b/packages/python/plotly/test_requirements/requirements_312_np2_optional.txt @@ -19,4 +19,7 @@ scikit-image==0.24.0 psutil==5.9.7 kaleido orjson==3.9.10 +polars[timezone] +pyarrow +narwhals>=1.13.3 anywidget==0.9.13 diff --git a/packages/python/plotly/test_requirements/requirements_312_optional.txt b/packages/python/plotly/test_requirements/requirements_312_optional.txt index 69edf4915a9..3609e9b2725 100644 --- a/packages/python/plotly/test_requirements/requirements_312_optional.txt +++ b/packages/python/plotly/test_requirements/requirements_312_optional.txt @@ -18,5 +18,8 @@ scikit-image==0.22.0 psutil==5.9.7 kaleido orjson==3.9.10 +polars[timezone] +pyarrow +narwhals>=1.13.3 anywidget==0.9.13 jupyter-console==6.4.4 diff --git a/packages/python/plotly/test_requirements/requirements_38_core.txt b/packages/python/plotly/test_requirements/requirements_38_core.txt index 659fe1a370f..2983296113c 100644 --- a/packages/python/plotly/test_requirements/requirements_38_core.txt +++ b/packages/python/plotly/test_requirements/requirements_38_core.txt @@ -1,2 +1,3 @@ requests==2.25.1 pytest==8.1.1 +narwhals>=1.13.3 diff --git a/packages/python/plotly/test_requirements/requirements_38_optional.txt b/packages/python/plotly/test_requirements/requirements_38_optional.txt index 7f5690e0297..dc5cb4efe3c 100644 --- a/packages/python/plotly/test_requirements/requirements_38_optional.txt +++ b/packages/python/plotly/test_requirements/requirements_38_optional.txt @@ -18,4 +18,7 @@ matplotlib==3.7.3 scikit-image==0.18.1 psutil==5.7.0 kaleido +polars[timezone] +pyarrow +narwhals>=1.13.3 anywidget==0.9.13 diff --git a/packages/python/plotly/test_requirements/requirements_39_core.txt b/packages/python/plotly/test_requirements/requirements_39_core.txt index f4605b806c5..b6055de8091 100644 --- a/packages/python/plotly/test_requirements/requirements_39_core.txt +++ b/packages/python/plotly/test_requirements/requirements_39_core.txt @@ -1,2 +1,3 @@ requests==2.25.1 pytest==6.2.3 +narwhals>=1.13.3 diff --git a/packages/python/plotly/test_requirements/requirements_39_optional.txt b/packages/python/plotly/test_requirements/requirements_39_optional.txt index d24961619d8..153065dfd46 100644 --- a/packages/python/plotly/test_requirements/requirements_39_optional.txt +++ b/packages/python/plotly/test_requirements/requirements_39_optional.txt @@ -19,4 +19,7 @@ scikit-image==0.18.1 psutil==5.7.0 kaleido orjson==3.8.12 +polars[timezone] +pyarrow +narwhals>=1.13.3 anywidget==0.9.13 diff --git a/packages/python/plotly/test_requirements/requirements_39_pandas_2_optional.txt b/packages/python/plotly/test_requirements/requirements_39_pandas_2_optional.txt index da208cbd476..995c8aec7cd 100644 --- a/packages/python/plotly/test_requirements/requirements_39_pandas_2_optional.txt +++ b/packages/python/plotly/test_requirements/requirements_39_pandas_2_optional.txt @@ -19,5 +19,8 @@ psutil==5.7.0 kaleido vaex pydantic<=1.10.11 # for vaex, see https://github.com/vaexio/vaex/issues/2384 +polars[timezone] +pyarrow +narwhals>=1.13.3 polars anywidget==0.9.13