diff --git a/packages/python/plotly/plotly/express/_chart_types.py b/packages/python/plotly/plotly/express/_chart_types.py index 9be0a02c035..334f9ac6f35 100644 --- a/packages/python/plotly/plotly/express/_chart_types.py +++ b/packages/python/plotly/plotly/express/_chart_types.py @@ -840,6 +840,8 @@ def choropleth( lon=None, locations=None, locationmode=None, + geojson=None, + featureidkey=None, color=None, hover_name=None, hover_data=None, @@ -848,6 +850,8 @@ def choropleth( animation_group=None, category_orders={}, labels={}, + color_discrete_sequence=None, + color_discrete_map={}, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, @@ -866,7 +870,13 @@ def choropleth( return make_figure( args=locals(), constructor=go.Choropleth, - trace_patch=dict(locationmode=locationmode), + trace_patch=dict( + locationmode=locationmode, + featureidkey=featureidkey, + geojson=geojson + if not hasattr(geojson, "__geo_interface__") # for geopandas + else geojson.__geo_interface__, + ), ) @@ -1003,6 +1013,7 @@ def scatter_mapbox( def choropleth_mapbox( data_frame=None, geojson=None, + featureidkey=None, locations=None, color=None, hover_name=None, @@ -1012,6 +1023,8 @@ def choropleth_mapbox( animation_group=None, category_orders={}, labels={}, + color_discrete_sequence=None, + color_discrete_map={}, color_continuous_scale=None, range_color=None, color_continuous_midpoint=None, @@ -1032,9 +1045,10 @@ def choropleth_mapbox( args=locals(), constructor=go.Choroplethmapbox, trace_patch=dict( + featureidkey=featureidkey, geojson=geojson - if not hasattr(geojson, "__geo_interface__") - else geojson.__geo_interface__ + if not hasattr(geojson, "__geo_interface__") # for geopandas + else geojson.__geo_interface__, ), ) diff --git a/packages/python/plotly/plotly/express/_core.py b/packages/python/plotly/plotly/express/_core.py index 1d919123c65..7bcd461d7e1 100644 --- a/packages/python/plotly/plotly/express/_core.py +++ b/packages/python/plotly/plotly/express/_core.py @@ -1250,8 +1250,9 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): if val not in m.val_map: m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)] try: - m.updater(trace, m.val_map[val]) + m.updater(trace, m.val_map[val]) # covers most cases except ValueError: + # this catches some odd cases like marginals if ( trace_spec != trace_specs[0] and trace_spec.constructor in [go.Violin, go.Box, go.Histogram] @@ -1264,6 +1265,16 @@ def make_figure(args, constructor, trace_patch={}, layout_patch={}): and m.variable == "color" ): trace.update(marker=dict(color=m.val_map[val])) + elif ( + trace_spec.constructor in [go.Choropleth, go.Choroplethmapbox] + and m.variable == "color" + ): + trace.update( + z=[1] * len(group), + colorscale=[m.val_map[val]] * 2, + showscale=False, + showlegend=True, + ) else: raise diff --git a/packages/python/plotly/plotly/express/_doc.py b/packages/python/plotly/plotly/express/_doc.py index 3ee3df4e5a7..676f508ef6d 100644 --- a/packages/python/plotly/plotly/express/_doc.py +++ b/packages/python/plotly/plotly/express/_doc.py @@ -471,6 +471,11 @@ "GeoJSON-formatted dict", "Must contain a Polygon feature collection, with IDs, which are references from `locations`.", ], + featureidkey=[ + "str (default: `'id'`)", + "Path to field in GeoJSON feature object with which to match the values passed in to `locations`." + "The most common alternative to the default is of the form `'properties.`.", + ], cumulative=[ "boolean (default `False`)", "If `True`, histogram values are cumulative.",