diff --git a/CHANGELOG.md b/CHANGELOG.md index 8985395f77b..9ddd0bd04c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ This project adheres to [Semantic Versioning](http://semver.org/). ### Added - Extra flags were added to the `gapminder` and `stocks` dataset to facilitate testing, documentation and demos [#3305](https://github.com/plotly/plotly.py/issues/3305) - All line-like Plotly Express functions now accept `markers` argument to display markers, and all but `line_mapbox` accept `symbol` to map a field to the symbol attribute, similar to scatter-like functions [#3326](https://github.com/plotly/plotly.py/issues/3326) + - `px.scatter` and `px.density_contours` now support new `trendline` types `'rolling'`, `'expanding'` and `'ewm'` [#2997](https://github.com/plotly/plotly.py/pull/2997) + - `px.scatter` and `px.density_contours` now support new `trendline_options` argument to parameterize trendlines, with support for constant control and log-scaling in `'ols'` and specification of the fraction used for `'lowess'`, as well as pass-through to Pandas for `'rolling'`, `'expanding'` and `'ewm'` [#2997](https://github.com/plotly/plotly.py/pull/2997) + - `px.scatter` and `px.density_contours` now support new `trendline_scope` argument that accepts the value `'overall'` to request a single trendline for all traces, including across facets and animation frames [#2997](https://github.com/plotly/plotly.py/pull/2997) ### Fixed - Fixed regression introduced in version 5.0.0 where pandas/numpy arrays with `dtype` of Object were being converted to `list` values when added to a Figure ([#3292](https://github.com/plotly/plotly.py/issues/3292), [#3293](https://github.com/plotly/plotly.py/pull/3293)) diff --git a/doc/apidoc/plotly.express.rst b/doc/apidoc/plotly.express.rst index cd252158cb4..bff238d6684 100644 --- a/doc/apidoc/plotly.express.rst +++ b/doc/apidoc/plotly.express.rst @@ -49,6 +49,8 @@ plotly's high-level API for rapid figure generation. :: density_heatmap density_mapbox imshow + set_mapbox_access_token + get_trendline_results `plotly.express` subpackages @@ -60,3 +62,4 @@ plotly's high-level API for rapid figure generation. :: generated/plotly.express.data.rst generated/plotly.express.colors.rst + generated/plotly.express.trendline_functions.rst diff --git a/doc/python/linear-fits.md b/doc/python/linear-fits.md index 7f1f1a2971f..0029be6c4be 100644 --- a/doc/python/linear-fits.md +++ b/doc/python/linear-fits.md @@ -5,8 +5,8 @@ jupyter: text_representation: extension: .md format_name: markdown - format_version: '1.1' - jupytext_version: 1.1.1 + format_version: '1.2' + jupytext_version: 1.4.2 kernelspec: display_name: Python 3 language: python @@ -20,11 +20,12 @@ jupyter: name: python nbconvert_exporter: python pygments_lexer: ipython3 - version: 3.6.8 + version: 3.7.7 plotly: description: Add linear Ordinary Least Squares (OLS) regression trendlines or non-linear Locally Weighted Scatterplot Smoothing (LOWESS) trendlines to scatterplots - in Python. + in Python. Options for moving averages (rolling means) as well as exponentially-weighted + and expanding functions. display_as: statistical language: python layout: base @@ -39,7 +40,7 @@ jupyter: [Plotly Express](/python/plotly-express/) is the easy-to-use, high-level interface to Plotly, which [operates on a variety of types of data](/python/px-arguments/) and produces [easy-to-style figures](/python/styling-plotly-express/). -Plotly Express allows you to add [Ordinary Least](https://en.wikipedia.org/wiki/Ordinary_least_squares) Squares regression trendline to scatterplots with the `trendline` argument. In order to do so, you will need to install `statsmodels` and its dependencies. Hovering over the trendline will show the equation of the line and its R-squared value. +Plotly Express allows you to add [Ordinary Least Squares](https://en.wikipedia.org/wiki/Ordinary_least_squares) regression trendline to scatterplots with the `trendline` argument. In order to do so, you will need to [install `statsmodels` and its dependencies](https://www.statsmodels.org/stable/install.html). Hovering over the trendline will show the equation of the line and its R-squared value. ```python import plotly.express as px @@ -66,14 +67,160 @@ print(results) results.query("sex == 'Male' and smoker == 'Yes'").px_fit_results.iloc[0].summary() ``` -### Non-Linear Trendlines +### Displaying a single trendline with multiple traces -Plotly Express also supports non-linear [LOWESS](https://en.wikipedia.org/wiki/Local_regression) trendlines. +_new in v5.2_ + +To display a single trendline using the entire dataset, set the `trendline_scope` argument to `"overall"`. The same trendline will be overlaid on all facets and animation frames. The trendline color can be overridden with `trendline_color_override`. + +```python +import plotly.express as px + +df = px.data.tips() +fig = px.scatter(df, x="total_bill", y="tip", symbol="smoker", color="sex", trendline="ols", trendline_scope="overall") +fig.show() +``` + +```python +import plotly.express as px + +df = px.data.tips() +fig = px.scatter(df, x="total_bill", y="tip", facet_col="smoker", color="sex", + trendline="ols", trendline_scope="overall", trendline_color_override="black") +fig.show() +``` + +### OLS Parameters + +_new in v5.2_ + +OLS trendlines can be fit with log transformations to both X or Y data using the `trendline_options` argument, independently of whether or not the plot has [logarithmic axes](https://plotly.com/python/log-plot/). + +```python +import plotly.express as px + +df = px.data.gapminder(year=2007) +fig = px.scatter(df, x="gdpPercap", y="lifeExp", + trendline="ols", trendline_options=dict(log_x=True), + title="Log-transformed fit on linear axes") +fig.show() +``` + +```python +import plotly.express as px + +df = px.data.gapminder(year=2007) +fig = px.scatter(df, x="gdpPercap", y="lifeExp", log_x=True, + trendline="ols", trendline_options=dict(log_x=True), + title="Log-scaled X axis and log-transformed fit") +fig.show() +``` + +### Locally WEighted Scatterplot Smoothing (LOWESS) + +Plotly Express also supports non-linear [LOWESS](https://en.wikipedia.org/wiki/Local_regression) trendlines. In order use this feature, you will need to [install `statsmodels` and its dependencies](https://www.statsmodels.org/stable/install.html). + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="lowess") +fig.show() +``` + +_new in v5.2_ + +The level of smoothing can be controlled via the `frac` trendline option, which indicates the fraction of the data that the LOWESS smoother should include. The default is a fairly smooth line with `frac=0.6666` and lowering this fraction will give a line that more closely follows the data. + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="lowess", trendline_options=dict(frac=0.1)) +fig.show() +``` + +### Moving Averages + +_new in v5.2_ + +Plotly Express can leverage Pandas' [`rolling`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.rolling.html), [`ewm`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.ewm.html) and [`expanding`](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.expanding.html) functions in trendlines as well, for example to display moving averages. Values passed to `trendline_options` are passed directly to the underlying Pandas function (with the exception of the `function` and `function_options` keys, see below). + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="rolling", trendline_options=dict(window=5), + title="5-point moving average") +fig.show() +``` + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="ewm", trendline_options=dict(halflife=2), + title="Exponentially-weighted moving average (halflife of 2 points)") +fig.show() +``` + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="expanding", title="Expanding mean") +fig.show() +``` + +### Other Functions + +The `rolling`, `expanding` and `ewm` trendlines support other functions than the default `mean`, enabling, for example, a moving-median trendline, or an expanding-max trendline. ```python import plotly.express as px -df = px.data.gapminder().query("year == 2007") -fig = px.scatter(df, x="gdpPercap", y="lifeExp", color="continent", trendline="lowess") +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="rolling", trendline_options=dict(function="median", window=5), + title="Rolling Median") fig.show() ``` + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="expanding", trendline_options=dict(function="max"), + title="Expanding Maximum") +fig.show() +``` + +In some cases, it is necessary to pass options into the underying Pandas function, for example the `std` parameter must be provided if the `win_type` argument to `rolling` is `"gaussian"`. This is possible with the `function_args` trendline option. + +```python +import plotly.express as px + +df = px.data.stocks(datetimes=True) +fig = px.scatter(df, x="date", y="GOOG", trendline="rolling", + trendline_options=dict(window=5, win_type="gaussian", function_args=dict(std=2)), + title="Rolling Mean with Gaussian Window") +fig.show() +``` + +### Displaying only the trendlines + +In some cases, it may be desirable to show only the trendlines, by removing the scatter points. + +```python +import plotly.express as px + +df = px.data.stocks(indexed=True, datetimes=True) +fig = px.scatter(df, trendline="rolling", trendline_options=dict(window=5), + title="5-point moving average") +fig.data = [t for t in fig.data if t.mode == "lines"] +fig.update_traces(showlegend=True) #trendlines have showlegend=False by default +fig.show() +``` + +```python + +``` diff --git a/packages/python/plotly/plotly/express/__init__.py b/packages/python/plotly/plotly/express/__init__.py index 140a0fbe814..8bc5da53910 100644 --- a/packages/python/plotly/plotly/express/__init__.py +++ b/packages/python/plotly/plotly/express/__init__.py @@ -60,7 +60,7 @@ from ._special_inputs import IdentityMap, Constant, Range # noqa: F401 -from . import data, colors # noqa: F401 +from . import data, colors, trendline_functions # noqa: F401 __all__ = [ "scatter", @@ -100,6 +100,7 @@ "imshow", "data", "colors", + "trendline_functions", "set_mapbox_access_token", "get_trendline_results", "IdentityMap", diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 6cfb6a90367..f335e78de34 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -46,7 +46,9 @@ def scatter( marginal_x=None, marginal_y=None, trendline=None, + trendline_options=None, trendline_color_override=None, + trendline_scope="trace", log_x=False, log_y=False, range_x=None, @@ -90,7 +92,9 @@ def density_contour( marginal_x=None, marginal_y=None, trendline=None, + trendline_options=None, trendline_color_override=None, + trendline_scope="trace", log_x=False, log_y=False, range_x=None, diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index f8e391053b9..cc0e98375b2 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -2,6 +2,7 @@ import plotly.io as pio from collections import namedtuple, OrderedDict from ._special_inputs import IdentityMap, Constant, Range +from .trendline_functions import ols, lowess, rolling, expanding, ewm from _plotly_utils.basevalidators import ColorscaleValidator from plotly.colors import qualitative, sequential @@ -16,6 +17,9 @@ ) NO_COLOR = "px_no_color_constant" +trendline_functions = dict( + lowess=lowess, rolling=rolling, ewm=ewm, expanding=expanding, ols=ols +) # Declare all supported attributes, across all plot types direct_attrables = ( @@ -313,12 +317,10 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): mapping_labels["count"] = "%{x}" elif attr_name == "trendline": if ( - attr_value in ["ols", "lowess"] - and args["x"] + args["x"] and args["y"] and len(trace_data[[args["x"], args["y"]]].dropna()) > 1 ): - import statsmodels.api as sm # sorting is bad but trace_specs with "trendline" have no other attrs sorted_trace_data = trace_data.sort_values(by=args["x"]) @@ -345,37 +347,27 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): ) # preserve original values of "x" in case they're dates - trace_patch["x"] = sorted_trace_data[args["x"]][ - np.logical_not(np.logical_or(np.isnan(y), np.isnan(x))) - ] - - if attr_value == "lowess": - # missing ='drop' is the default value for lowess but not for OLS (None) - # we force it here in case statsmodels change their defaults - trendline = sm.nonparametric.lowess(y, x, missing="drop") - trace_patch["y"] = trendline[:, 1] - hover_header = "LOWESS trendline

" - elif attr_value == "ols": - fit_results = sm.OLS( - y, sm.add_constant(x), missing="drop" - ).fit() - trace_patch["y"] = fit_results.predict() - hover_header = "OLS trendline
" - if len(fit_results.params) == 2: - hover_header += "%s = %g * %s + %g
" % ( - args["y"], - fit_results.params[1], - args["x"], - fit_results.params[0], - ) - else: - hover_header += "%s = %g
" % ( - args["y"], - fit_results.params[0], - ) - hover_header += ( - "R2=%f

" % fit_results.rsquared - ) + # otherwise numpy/pandas can mess with the timezones + # NB this means trendline functions must output one-to-one with the input series + # i.e. we can't do resampling, because then the X values might not line up! + non_missing = np.logical_not( + np.logical_or(np.isnan(y), np.isnan(x)) + ) + trace_patch["x"] = sorted_trace_data[args["x"]][non_missing] + trendline_function = trendline_functions[attr_value] + y_out, hover_header, fit_results = trendline_function( + args["trendline_options"], + sorted_trace_data[args["x"]], + x, + y, + args["x"], + args["y"], + non_missing, + ) + assert len(y_out) == len( + trace_patch["x"] + ), "missing-data-handling failure in trendline code" + trace_patch["y"] = y_out mapping_labels[get_label(args, args["x"])] = "%{x}" mapping_labels[get_label(args, args["y"])] = "%{y} (trend)" elif attr_name.startswith("error"): @@ -878,21 +870,25 @@ def make_trace_spec(args, constructor, attrs, trace_patch): result.append(trace_spec) # Add trendline trace specifications - if "trendline" in args and args["trendline"]: - trace_spec = TraceSpec( - constructor=go.Scattergl if constructor == go.Scattergl else go.Scatter, - attrs=["trendline"], - trace_patch=dict(mode="lines"), - marginal=None, - ) - if args["trendline_color_override"]: - trace_spec.trace_patch["line"] = dict( - color=args["trendline_color_override"] - ) - result.append(trace_spec) + if args.get("trendline") and args.get("trendline_scope", "trace") == "trace": + result.append(make_trendline_spec(args, constructor)) return result +def make_trendline_spec(args, constructor): + trace_spec = TraceSpec( + constructor=go.Scattergl + if constructor == go.Scattergl # could be contour + else go.Scatter, + attrs=["trendline"], + trace_patch=dict(mode="lines"), + marginal=None, + ) + if args["trendline_color_override"]: + trace_spec.trace_patch["line"] = dict(color=args["trendline_color_override"]) + return trace_spec + + def one_group(x): return "" @@ -1827,6 +1823,16 @@ def infer_config(args, constructor, trace_patch, layout_patch): ): args["facet_col_wrap"] = 0 + if "trendline" in args and args["trendline"] is not None: + if args["trendline"] not in trendline_functions: + raise ValueError( + "Value '%s' for `trendline` must be one of %s" + % (args["trendline"], trendline_functions.keys()) + ) + + if "trendline_options" in args and args["trendline_options"] is None: + args["trendline_options"] = dict() + # Compute applicable grouping attributes for k in group_attrables: if k in args: @@ -2126,6 +2132,27 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None): fig.update_layout(template=args["template"], overwrite=True) fig.frames = frame_list if len(frames) > 1 else [] + if args.get("trendline") and args.get("trendline_scope", "trace") == "overall": + trendline_spec = make_trendline_spec(args, constructor) + trendline_trace = trendline_spec.constructor( + name="Overall Trendline", legendgroup="Overall Trendline", showlegend=False + ) + if "line" not in trendline_spec.trace_patch: # no color override + for m in grouped_mappings: + if m.variable == "color": + next_color = m.sequence[len(m.val_map) % len(m.sequence)] + trendline_spec.trace_patch["line"] = dict(color=next_color) + patch, fit_results = make_trace_kwargs( + args, trendline_spec, args["data_frame"], {}, sizeref + ) + trendline_trace.update(patch) + fig.add_trace( + trendline_trace, row="all", col="all", exclude_empty_subplots=True + ) + fig.update_traces(selector=-1, showlegend=True) + if fit_results is not None: + trendline_rows.append(dict(px_fit_results=fit_results)) + fig._px_trendlines = pd.DataFrame(trendline_rows) configure_axes(args, constructor, fig, orders) diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index f2f2ab0544d..65d9f0588ff 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -402,14 +402,28 @@ ], trendline=[ "str", - "One of `'ols'` or `'lowess'`.", + "One of `'ols'`, `'lowess'`, `'rolling'`, `'expanding'` or `'ewm'`.", "If `'ols'`, an Ordinary Least Squares regression line will be drawn for each discrete-color/symbol group.", "If `'lowess`', a Locally Weighted Scatterplot Smoothing line will be drawn for each discrete-color/symbol group.", + "If `'rolling`', a Rolling (e.g. rolling average, rolling median) line will be drawn for each discrete-color/symbol group.", + "If `'expanding`', an Expanding (e.g. expanding average, expanding sum) line will be drawn for each discrete-color/symbol group.", + "If `'ewm`', an Exponentially Weighted Moment (e.g. exponentially-weighted moving average) line will be drawn for each discrete-color/symbol group.", + "See the docstrings for the functions in `plotly.express.trendline_functions` for more details on these functions and how", + "to configure them with the `trendline_options` argument.", + ], + trendline_options=[ + "dict", + "Options passed as the first argument to the function from `plotly.express.trendline_functions` ", + "named in the `trendline` argument.", ], trendline_color_override=[ "str", "Valid CSS color.", - "If provided, and if `trendline` is set, all trendlines will be drawn in this color.", + "If provided, and if `trendline` is set, all trendlines will be drawn in this color rather than in the same color as the traces from which they draw their inputs.", + ], + trendline_scope=[ + "str (one of `'trace'` or `'overall'`, default `'trace'`)", + "If `'trace'`, then one trendline is drawn per trace (i.e. per color, symbol, facet, animation frame etc) and if `'overall'` then one trendline is computed for the entire dataset, and replicated across all facets.", ], render_mode=[ "str", diff --git a/packages/python/plotly/plotly/express/trendline_functions/__init__.py b/packages/python/plotly/plotly/express/trendline_functions/__init__.py new file mode 100644 index 00000000000..f0fc29cee4b --- /dev/null +++ b/packages/python/plotly/plotly/express/trendline_functions/__init__.py @@ -0,0 +1,155 @@ +""" +The `trendline_functions` module contains functions which are called by Plotly Express +when the `trendline` argument is used. Valid values for `trendline` are the names of the +functions in this module, and the value of the `trendline_options` argument to PX +functions is passed in as the first argument to these functions when called. + +Note that the functions in this module are not meant to be called directly, and are +exposed as part of the public API for documentation purposes. +""" + +import pandas as pd +import numpy as np + + +def ols(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Ordinary Least Squares (OLS) trendline function + + Requires `statsmodels` to be installed. + + This trendline function causes fit results to be stored within the figure, + accessible via the `plotly.express.get_trendline_results` function. The fit results + are the output of the `statsmodels.api.OLS` function. + + Valid keys for the `trendline_options` dict are: + + - `add_constant` (`bool`, default `True`): if `False`, the trendline passes through + the origin but if `True` a y-intercept is fitted. + + - `log_x` and `log_y` (`bool`, default `False`): if `True` the OLS is computed with + respect to the base 10 logarithm of the input. Note that this means no zeros can + be present in the input. + """ + valid_options = ["add_constant", "log_x", "log_y"] + for k in trendline_options.keys(): + if k not in valid_options: + raise ValueError( + "OLS trendline_options keys must be one of [%s] but got '%s'" + % (", ".join(valid_options), k) + ) + + import statsmodels.api as sm + + add_constant = trendline_options.get("add_constant", True) + log_x = trendline_options.get("log_x", False) + log_y = trendline_options.get("log_y", False) + + if log_y: + if np.any(y <= 0): + raise ValueError( + "Can't do OLS trendline with `log_y=True` when `y` contains non-positive values." + ) + y = np.log10(y) + y_label = "log10(%s)" % y_label + if log_x: + if np.any(x <= 0): + raise ValueError( + "Can't do OLS trendline with `log_x=True` when `x` contains non-positive values." + ) + x = np.log10(x) + x_label = "log10(%s)" % x_label + if add_constant: + x = sm.add_constant(x) + fit_results = sm.OLS(y, x, missing="drop").fit() + y_out = fit_results.predict() + if log_y: + y_out = np.power(10, y_out) + hover_header = "OLS trendline
" + if len(fit_results.params) == 2: + hover_header += "%s = %g * %s + %g
" % ( + y_label, + fit_results.params[1], + x_label, + fit_results.params[0], + ) + elif not add_constant: + hover_header += "%s = %g * %s
" % (y_label, fit_results.params[0], x_label) + else: + hover_header += "%s = %g
" % (y_label, fit_results.params[0]) + hover_header += "R2=%f

" % fit_results.rsquared + return y_out, hover_header, fit_results + + +def lowess(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """LOcally WEighted Scatterplot Smoothing (LOWESS) trendline function + + Requires `statsmodels` to be installed. + + Valid keys for the `trendline_options` dict are: + + - `frac` (`float`, default `0.6666666`): the `frac` parameter from the + `statsmodels.api.nonparametric.lowess` function + """ + + valid_options = ["frac"] + for k in trendline_options.keys(): + if k not in valid_options: + raise ValueError( + "LOWESS trendline_options keys must be one of [%s] but got '%s'" + % (", ".join(valid_options), k) + ) + + import statsmodels.api as sm + + frac = trendline_options.get("frac", 0.6666666) + y_out = sm.nonparametric.lowess(y, x, missing="drop", frac=frac)[:, 1] + hover_header = "LOWESS trendline

" + return y_out, hover_header, None + + +def _pandas(mode, trendline_options, x_raw, y, non_missing): + modes = dict(rolling="Rolling", ewm="Exponentially Weighted", expanding="Expanding") + trendline_options = trendline_options.copy() + function_name = trendline_options.pop("function", "mean") + function_args = trendline_options.pop("function_args", dict()) + series = pd.Series(y, index=x_raw) + agg = getattr(series, mode) # e.g. series.rolling + agg_obj = agg(**trendline_options) # e.g. series.rolling(**opts) + function = getattr(agg_obj, function_name) # e.g. series.rolling(**opts).mean + y_out = function(**function_args) # e.g. series.rolling(**opts).mean(**opts) + y_out = y_out[non_missing] + hover_header = "%s %s trendline

" % (modes[mode], function_name) + return y_out, hover_header, None + + +def rolling(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Rolling trendline function + + The value of the `function` key of the `trendline_options` dict is the function to + use (defaults to `mean`) and the value of the `function_args` key are taken to be + its arguments as a dict. The remainder of the `trendline_options` dict is passed as + keyword arguments into the `pandas.Series.rolling` function. + """ + return _pandas("rolling", trendline_options, x_raw, y, non_missing) + + +def expanding(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Expanding trendline function + + The value of the `function` key of the `trendline_options` dict is the function to + use (defaults to `mean`) and the value of the `function_args` key are taken to be + its arguments as a dict. The remainder of the `trendline_options` dict is passed as + keyword arguments into the `pandas.Series.expanding` function. + """ + return _pandas("expanding", trendline_options, x_raw, y, non_missing) + + +def ewm(trendline_options, x_raw, x, y, x_label, y_label, non_missing): + """Exponentially Weighted Moment (EWM) trendline function + + The value of the `function` key of the `trendline_options` dict is the function to + use (defaults to `mean`) and the value of the `function_args` key are taken to be + its arguments as a dict. The remainder of the `trendline_options` dict is passed as + keyword arguments into the `pandas.Series.ewm` function. + """ + return _pandas("ewm", trendline_options, x_raw, y, non_missing) diff --git a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py index 41064bd19df..66046981eff 100644 --- a/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py +++ b/packages/python/plotly/plotly/tests/test_optional/test_px/test_trendline.py @@ -5,10 +5,27 @@ from datetime import datetime -@pytest.mark.parametrize("mode", ["ols", "lowess"]) -def test_trendline_results_passthrough(mode): +@pytest.mark.parametrize( + "mode,options", + [ + ("ols", None), + ("lowess", None), + ("lowess", dict(frac=0.3)), + ("rolling", dict(window=2)), + ("expanding", None), + ("ewm", dict(alpha=0.5)), + ], +) +def test_trendline_results_passthrough(mode, options): df = px.data.gapminder().query("continent == 'Oceania'") - fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode) + fig = px.scatter( + df, + x="year", + y="pop", + color="country", + trendline=mode, + trendline_options=options, + ) assert len(fig.data) == 4 for trace in fig["data"][0::2]: assert "trendline" not in trace.hovertemplate @@ -20,93 +37,205 @@ def test_trendline_results_passthrough(mode): if mode == "ols": assert len(results) == 2 assert results["country"].values[0] == "Australia" - assert results["country"].values[0] == "Australia" au_result = results["px_fit_results"].values[0] assert len(au_result.params) == 2 else: assert len(results) == 0 -@pytest.mark.parametrize("mode", ["ols", "lowess"]) -def test_trendline_enough_values(mode): - fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode) +@pytest.mark.parametrize( + "mode,options", + [ + ("ols", None), + ("lowess", None), + ("lowess", dict(frac=0.3)), + ("rolling", dict(window=2)), + ("expanding", None), + ("ewm", dict(alpha=0.5)), + ], +) +def test_trendline_enough_values(mode, options): + fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode, trendline_options=options) assert len(fig.data) == 2 assert len(fig.data[1].x) == 2 - fig = px.scatter(x=[0], y=[0], trendline=mode) + fig = px.scatter(x=[0], y=[0], trendline=mode, trendline_options=options) assert len(fig.data) == 2 assert fig.data[1].x is None - fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode) + fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode, trendline_options=options) assert len(fig.data) == 2 assert fig.data[1].x is None - fig = px.scatter(x=[0, 1], y=np.array([0, np.nan]), trendline=mode) + fig = px.scatter( + x=[0, 1], y=np.array([0, np.nan]), trendline=mode, trendline_options=options + ) assert len(fig.data) == 2 assert fig.data[1].x is None - fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode) + fig = px.scatter( + x=[0, 1, None], y=[0, None, 1], trendline=mode, trendline_options=options + ) assert len(fig.data) == 2 assert fig.data[1].x is None fig = px.scatter( - x=np.array([0, 1, np.nan]), y=np.array([0, np.nan, 1]), trendline=mode + x=np.array([0, 1, np.nan]), + y=np.array([0, np.nan, 1]), + trendline=mode, + trendline_options=options, ) assert len(fig.data) == 2 assert fig.data[1].x is None - fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode) + fig = px.scatter( + x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode, trendline_options=options + ) assert len(fig.data) == 2 assert len(fig.data[1].x) == 2 fig = px.scatter( - x=np.array([0, 1, np.nan, 2]), y=np.array([1, np.nan, 1, 2]), trendline=mode + x=np.array([0, 1, np.nan, 2]), + y=np.array([1, np.nan, 1, 2]), + trendline=mode, + trendline_options=options, ) assert len(fig.data) == 2 assert len(fig.data[1].x) == 2 -@pytest.mark.parametrize("mode", ["ols", "lowess"]) -def test_trendline_nan_values(mode): +@pytest.mark.parametrize( + "mode,options", + [ + ("ols", None), + ("ols", dict(add_constant=False, log_x=True, log_y=True)), + ("lowess", None), + ("lowess", dict(frac=0.3)), + ("rolling", dict(window=2)), + ("expanding", None), + ("ewm", dict(alpha=0.5)), + ], +) +def test_trendline_nan_values(mode, options): df = px.data.gapminder().query("continent == 'Oceania'") start_date = 1970 df["pop"][df["year"] < start_date] = np.nan - fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode) + fig = px.scatter( + df, + x="year", + y="pop", + color="country", + trendline=mode, + trendline_options=options, + ) for trendline in fig["data"][1::2]: assert trendline.x[0] >= start_date assert len(trendline.x) == len(trendline.y) -def test_no_slope_ols_trendline(): +def test_ols_trendline_slopes(): fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols") - assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number) + # should be "y = 1 * x + 0" but sometimes is some tiny number instead + assert "y = 1 * x + " in fig.data[1].hovertemplate results = px.get_trendline_results(fig) params = results["px_fit_results"].iloc[0].params assert np.all(np.isclose(params, [0, 1])) + fig = px.scatter(x=[0, 1], y=[1, 2], trendline="ols") + assert "y = 1 * x + 1
" in fig.data[1].hovertemplate + results = px.get_trendline_results(fig) + params = results["px_fit_results"].iloc[0].params + assert np.all(np.isclose(params, [1, 1])) + + fig = px.scatter( + x=[0, 1], y=[1, 2], trendline="ols", trendline_options=dict(add_constant=False) + ) + assert "y = 2 * x
" in fig.data[1].hovertemplate + results = px.get_trendline_results(fig) + params = results["px_fit_results"].iloc[0].params + assert np.all(np.isclose(params, [2])) + + fig = px.scatter( + x=[1, 1], y=[0, 0], trendline="ols", trendline_options=dict(add_constant=False) + ) + assert "y = 0 * x
" in fig.data[1].hovertemplate + results = px.get_trendline_results(fig) + params = results["px_fit_results"].iloc[0].params + assert np.all(np.isclose(params, [0])) + fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols") - assert "y = 0" in fig.data[1].hovertemplate + assert "y = 0
" in fig.data[1].hovertemplate results = px.get_trendline_results(fig) params = results["px_fit_results"].iloc[0].params assert np.all(np.isclose(params, [0])) fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols") - assert "y = 0" in fig.data[1].hovertemplate + assert "y = 0 * x + 0
" in fig.data[1].hovertemplate fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols") - assert "y = 0 * x + 1" in fig.data[1].hovertemplate + assert "y = 0 * x + 1
" in fig.data[1].hovertemplate fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols") - assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate + assert "y = 0 * x + 1.5
" in fig.data[1].hovertemplate -@pytest.mark.parametrize("mode", ["ols", "lowess"]) -def test_trendline_on_timeseries(mode): +@pytest.mark.parametrize( + "mode,options", + [ + ("ols", None), + ("lowess", None), + ("lowess", dict(frac=0.3)), + ("rolling", dict(window=2)), + ("rolling", dict(window="10d")), + ("expanding", None), + ("ewm", dict(alpha=0.5)), + ], +) +def test_trendline_on_timeseries(mode, options): df = px.data.stocks() with pytest.raises(ValueError) as err_msg: - px.scatter(df, x="date", y="GOOG", trendline=mode) + px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options) assert "Could not convert value of 'x' ('date') into a numeric type." in str( err_msg.value ) df["date"] = pd.to_datetime(df["date"]) df["date"] = df["date"].dt.tz_localize("CET") # force a timezone - fig = px.scatter(df, x="date", y="GOOG", trendline=mode) + fig = px.scatter(df, x="date", y="GOOG", trendline=mode, trendline_options=options) assert len(fig.data) == 2 assert len(fig.data[0].x) == len(fig.data[1].x) assert type(fig.data[0].x[0]) == datetime assert type(fig.data[1].x[0]) == datetime assert np.all(fig.data[0].x == fig.data[1].x) assert str(fig.data[0].x[0]) == str(fig.data[1].x[0]) + + +def test_overall_trendline(): + df = px.data.tips() + fig1 = px.scatter(df, x="total_bill", y="tip", trendline="ols") + assert len(fig1.data) == 2 + assert "trendline" in fig1.data[1].hovertemplate + results1 = px.get_trendline_results(fig1) + params1 = results1["px_fit_results"].iloc[0].params + + fig2 = px.scatter( + df, + x="total_bill", + y="tip", + color="sex", + trendline="ols", + trendline_scope="overall", + ) + assert len(fig2.data) == 3 + assert "trendline" in fig2.data[2].hovertemplate + results2 = px.get_trendline_results(fig2) + params2 = results2["px_fit_results"].iloc[0].params + + assert np.all(np.array_equal(params1, params2)) + + fig3 = px.scatter( + df, + x="total_bill", + y="tip", + facet_row="sex", + trendline="ols", + trendline_scope="overall", + ) + assert len(fig3.data) == 4 + assert "trendline" in fig3.data[3].hovertemplate + results3 = px.get_trendline_results(fig3) + params3 = results3["px_fit_results"].iloc[0].params + + assert np.all(np.array_equal(params1, params3))