Skip to content

Commit a037099

Browse files
committed
hexbin in ff
1 parent b423876 commit a037099

File tree

4 files changed

+113
-54
lines changed

4 files changed

+113
-54
lines changed

packages/python/plotly/plotly/express/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
choropleth_mapbox,
4848
density_mapbox,
4949
)
50-
from ._hexbin_mapbox import hexbin_mapbox
5150

5251

5352
from ._core import ( # noqa: F401
@@ -101,5 +100,4 @@
101100
"IdentityMap",
102101
"Constant",
103102
"Range",
104-
"hexbin_mapbox",
105103
]

packages/python/plotly/plotly/express/_doc.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -514,12 +514,6 @@
514514
"Sets the number of rendered sectors from any given `level`. Set `maxdepth` to -1 to render all the"
515515
"levels in the hierarchy.",
516516
],
517-
agg_func=[
518-
"function",
519-
"Numpy array aggregator, it must take as input a 1D array",
520-
"and output a scalar value.",
521-
],
522-
gridsize=["int", "Number of hexagons (horizontally) to be created",],
523517
)
524518

525519

packages/python/plotly/plotly/figure_factory/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,13 @@
2929

3030
if optional_imports.get_module("pandas") is not None:
3131
from plotly.figure_factory._county_choropleth import create_choropleth
32+
from plotly.figure_factory._hexbin_mapbox import create_hexbin_mapbox
3233
else:
3334

3435
def create_choropleth(*args, **kwargs):
3536
raise ImportError("Please install pandas to use `create_choropleth`")
37+
def create_hexbin_mapbox(*args, **kwargs):
38+
raise ImportError("Please install pandas to use `create_hexbin_mapbox`")
3639

3740

3841
if optional_imports.get_module("skimage") is not None:
@@ -53,6 +56,7 @@ def create_ternary_contour(*args, **kwargs):
5356
"create_distplot",
5457
"create_facet_grid",
5558
"create_gantt",
59+
"create_hexbin_mapbox",
5660
"create_ohlc",
5761
"create_quiver",
5862
"create_scatterplotmatrix",

packages/python/plotly/plotly/express/_hexbin_mapbox.py renamed to packages/python/plotly/plotly/figure_factory/_hexbin_mapbox.py

Lines changed: 109 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
from ._core import make_figure, build_dataframe
2-
from ._doc import make_docstring, docs
3-
from ._chart_types import choropleth_mapbox
4-
import plotly.graph_objs as go
1+
from plotly.express._core import build_dataframe
2+
from plotly.express._doc import make_docstring
3+
from plotly.express._chart_types import choropleth_mapbox
54
import numpy as np
65
import pandas as pd
7-
import re
86

97

108
def _project_latlon_to_wgs84(lat, lon):
119
"""
12-
Projects lat and lon to WGS84 to get regular hexagons on a mapbox map
10+
Projects lat and lon to WGS84, used to get regular hexagons on a mapbox map
1311
"""
1412
x = lon * np.pi / 180
1513
y = np.arctanh(np.sin(lat * np.pi / 180))
@@ -18,7 +16,7 @@ def _project_latlon_to_wgs84(lat, lon):
1816

1917
def _project_wgs84_to_latlon(x, y):
2018
"""
21-
Projects lat and lon to WGS84 to get regular hexagons on a mapbox map
19+
Projects WGS84 to lat and lon, used to get regular hexagons on a mapbox map
2220
"""
2321
lon = x * 180 / np.pi
2422
lat = (2 * np.arctan(np.exp(y)) - np.pi / 2) * 180 / np.pi
@@ -55,16 +53,8 @@ def zoom(mapPx, worldPx, fraction):
5553

5654
return min(latZoom, lngZoom, ZOOM_MAX)
5755

58-
5956
def _compute_hexbin(
60-
lat=None,
61-
lon=None,
62-
lat_range=None,
63-
lon_range=None,
64-
color=None,
65-
nx=None,
66-
agg_func=None,
67-
min_count=None,
57+
x, y, x_range, y_range, color, nx, agg_func, min_count
6858
):
6959
"""
7060
Computes the aggregation at hexagonal bin level.
@@ -73,38 +63,36 @@ def _compute_hexbin(
7363
7464
Parameters
7565
----------
76-
lat : np.ndarray
77-
Array of latitudes
78-
lon : np.ndarray
79-
Array of longitudes
80-
lat_range : np.ndarray
81-
Min and max latitudes
82-
lon_range : np.ndarray
83-
Min and max longitudes
66+
x : np.ndarray
67+
Array of x values (shape N)
68+
y : np.ndarray
69+
Array of y values (shape N)
70+
x_range : np.ndarray
71+
Min and max x (shape 2)
72+
y_range : np.ndarray
73+
Min and max y (shape 2)
8474
color : np.ndarray
85-
Metric to aggregate at hexagon level
75+
Metric to aggregate at hexagon level (shape N)
8676
nx : int
8777
Number of hexagons horizontally
8878
agg_func : function
8979
Numpy compatible aggregator, this function must take a one-dimensional
9080
np.ndarray as input and output a scalar
91-
min_count : float
92-
Minimum value for which to display the aggregate
81+
min_count : int
82+
Minimum number of points in the hexagon for the hexagon to be displayed
9383
9484
Returns
9585
-------
86+
np.ndarray
87+
X coordinates of each hexagon (shape M x 6)
88+
np.ndarray
89+
Y coordinates of each hexagon (shape M x 6)
90+
np.ndarray
91+
Centers of the hexagons (shape M x 2)
92+
np.ndarray
93+
Aggregated value in each hexagon (shape M)
9694
9795
"""
98-
# Project to WGS 84
99-
x, y = _project_latlon_to_wgs84(lat, lon)
100-
101-
if lat_range is None:
102-
lat_range = np.array([lat.min(), lat.max()])
103-
if lon_range is None:
104-
lon_range = np.array([lon.min(), lon.max()])
105-
106-
x_range, y_range = _project_latlon_to_wgs84(lat_range, lon_range)
107-
10896
xmin = x_range.min()
10997
xmax = x_range.max()
11098
ymin = y_range.min()
@@ -224,6 +212,69 @@ def _compute_hexbin(
224212
hxs = np.array([hx] * m) * nx + np.vstack(centers[:, 0])
225213
hys = np.array([hy] * m) * ny + np.vstack(centers[:, 1])
226214

215+
return hxs, hys, centers, agreggated_value
216+
217+
def _compute_wgs84_hexbin(
218+
lat=None,
219+
lon=None,
220+
lat_range=None,
221+
lon_range=None,
222+
color=None,
223+
nx=None,
224+
agg_func=None,
225+
min_count=None,
226+
):
227+
"""
228+
Computes the lat-lon aggregation at hexagonal bin level.
229+
Latitude and longitude need to be projected to WGS84 before aggregating
230+
in order to display regular hexagons on the map.
231+
232+
Parameters
233+
----------
234+
lat : np.ndarray
235+
Array of latitudes (shape N)
236+
lon : np.ndarray
237+
Array of longitudes (shape N)
238+
lat_range : np.ndarray
239+
Min and max latitudes (shape 2)
240+
lon_range : np.ndarray
241+
Min and max longitudes (shape 2)
242+
color : np.ndarray
243+
Metric to aggregate at hexagon level (shape N)
244+
nx : int
245+
Number of hexagons horizontally
246+
agg_func : function
247+
Numpy compatible aggregator, this function must take a one-dimensional
248+
np.ndarray as input and output a scalar
249+
min_count : int
250+
Minimum number of points in the hexagon for the hexagon to be displayed
251+
252+
Returns
253+
-------
254+
np.ndarray
255+
Lat coordinates of each hexagon (shape M x 6)
256+
np.ndarray
257+
Lon coordinates of each hexagon (shape M x 6)
258+
pd.Series
259+
Unique id for each hexagon, to be used in the geojson data (shape M)
260+
np.ndarray
261+
Aggregated value in each hexagon (shape M)
262+
263+
"""
264+
# Project to WGS 84
265+
x, y = _project_latlon_to_wgs84(lat, lon)
266+
267+
if lat_range is None:
268+
lat_range = np.array([lat.min(), lat.max()])
269+
if lon_range is None:
270+
lon_range = np.array([lon.min(), lon.max()])
271+
272+
x_range, y_range = _project_latlon_to_wgs84(lat_range, lon_range)
273+
274+
hxs, hys, centers, agreggated_value = _compute_hexbin(
275+
x, y, x_range, y_range, color, nx, agg_func, min_count
276+
)
277+
227278
# Convert back to lat-lon
228279
hexagons_lats, hexagons_lons = _project_wgs84_to_latlon(hxs, hys)
229280

@@ -237,7 +288,7 @@ def _compute_hexbin(
237288
def _hexagons_to_geojson(hexagons_lats, hexagons_lons, ids=None):
238289
"""
239290
Creates a geojson of hexagonal features based on the outputs of
240-
_compute_hexbin
291+
_compute_wgs84_hexbin
241292
"""
242293
features = []
243294
if ids is None:
@@ -255,12 +306,12 @@ def _hexagons_to_geojson(hexagons_lats, hexagons_lons, ids=None):
255306
return dict(type="FeatureCollection", features=features)
256307

257308

258-
def hexbin_mapbox(
309+
def create_hexbin_mapbox(
259310
data_frame=None,
260311
lat=None,
261312
lon=None,
262313
color=None,
263-
gridsize=5,
314+
nx_hexagon=5,
264315
agg_func=None,
265316
animation_frame=None,
266317
color_discrete_sequence=None,
@@ -278,6 +329,9 @@ def hexbin_mapbox(
278329
width=None,
279330
height=None,
280331
):
332+
"""
333+
Returns a figure aggregating scattered points into connected hexagons
334+
"""
281335
args = build_dataframe(args=locals(), constructor=None)
282336

283337
if agg_func is None:
@@ -286,13 +340,13 @@ def hexbin_mapbox(
286340
lat_range = args["data_frame"][args["lat"]].agg(["min", "max"]).values
287341
lon_range = args["data_frame"][args["lon"]].agg(["min", "max"]).values
288342

289-
hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_hexbin(
343+
hexagons_lats, hexagons_lons, hexagons_ids, count = _compute_wgs84_hexbin(
290344
lat=args["data_frame"][args["lat"]].values,
291345
lon=args["data_frame"][args["lon"]].values,
292346
lat_range=lat_range,
293347
lon_range=lon_range,
294348
color=None,
295-
nx=gridsize,
349+
nx=nx_hexagon,
296350
agg_func=agg_func,
297351
min_count=-np.inf,
298352
)
@@ -323,13 +377,13 @@ def hexbin_mapbox(
323377
agg_data_frame_list = []
324378
for frame, index in groups.items():
325379
df = args["data_frame"].loc[index]
326-
_, _, hexagons_ids, aggregated_value = _compute_hexbin(
380+
_, _, hexagons_ids, aggregated_value = _compute_wgs84_hexbin(
327381
lat=df[args["lat"]].values,
328382
lon=df[args["lon"]].values,
329383
lat_range=lat_range,
330384
lon_range=lon_range,
331385
color=df[args["color"]].values if args["color"] else None,
332-
nx=gridsize,
386+
nx=nx_hexagon,
333387
agg_func=agg_func,
334388
min_count=None,
335389
)
@@ -372,5 +426,14 @@ def hexbin_mapbox(
372426
height=height,
373427
)
374428

375-
376-
hexbin_mapbox.__doc__ = make_docstring(hexbin_mapbox)
429+
create_hexbin_mapbox.__doc__ = make_docstring(
430+
create_hexbin_mapbox,
431+
override_dict=dict(
432+
nx_hexagon=["int", "Number of hexagons (horizontally) to be created"],
433+
agg_func=[
434+
"function",
435+
"Numpy array aggregator, it must take as input a 1D array",
436+
"and output a scalar value.",
437+
],
438+
)
439+
)

0 commit comments

Comments
 (0)