Skip to content

Commit 8c75004

Browse files
authored
Merge pull request #4470 from plotly/pass-b64
Use plotly.js `base64` API to store and pass typed arrays declared by numpy, pandas, etc.
2 parents 7c24d87 + f481af7 commit 8c75004

File tree

6 files changed

+200
-36
lines changed

6 files changed

+200
-36
lines changed

Diff for: CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
All notable changes to this project will be documented in this file.
33
This project adheres to [Semantic Versioning](http://semver.org/).
44

5+
### Updated
6+
7+
- Updated plotly.py to use base64 encoding of arrays in plotly JSON to improve performance.
8+
59
## [5.24.1] - 2024-09-12
610

711
### Updated

Diff for: packages/python/plotly/_plotly_utils/utils.py

+105-1
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,115 @@
1+
import base64
12
import decimal
23
import json as _json
34
import sys
45
import re
56
from functools import reduce
67

78
from _plotly_utils.optional_imports import get_module
8-
from _plotly_utils.basevalidators import ImageUriValidator
9+
from _plotly_utils.basevalidators import (
10+
ImageUriValidator,
11+
copy_to_readonly_numpy_array,
12+
is_homogeneous_array,
13+
)
14+
15+
16+
int8min = -128
17+
int8max = 127
18+
int16min = -32768
19+
int16max = 32767
20+
int32min = -2147483648
21+
int32max = 2147483647
22+
23+
uint8max = 255
24+
uint16max = 65535
25+
uint32max = 4294967295
26+
27+
plotlyjsShortTypes = {
28+
"int8": "i1",
29+
"uint8": "u1",
30+
"int16": "i2",
31+
"uint16": "u2",
32+
"int32": "i4",
33+
"uint32": "u4",
34+
"float32": "f4",
35+
"float64": "f8",
36+
}
37+
38+
39+
def to_typed_array_spec(v):
40+
"""
41+
Convert numpy array to plotly.js typed array spec
42+
If not possible return the original value
43+
"""
44+
v = copy_to_readonly_numpy_array(v)
45+
46+
np = get_module("numpy", should_load=False)
47+
if not np or not isinstance(v, np.ndarray):
48+
return v
49+
50+
dtype = str(v.dtype)
51+
52+
# convert default Big Ints until we could support them in plotly.js
53+
if dtype == "int64":
54+
max = v.max()
55+
min = v.min()
56+
if max <= int8max and min >= int8min:
57+
v = v.astype("int8")
58+
elif max <= int16max and min >= int16min:
59+
v = v.astype("int16")
60+
elif max <= int32max and min >= int32min:
61+
v = v.astype("int32")
62+
else:
63+
return v
64+
65+
elif dtype == "uint64":
66+
max = v.max()
67+
min = v.min()
68+
if max <= uint8max and min >= 0:
69+
v = v.astype("uint8")
70+
elif max <= uint16max and min >= 0:
71+
v = v.astype("uint16")
72+
elif max <= uint32max and min >= 0:
73+
v = v.astype("uint32")
74+
else:
75+
return v
76+
77+
dtype = str(v.dtype)
78+
79+
if dtype in plotlyjsShortTypes:
80+
arrObj = {
81+
"dtype": plotlyjsShortTypes[dtype],
82+
"bdata": base64.b64encode(v).decode("ascii"),
83+
}
84+
85+
if v.ndim > 1:
86+
arrObj["shape"] = str(v.shape)[1:-1]
87+
88+
return arrObj
89+
90+
return v
91+
92+
93+
def is_skipped_key(key):
94+
"""
95+
Return whether the key is skipped for conversion to the typed array spec
96+
"""
97+
skipped_keys = ["geojson", "layer", "layers", "range"]
98+
return any(skipped_key == key for skipped_key in skipped_keys)
99+
100+
101+
def convert_to_base64(obj):
102+
if isinstance(obj, dict):
103+
for key, value in obj.items():
104+
if is_skipped_key(key):
105+
continue
106+
elif is_homogeneous_array(value):
107+
obj[key] = to_typed_array_spec(value)
108+
else:
109+
convert_to_base64(value)
110+
elif isinstance(obj, list) or isinstance(obj, tuple):
111+
for value in obj:
112+
convert_to_base64(value)
9113

10114

11115
def cumsum(x):

Diff for: packages/python/plotly/plotly/basedatatypes.py

+4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
display_string_positions,
1616
chomp_empty_strings,
1717
find_closest_string,
18+
convert_to_base64,
1819
)
1920
from _plotly_utils.exceptions import PlotlyKeyError
2021
from .optional_imports import get_module
@@ -3310,6 +3311,9 @@ def to_dict(self):
33103311
if frames:
33113312
res["frames"] = frames
33123313

3314+
# Add base64 conversion before sending to the front-end
3315+
convert_to_base64(res)
3316+
33133317
return res
33143318

33153319
def to_plotly_json(self):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import json
2+
from unittest import TestCase
3+
import numpy as np
4+
from plotly.tests.test_optional.optional_utils import NumpyTestUtilsMixin
5+
import plotly.graph_objs as go
6+
7+
8+
class TestShouldNotUseBase64InUnsupportedKeys(NumpyTestUtilsMixin, TestCase):
9+
def test_np_geojson(self):
10+
normal_coordinates = [
11+
[
12+
[-87, 35],
13+
[-87, 30],
14+
[-85, 30],
15+
[-85, 35],
16+
]
17+
]
18+
19+
numpy_coordinates = np.array(normal_coordinates)
20+
21+
data = [
22+
{
23+
"type": "choropleth",
24+
"locations": ["AL"],
25+
"featureidkey": "properties.id",
26+
"z": np.array([10]),
27+
"geojson": {
28+
"type": "Feature",
29+
"properties": {"id": "AL"},
30+
"geometry": {"type": "Polygon", "coordinates": numpy_coordinates},
31+
},
32+
}
33+
]
34+
35+
fig = go.Figure(data=data)
36+
37+
assert (
38+
json.loads(fig.to_json())["data"][0]["geojson"]["geometry"]["coordinates"]
39+
== normal_coordinates
40+
)
41+
42+
def test_np_layers(self):
43+
layout = {
44+
"mapbox": {
45+
"layers": [
46+
{
47+
"sourcetype": "geojson",
48+
"type": "line",
49+
"line": {"dash": np.array([2.5, 1])},
50+
"source": {
51+
"type": "FeatureCollection",
52+
"features": [
53+
{
54+
"type": "Feature",
55+
"geometry": {
56+
"type": "LineString",
57+
"coordinates": np.array(
58+
[[0.25, 52], [0.75, 50]]
59+
),
60+
},
61+
}
62+
],
63+
},
64+
},
65+
],
66+
"center": {"lon": 0.5, "lat": 51},
67+
},
68+
}
69+
data = [{"type": "scattermapbox"}]
70+
71+
fig = go.Figure(data=data, layout=layout)
72+
73+
assert (fig.layout["mapbox"]["layers"][0]["line"]["dash"] == (2.5, 1)).all()
74+
75+
assert json.loads(fig.to_json())["layout"]["mapbox"]["layers"][0]["source"][
76+
"features"
77+
][0]["geometry"]["coordinates"] == [[0.25, 52], [0.75, 50]]
78+
79+
def test_np_range(self):
80+
layout = {"xaxis": {"range": np.array([0, 1])}}
81+
82+
fig = go.Figure(data=[{"type": "scatter"}], layout=layout)
83+
84+
assert json.loads(fig.to_json())["layout"]["xaxis"]["range"] == [0, 1]

Diff for: packages/python/plotly/plotly/tests/test_optional/test_px/test_px_functions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,15 @@ def _compare_figures(go_trace, px_fig):
2525
def test_pie_like_px():
2626
# Pie
2727
labels = ["Oxygen", "Hydrogen", "Carbon_Dioxide", "Nitrogen"]
28-
values = [4500, 2500, 1053, 500]
28+
values = np.array([4500, 2500, 1053, 500])
2929

3030
fig = px.pie(names=labels, values=values)
3131
trace = go.Pie(labels=labels, values=values)
3232
_compare_figures(trace, fig)
3333

3434
labels = ["Eve", "Cain", "Seth", "Enos", "Noam", "Abel", "Awan", "Enoch", "Azura"]
3535
parents = ["", "Eve", "Eve", "Seth", "Seth", "Eve", "Eve", "Awan", "Eve"]
36-
values = [10, 14, 12, 10, 2, 6, 6, 4, 4]
36+
values = np.array([10, 14, 12, 10, 2, 6, 6, 4, 4])
3737
# Sunburst
3838
fig = px.sunburst(names=labels, parents=parents, values=values)
3939
trace = go.Sunburst(labels=labels, parents=parents, values=values)
@@ -45,7 +45,7 @@ def test_pie_like_px():
4545

4646
# Funnel
4747
x = ["A", "B", "C"]
48-
y = [3, 2, 1]
48+
y = np.array([3, 2, 1])
4949
fig = px.funnel(y=y, x=x)
5050
trace = go.Funnel(y=y, x=x)
5151
_compare_figures(trace, fig)

Diff for: packages/python/plotly/plotly/tests/test_optional/test_utils/test_utils.py

-32
Original file line numberDiff line numberDiff line change
@@ -372,38 +372,6 @@ def test_invalid_encode_exception(self):
372372
with self.assertRaises(TypeError):
373373
_json.dumps({"a": {1}}, cls=utils.PlotlyJSONEncoder)
374374

375-
def test_fast_track_finite_arrays(self):
376-
# if NaN or Infinity is found in the json dump
377-
# of a figure, it is decoded and re-encoded to replace these values
378-
# with null. This test checks that NaN and Infinity values are
379-
# indeed converted to null, and that the encoding of figures
380-
# without inf or nan is faster (because we can avoid decoding
381-
# and reencoding).
382-
z = np.random.randn(100, 100)
383-
x = np.arange(100.0)
384-
fig_1 = go.Figure(go.Heatmap(z=z, x=x))
385-
t1 = time()
386-
json_str_1 = _json.dumps(fig_1, cls=utils.PlotlyJSONEncoder)
387-
t2 = time()
388-
x[0] = np.nan
389-
x[1] = np.inf
390-
fig_2 = go.Figure(go.Heatmap(z=z, x=x))
391-
t3 = time()
392-
json_str_2 = _json.dumps(fig_2, cls=utils.PlotlyJSONEncoder)
393-
t4 = time()
394-
assert t2 - t1 < t4 - t3
395-
assert "null" in json_str_2
396-
assert "NaN" not in json_str_2
397-
assert "Infinity" not in json_str_2
398-
x = np.arange(100.0)
399-
fig_3 = go.Figure(go.Heatmap(z=z, x=x))
400-
fig_3.update_layout(title_text="Infinity")
401-
t5 = time()
402-
json_str_3 = _json.dumps(fig_3, cls=utils.PlotlyJSONEncoder)
403-
t6 = time()
404-
assert t2 - t1 < t6 - t5
405-
assert "Infinity" in json_str_3
406-
407375

408376
class TestNumpyIntegerBaseType(TestCase):
409377
def test_numpy_integer_import(self):

0 commit comments

Comments
 (0)