diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5aca477f453..f6e9fdc3702 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -8,6 +8,7 @@ This project adheres to [Semantic Versioning](http://semver.org/).
- Fixed special cases with `px.sunburst` and `px.treemap` with `path` input ([#2524](https://github.com/plotly/plotly.py/pull/2524))
- Fixed bug in `hover_data` argument of `px` functions, when the column name is changed with labels and `hover_data` is a dictionary setting up a specific format for the hover data ([#2544](https://github.com/plotly/plotly.py/pull/2544)).
+- Made the Plotly Express `trendline` argument more robust and made it work with datetime `x` values ([#2554](https://github.com/plotly/plotly.py/pull/2554))
## [4.8.1] - 2020-05-28
diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py
index 5d6388ae59d..d89794a5a49 100644
--- a/packages/python/plotly/plotly/express/_core.py
+++ b/packages/python/plotly/plotly/express/_core.py
@@ -277,17 +277,35 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
attr_value in ["ols", "lowess"]
and args["x"]
and args["y"]
- and len(trace_data) > 1
+ and len(trace_data[[args["x"], args["y"]]].dropna()) > 1
):
import statsmodels.api as sm
# sorting is bad but trace_specs with "trendline" have no other attrs
sorted_trace_data = trace_data.sort_values(by=args["x"])
- y = sorted_trace_data[args["y"]]
- x = sorted_trace_data[args["x"]]
+ y = sorted_trace_data[args["y"]].values
+ x = sorted_trace_data[args["x"]].values
+ x_is_date = False
if x.dtype.type == np.datetime64:
x = x.astype(int) / 10 ** 9 # convert to unix epoch seconds
+ x_is_date = True
+ elif x.dtype.type == np.object_:
+ try:
+ x = x.astype(np.float64)
+ except ValueError:
+ raise ValueError(
+ "Could not convert value of 'x' ('%s') into a numeric type. "
+ "If 'x' contains stringified dates, please convert to a datetime column."
+ % args["x"]
+ )
+ if y.dtype.type == np.object_:
+ try:
+ y = y.astype(np.float64)
+ except ValueError:
+ raise ValueError(
+ "Could not convert value of 'y' into a numeric type."
+ )
if attr_value == "lowess":
# missing ='drop' is the default value for lowess but not for OLS (None)
@@ -298,25 +316,32 @@ def make_trace_kwargs(args, trace_spec, trace_data, mapping_labels, sizeref):
hover_header = "LOWESS trendline
"
elif attr_value == "ols":
fit_results = sm.OLS(
- y.values, sm.add_constant(x.values), missing="drop"
+ y, sm.add_constant(x), missing="drop"
).fit()
trace_patch["y"] = fit_results.predict()
trace_patch["x"] = x[
np.logical_not(np.logical_or(np.isnan(y), np.isnan(x)))
]
hover_header = "OLS trendline
"
- hover_header += "%s = %g * %s + %g
" % (
- args["y"],
- fit_results.params[1],
- args["x"],
- fit_results.params[0],
- )
+ if len(fit_results.params) == 2:
+ hover_header += "%s = %g * %s + %g
" % (
+ args["y"],
+ fit_results.params[1],
+ args["x"],
+ fit_results.params[0],
+ )
+ else:
+ hover_header += "%s = %g
" % (
+ args["y"],
+ fit_results.params[0],
+ )
hover_header += (
"R2=%f
" % fit_results.rsquared
)
+ if x_is_date:
+ trace_patch["x"] = pd.to_datetime(trace_patch["x"] * 10 ** 9)
mapping_labels[get_label(args, args["x"])] = "%{x}"
mapping_labels[get_label(args, args["y"])] = "%{y} (trend)"
-
elif attr_name.startswith("error"):
error_xy = attr_name[:7]
arr = "arrayminus" if attr_name.endswith("minus") else "array"
diff --git a/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py b/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py
index 4c151148c12..e908d7dee12 100644
--- a/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py
+++ b/packages/python/plotly/plotly/tests/test_core/test_px/test_trendline.py
@@ -1,14 +1,110 @@
import plotly.express as px
import numpy as np
+import pandas as pd
+import pytest
+from datetime import datetime
-def test_trendline_nan_values():
+@pytest.mark.parametrize("mode", ["ols", "lowess"])
+def test_trendline_results_passthrough(mode):
+ df = px.data.gapminder().query("continent == 'Oceania'")
+ fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
+ assert len(fig.data) == 4
+ for trace in fig["data"][0::2]:
+ assert "trendline" not in trace.hovertemplate
+ for trendline in fig["data"][1::2]:
+ assert "trendline" in trendline.hovertemplate
+ if mode == "ols":
+ assert "R2" in trendline.hovertemplate
+ results = px.get_trendline_results(fig)
+ if mode == "ols":
+ assert len(results) == 2
+ assert results["country"].values[0] == "Australia"
+ assert results["country"].values[0] == "Australia"
+ au_result = results["px_fit_results"].values[0]
+ assert len(au_result.params) == 2
+ else:
+ assert len(results) == 0
+
+
+@pytest.mark.parametrize("mode", ["ols", "lowess"])
+def test_trendline_enough_values(mode):
+ fig = px.scatter(x=[0, 1], y=[0, 1], trendline=mode)
+ assert len(fig.data) == 2
+ assert len(fig.data[1].x) == 2
+ fig = px.scatter(x=[0], y=[0], trendline=mode)
+ assert len(fig.data) == 2
+ assert fig.data[1].x is None
+ fig = px.scatter(x=[0, 1], y=[0, None], trendline=mode)
+ assert len(fig.data) == 2
+ assert fig.data[1].x is None
+ fig = px.scatter(x=[0, 1], y=np.array([0, np.nan]), trendline=mode)
+ assert len(fig.data) == 2
+ assert fig.data[1].x is None
+ fig = px.scatter(x=[0, 1, None], y=[0, None, 1], trendline=mode)
+ assert len(fig.data) == 2
+ assert fig.data[1].x is None
+ fig = px.scatter(
+ x=np.array([0, 1, np.nan]), y=np.array([0, np.nan, 1]), trendline=mode
+ )
+ assert len(fig.data) == 2
+ assert fig.data[1].x is None
+ fig = px.scatter(x=[0, 1, None, 2], y=[1, None, 1, 2], trendline=mode)
+ assert len(fig.data) == 2
+ assert len(fig.data[1].x) == 2
+ fig = px.scatter(
+ x=np.array([0, 1, np.nan, 2]), y=np.array([1, np.nan, 1, 2]), trendline=mode
+ )
+ assert len(fig.data) == 2
+ assert len(fig.data[1].x) == 2
+
+
+@pytest.mark.parametrize("mode", ["ols", "lowess"])
+def test_trendline_nan_values(mode):
df = px.data.gapminder().query("continent == 'Oceania'")
start_date = 1970
df["pop"][df["year"] < start_date] = np.nan
- modes = ["ols", "lowess"]
- for mode in modes:
- fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
- for trendline in fig["data"][1::2]:
- assert trendline.x[0] >= start_date
- assert len(trendline.x) == len(trendline.y)
+ fig = px.scatter(df, x="year", y="pop", color="country", trendline=mode)
+ for trendline in fig["data"][1::2]:
+ assert trendline.x[0] >= start_date
+ assert len(trendline.x) == len(trendline.y)
+
+
+def test_no_slope_ols_trendline():
+ fig = px.scatter(x=[0, 1], y=[0, 1], trendline="ols")
+ assert "y = 1" in fig.data[1].hovertemplate # then + x*(some small number)
+ results = px.get_trendline_results(fig)
+ params = results["px_fit_results"].iloc[0].params
+ assert np.all(np.isclose(params, [0, 1]))
+
+ fig = px.scatter(x=[1, 1], y=[0, 0], trendline="ols")
+ assert "y = 0" in fig.data[1].hovertemplate
+ results = px.get_trendline_results(fig)
+ params = results["px_fit_results"].iloc[0].params
+ assert np.all(np.isclose(params, [0]))
+
+ fig = px.scatter(x=[1, 2], y=[0, 0], trendline="ols")
+ assert "y = 0" in fig.data[1].hovertemplate
+ fig = px.scatter(x=[0, 0], y=[1, 1], trendline="ols")
+ assert "y = 0 * x + 1" in fig.data[1].hovertemplate
+ fig = px.scatter(x=[0, 0], y=[1, 2], trendline="ols")
+ assert "y = 0 * x + 1.5" in fig.data[1].hovertemplate
+
+
+@pytest.mark.parametrize("mode", ["ols", "lowess"])
+def test_trendline_on_timeseries(mode):
+ df = px.data.stocks()
+
+ with pytest.raises(ValueError) as err_msg:
+ px.scatter(df, x="date", y="GOOG", trendline=mode)
+ assert "Could not convert value of 'x' ('date') into a numeric type." in str(
+ err_msg.value
+ )
+
+ df["date"] = pd.to_datetime(df["date"])
+ fig = px.scatter(df, x="date", y="GOOG", trendline=mode)
+ assert len(fig.data) == 2
+ assert len(fig.data[0].x) == len(fig.data[1].x)
+ assert type(fig.data[0].x[0]) == datetime
+ assert type(fig.data[1].x[0]) == datetime
+ assert np.all(fig.data[0].x == fig.data[1].x)