Skip to content

Commit 33701df

Browse files
bug_2410: Allowing Ints to be passed for rows/cols and refactored int… (#2546)
* bug_2410: Allowing Ints to be passed for rows/cols and refactored int checks * Update packages/python/plotly/plotly/basedatatypes.py Quick cosmetic update to the docstring Co-authored-by: Emmanuelle Gouillart <[email protected]> * Update packages/python/plotly/plotly/basedatatypes.py Quick cosmetic update to the docstring Co-authored-by: Emmanuelle Gouillart <[email protected]> Co-authored-by: Emmanuelle Gouillart <[email protected]>
1 parent 8be4915 commit 33701df

File tree

4 files changed

+81
-8
lines changed

4 files changed

+81
-8
lines changed

Diff for: packages/python/plotly/_plotly_utils/utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,12 @@ def key(v):
247247
return tuple(v_parts)
248248

249249
return sorted(vals, key=key, reverse=reverse)
250+
251+
252+
def _get_int_type():
253+
np = get_module("numpy", should_load=False)
254+
if np:
255+
int_type = (int, np.integer)
256+
else:
257+
int_type = (int,)
258+
return int_type

Diff for: packages/python/plotly/plotly/basedatatypes.py

+17-8
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from contextlib import contextmanager
1010
from copy import deepcopy, copy
1111

12-
from _plotly_utils.utils import _natural_sort_strings
12+
from _plotly_utils.utils import _natural_sort_strings, _get_int_type
1313
from .optional_imports import get_module
1414

1515
# Create Undefined sentinel value
@@ -1560,12 +1560,7 @@ def _validate_rows_cols(name, n, vals):
15601560
if len(vals) != n:
15611561
BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals)
15621562

1563-
try:
1564-
import numpy as np
1565-
1566-
int_type = (int, np.integer)
1567-
except ImportError:
1568-
int_type = (int,)
1563+
int_type = _get_int_type()
15691564

15701565
if [r for r in vals if not isinstance(r, int_type)]:
15711566
BaseFigure._raise_invalid_rows_cols(name=name, n=n, invalid=vals)
@@ -1677,14 +1672,19 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None):
16771672
- All remaining properties are passed to the constructor
16781673
of the specified trace type.
16791674
1680-
rows : None or list[int] (default None)
1675+
rows : None, list[int], or int (default None)
16811676
List of subplot row indexes (starting from 1) for the traces to be
16821677
added. Only valid if figure was created using
16831678
`plotly.tools.make_subplots`
1679+
If a single integer is passed, all traces will be added to row number
1680+
16841681
cols : None or list[int] (default None)
16851682
List of subplot column indexes (starting from 1) for the traces
16861683
to be added. Only valid if figure was created using
16871684
`plotly.tools.make_subplots`
1685+
If a single integer is passed, all traces will be added to column number
1686+
1687+
16881688
secondary_ys: None or list[boolean] (default None)
16891689
List of secondary_y booleans for traces to be added. See the
16901690
docstring for `add_trace` for more info.
@@ -1723,6 +1723,15 @@ def add_traces(self, data, rows=None, cols=None, secondary_ys=None):
17231723
for ind, new_trace in enumerate(data):
17241724
new_trace._trace_ind = ind + len(self.data)
17251725

1726+
# Allow integers as inputs to subplots
1727+
int_type = _get_int_type()
1728+
1729+
if isinstance(rows, int_type):
1730+
rows = [rows] * len(data)
1731+
1732+
if isinstance(cols, int_type):
1733+
cols = [cols] * len(data)
1734+
17261735
# Validate rows / cols
17271736
n = len(data)
17281737
BaseFigure._validate_rows_cols("rows", n, rows)

Diff for: packages/python/plotly/plotly/tests/test_core/test_figure_messages/test_add_traces.py

+30
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,33 @@ def test_add_traces(self):
6363
{"type": "histogram2dcontour", "line": {"color": "cyan"}},
6464
]
6565
)
66+
67+
68+
class TestAddTracesRowsColsDataTypes(TestCase):
69+
def test_add_traces_with_iterable(self):
70+
import plotly.express as px
71+
72+
df = px.data.tips()
73+
fig = px.scatter(df, x="total_bill", y="tip", color="day")
74+
from plotly.subplots import make_subplots
75+
76+
fig2 = make_subplots(1, 2)
77+
fig2.add_traces(fig.data, rows=[1,] * len(fig.data), cols=[1,] * len(fig.data))
78+
79+
expected_data_length = 4
80+
81+
self.assertEqual(expected_data_length, len(fig2.data))
82+
83+
def test_add_traces_with_integers(self):
84+
import plotly.express as px
85+
86+
df = px.data.tips()
87+
fig = px.scatter(df, x="total_bill", y="tip", color="day")
88+
from plotly.subplots import make_subplots
89+
90+
fig2 = make_subplots(1, 2)
91+
fig2.add_traces(fig.data, rows=1, cols=2)
92+
93+
expected_data_length = 4
94+
95+
self.assertEqual(expected_data_length, len(fig2.data))

Diff for: packages/python/plotly/plotly/tests/test_core/test_utils/test_utils.py

+25
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,28 @@ def test_numpy_integer_import(self):
7070
value = get_by_path(fig, data_path)
7171
expected_value = (1,)
7272
self.assertEqual(value, expected_value)
73+
74+
def test_get_numpy_int_type(self):
75+
import numpy as np
76+
from _plotly_utils.utils import _get_int_type
77+
78+
int_type_tuple = _get_int_type()
79+
expected_tuple = (int, np.integer)
80+
81+
self.assertEqual(int_type_tuple, expected_tuple)
82+
83+
84+
class TestNoNumpyIntegerBaseType(TestCase):
85+
def test_no_numpy_int_type(self):
86+
import sys
87+
from _plotly_utils.utils import _get_int_type
88+
from _plotly_utils.optional_imports import get_module
89+
90+
np = get_module("numpy", should_load=False)
91+
if np:
92+
sys.modules.pop("numpy")
93+
94+
int_type_tuple = _get_int_type()
95+
expected_tuple = (int,)
96+
97+
self.assertEqual(int_type_tuple, expected_tuple)

0 commit comments

Comments
 (0)