Skip to content

Commit c7c765a

Browse files
enable faceting for geo, geojson everywhere possible, text/symbols for scatter_geo
1 parent 326932b commit c7c765a

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

packages/python/plotly/plotly/express/_chart_types.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,11 @@ def choropleth(
940940
geojson=None,
941941
featureidkey=None,
942942
color=None,
943+
facet_row=None,
944+
facet_col=None,
945+
facet_col_wrap=0,
946+
facet_row_spacing=None,
947+
facet_col_spacing=None,
943948
hover_name=None,
944949
hover_data=None,
945950
custom_data=None,
@@ -967,13 +972,7 @@ def choropleth(
967972
return make_figure(
968973
args=locals(),
969974
constructor=go.Choropleth,
970-
trace_patch=dict(
971-
locationmode=locationmode,
972-
featureidkey=featureidkey,
973-
geojson=geojson
974-
if not hasattr(geojson, "__geo_interface__") # for geopandas
975-
else geojson.__geo_interface__,
976-
),
975+
trace_patch=dict(locationmode=locationmode),
977976
)
978977

979978

@@ -986,8 +985,16 @@ def scatter_geo(
986985
lon=None,
987986
locations=None,
988987
locationmode=None,
988+
geojson=None,
989+
featureidkey=None,
989990
color=None,
990991
text=None,
992+
symbol=None,
993+
facet_row=None,
994+
facet_col=None,
995+
facet_col_wrap=0,
996+
facet_row_spacing=None,
997+
facet_col_spacing=None,
991998
hover_name=None,
992999
hover_data=None,
9931000
custom_data=None,
@@ -1001,6 +1008,8 @@ def scatter_geo(
10011008
color_continuous_scale=None,
10021009
range_color=None,
10031010
color_continuous_midpoint=None,
1011+
symbol_sequence=None,
1012+
symbol_map={},
10041013
opacity=None,
10051014
size_max=None,
10061015
projection=None,
@@ -1031,9 +1040,16 @@ def line_geo(
10311040
lon=None,
10321041
locations=None,
10331042
locationmode=None,
1043+
geojson=None,
1044+
featureidkey=None,
10341045
color=None,
10351046
line_dash=None,
10361047
text=None,
1048+
facet_row=None,
1049+
facet_col=None,
1050+
facet_col_wrap=0,
1051+
facet_row_spacing=None,
1052+
facet_col_spacing=None,
10371053
hover_name=None,
10381054
hover_data=None,
10391055
custom_data=None,
@@ -1078,6 +1094,8 @@ def scatter_mapbox(
10781094
hover_data=None,
10791095
custom_data=None,
10801096
size=None,
1097+
geojson=None,
1098+
featureidkey=None,
10811099
animation_frame=None,
10821100
animation_group=None,
10831101
category_orders={},
@@ -1138,16 +1156,7 @@ def choropleth_mapbox(
11381156
In a Mapbox choropleth map, each row of `data_frame` is represented by a
11391157
colored region on a Mapbox map.
11401158
"""
1141-
return make_figure(
1142-
args=locals(),
1143-
constructor=go.Choroplethmapbox,
1144-
trace_patch=dict(
1145-
featureidkey=featureidkey,
1146-
geojson=geojson
1147-
if not hasattr(geojson, "__geo_interface__") # for geopandas
1148-
else geojson.__geo_interface__,
1149-
),
1150-
)
1159+
return make_figure(args=locals(), constructor=go.Choroplethmapbox)
11511160

11521161

11531162
choropleth_mapbox.__doc__ = make_docstring(choropleth_mapbox)

packages/python/plotly/plotly/express/_core.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,6 +1750,14 @@ def infer_config(args, constructor, trace_patch, layout_patch):
17501750
if "line_shape" in args:
17511751
trace_patch["line"] = dict(shape=args["line_shape"])
17521752

1753+
if "geojson" in args:
1754+
trace_patch["featureidkey"] = args["featureidkey"]
1755+
trace_patch["geojson"] = (
1756+
args["geojson"]
1757+
if not hasattr(args["geojson"], "__geo_interface__") # for geopandas
1758+
else args["geojson"].__geo_interface__
1759+
)
1760+
17531761
# Compute marginal attribute
17541762
if "marginal" in args:
17551763
position = "marginal_x" if args["orientation"] == "v" else "marginal_y"
@@ -2062,20 +2070,12 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
20622070

20632071
def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels):
20642072
# Build subplot specs
2065-
specs = [[{}] * ncols for _ in range(nrows)]
2066-
for frame in frame_list:
2067-
for trace in frame["data"]:
2068-
row0 = trace._subplot_row - 1
2069-
col0 = trace._subplot_col - 1
2070-
if isinstance(trace, go.Splom):
2071-
# Splom not compatible with make_subplots, treat as domain
2072-
specs[row0][col0] = {"type": "domain"}
2073-
else:
2074-
specs[row0][col0] = {"type": trace.type}
2073+
specs = [[dict(type=subplot_type or "domain")] * ncols for _ in range(nrows)]
20752074

20762075
# Default row/column widths uniform
20772076
column_widths = [1.0] * ncols
20782077
row_heights = [1.0] * nrows
2078+
facet_col_wrap = args.get("facet_col_wrap", 0)
20792079

20802080
# Build column_widths/row_heights
20812081
if subplot_type == "xy":
@@ -2087,7 +2087,7 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
20872087

20882088
row_heights = [main_size] * (nrows - 1) + [1 - main_size]
20892089
vertical_spacing = 0.01
2090-
elif args.get("facet_col_wrap", 0):
2090+
elif facet_col_wrap:
20912091
vertical_spacing = args.get("facet_row_spacing", None) or 0.07
20922092
else:
20932093
vertical_spacing = args.get("facet_row_spacing", None) or 0.03
@@ -2108,10 +2108,12 @@ def init_figure(args, subplot_type, frame_list, nrows, ncols, col_labels, row_la
21082108
#
21092109
# We can customize subplot spacing per type once we enable faceting
21102110
# for all plot types
2111-
vertical_spacing = 0.1
2112-
horizontal_spacing = 0.1
2111+
if facet_col_wrap:
2112+
vertical_spacing = args.get("facet_row_spacing", None) or 0.07
2113+
else:
2114+
vertical_spacing = args.get("facet_row_spacing", None) or 0.03
2115+
horizontal_spacing = args.get("facet_col_spacing", None) or 0.02
21132116

2114-
facet_col_wrap = args.get("facet_col_wrap", 0)
21152117
if facet_col_wrap:
21162118
subplot_labels = [None] * nrows * ncols
21172119
while len(col_labels) < nrows * ncols:

0 commit comments

Comments
 (0)