diff --git a/CHANGELOG.md b/CHANGELOG.md index 5aca477f453..f6e9fdc3702 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/). - Fixed special cases with `px.sunburst` and `px.treemap` with `path` input ([#2524](https://github.com/plotly/plotly.py/pull/2524)) - Fixed bug in `hover_data` argument of `px` functions, when the column name is changed with labels and `hover_data` is a dictionary setting up a specific format for the hover data ([#2544](https://github.com/plotly/plotly.py/pull/2544)). +- Made the Plotly Express `trendline` argument more robust and made it work with datetime `x` values ([#2554](https://github.com/plotly/plotly.py/pull/2554)) ## [4.8.1] - 2020-05-28 diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 5d6388ae59d..d89794a5a49 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -277,17 +277,35 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): attr_value in ["ols", "lowess"] and args["x"] and args["y"] - and len(trace_data) > 1 + and len(trace_data[[args["x"], args["y"]]].dropna()) > 1 ): import statsmodels.api as sm # sorting is bad but trace_specs with "trendline" have no other attrs sorted_trace_data = trace_data.sort_values(by=args["x"]) - y = sorted_trace_data[args["y"]] - x = sorted_trace_data[args["x"]] + y = sorted_trace_data[args["y"]].values + x = sorted_trace_data[args["x"]].values + x_is_date = False if x.dtype.type == np.datetime64: x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds + x_is_date = True + elif x.dtype.type == np.object_: + try: + x = x.astype(np.float64) + except ValueError: + raise ValueError( + "Could not convert value of 'x' ('%s') into a numeric type. " + "If 'x' contains stringified dates, please convert to a datetime column." + % args["x"] + ) + if y.dtype.type == np.object_: + try: + y = y.astype(np.float64) + except ValueError: + raise ValueError( + "Could not convert value of 'y' into a numeric type." + ) if attr_value == "lowess": # missing ='drop' is the default value for lowess but not for OLS (None) @@ -298,25 +316,32 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref): hover_header = "LOWESS trendline

" elif attr_value == "ols": fit_results = sm.OLS( - y.values, sm.add_constant(x.values), missing="drop" + y, sm.add_constant(x), missing="drop" ).fit() trace_patch["y"] = fit_results.predict() trace_patch["x"] = x[ np.logical_not(np.logical_or(np.isnan(y), np.isnan(x))) ] hover_header = "OLS trendline
" - hover_header += "%s = %g * %s + %g
" % ( - args["y"], - fit_results.params[1], - args["x"], - fit_results.params[0], - ) + if len(fit_results.params) == 2: + hover_header += "%s = %g * %s + %g
" % ( + args["y"], + fit_results.params[1], + args["x"], + fit_results.params[0], + ) + else: + hover_header += "%s = %g
" % ( + args["y"], + fit_results.params[0], + ) hover_header += ( "R2=%f

" % fit_results.rsquared ) + if x_is_date: + trace_patch["x"] = pd.to_datetime(trace_patch["x"] * 10 ** 9) mapping_labels[get_label(args, args["x"])] = "%{x}" mapping_labels[get_label(args, args["y"])] = "%{y} (trend)" - elif attr_name.startswith("error"): error_xy = attr_name[:7] arr = "arrayminus" if attr_name.endswith("minus") else "array" diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py index 4c151148c12..e908d7dee12 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py +++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py @@ -1,14 +1,110 @@ import plotly.express as px import numpy as np +import pandas as pd +import pytest +from datetime import datetime -def test_trendline_nan_values(): +@pytest.mark.parametrize("mode", ["ols", "lowess"]) +def test_trendline_results_passthrough(mode): + df = px.data.gapminder().query("continent == 'Oceania'") + fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode) + assert len(fig.data) == 4 + for trace in fig["data"][0::2]: + assert "trendline" not in trace.hovertemplate + for trendline in fig["data"][1::2]: + assert "trendline" in trendline.hovertemplate + if mode == "ols": + assert "R2" in trendline.hovertemplate + results = px.get_trendline_results(fig) + if mode == "ols": + assert len(results) == 2 + assert results["country"].values[0] == "Australia" + assert results["country"].values[0] == "Australia" + au_result = results["px_fit_results"].values[0] + assert len(au_result.params) == 2 + else: + assert len(results) == 0 + + +@pytest.mark.parametrize("mode", ["ols", "lowess"]) +def test_trendline_enough_values(mode): + fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode) + assert len(fig.data) == 2 + assert len(fig.data[1].x) == 2 + fig = px.scatter(x=[0], y=[0], trendline=mode) + assert len(fig.data) == 2 + assert fig.data[1].x is None + fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode) + assert len(fig.data) == 2 + assert fig.data[1].x is None + fig = px.scatter(x=[0, 1], y=np.array([0, np.nan]), trendline=mode) + assert len(fig.data) == 2 + assert fig.data[1].x is None + fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode) + assert len(fig.data) == 2 + assert fig.data[1].x is None + fig = px.scatter( + x=np.array([0, 1, np.nan]), y=np.array([0, np.nan, 1]), trendline=mode + ) + assert len(fig.data) == 2 + assert fig.data[1].x is None + fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode) + assert len(fig.data) == 2 + assert len(fig.data[1].x) == 2 + fig = px.scatter( + x=np.array([0, 1, np.nan, 2]), y=np.array([1, np.nan, 1, 2]), trendline=mode + ) + assert len(fig.data) == 2 + assert len(fig.data[1].x) == 2 + + +@pytest.mark.parametrize("mode", ["ols", "lowess"]) +def test_trendline_nan_values(mode): df = px.data.gapminder().query("continent == 'Oceania'") start_date = 1970 df["pop"][df["year"] < start_date] = np.nan - modes = ["ols", "lowess"] - for mode in modes: - fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode) - for trendline in fig["data"][1::2]: - assert trendline.x[0] >= start_date - assert len(trendline.x) == len(trendline.y) + fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode) + for trendline in fig["data"][1::2]: + assert trendline.x[0] >= start_date + assert len(trendline.x) == len(trendline.y) + + +def test_no_slope_ols_trendline(): + fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols") + assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number) + results = px.get_trendline_results(fig) + params = results["px_fit_results"].iloc[0].params + assert np.all(np.isclose(params, [0, 1])) + + fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols") + assert "y = 0" in fig.data[1].hovertemplate + results = px.get_trendline_results(fig) + params = results["px_fit_results"].iloc[0].params + assert np.all(np.isclose(params, [0])) + + fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols") + assert "y = 0" in fig.data[1].hovertemplate + fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols") + assert "y = 0 * x + 1" in fig.data[1].hovertemplate + fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols") + assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate + + +@pytest.mark.parametrize("mode", ["ols", "lowess"]) +def test_trendline_on_timeseries(mode): + df = px.data.stocks() + + with pytest.raises(ValueError) as err_msg: + px.scatter(df, x="date", y="GOOG", trendline=mode) + assert "Could not convert value of 'x' ('date') into a numeric type." in str( + err_msg.value + ) + + df["date"] = pd.to_datetime(df["date"]) + fig = px.scatter(df, x="date", y="GOOG", trendline=mode) + assert len(fig.data) == 2 + assert len(fig.data[0].x) == len(fig.data[1].x) + assert type(fig.data[0].x[0]) == datetime + assert type(fig.data[1].x[0]) == datetime + assert np.all(fig.data[0].x == fig.data[1].x)