Skip to content

Commit 90efcfc

Browse files
directly compute nrows/ncols
1 parent 2b8b1c8 commit 90efcfc

File tree

1 file changed

+17
-7
lines changed
  • packages/python/plotly/plotly/express

1 file changed

+17
-7
lines changed

Diff for: packages/python/plotly/plotly/express/_core.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -1779,7 +1779,7 @@ def infer_config(args, constructor, trace_patch, layout_patch):
17791779
else args["geojson"].__geo_interface__
17801780
)
17811781

1782-
# Compute marginal attribute
1782+
# Compute marginal attribute: copy to appropriate marginal_*
17831783
if "marginal" in args:
17841784
position = "marginal_x" if args["orientation"] == "v" else "marginal_y"
17851785
other_position = "marginal_x" if args["orientation"] == "h" else "marginal_y"
@@ -1879,6 +1879,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
18791879

18801880
col_labels = []
18811881
row_labels = []
1882+
nrows = ncols = 1
18821883
for m in grouped_mappings:
18831884
if m.grouper not in sorted_group_values:
18841885
m.val_map[""] = m.sequence[0]
@@ -1887,9 +1888,11 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
18871888
if m.facet == "col":
18881889
prefix = get_label(args, args["facet_col"]) + "="
18891890
col_labels = [prefix + str(s) for s in sorted_values]
1891+
ncols = len(col_labels)
18901892
if m.facet == "row":
18911893
prefix = get_label(args, args["facet_row"]) + "="
18921894
row_labels = [prefix + str(s) for s in sorted_values]
1895+
nrows = len(row_labels)
18931896
for val in sorted_values:
18941897
if val not in m.val_map: # always False if it's an IdentityMap
18951898
m.val_map[val] = m.sequence[len(m.val_map) % len(m.sequence)]
@@ -1899,8 +1902,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
18991902
trace_names_by_frame = {}
19001903
frames = OrderedDict()
19011904
trendline_rows = []
1902-
nrows = ncols = 1
19031905
trace_name_labels = None
1906+
facet_col_wrap = args.get("facet_col_wrap", 0)
19041907
for group_name in sorted_group_names:
19051908
group = grouped.get_group(group_name if len(group_name) > 1 else group_name[0])
19061909
mapping_labels = OrderedDict()
@@ -1981,14 +1984,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
19811984
row = m.val_map[val]
19821985
else:
19831986
if (
1984-
bool(args.get("marginal_x", False))
1985-
and trace_spec.marginal != "x"
1987+
bool(args.get("marginal_x", False)) # there is a marginal
1988+
and trace_spec.marginal != "x" # and we're not it
19861989
):
19871990
row = 2
19881991
else:
19891992
row = 1
19901993

1991-
facet_col_wrap = args.get("facet_col_wrap", 0)
19921994
# Find col for trace, handling facet_col and marginal_y
19931995
if m.facet == "col":
19941996
col = m.val_map[val]
@@ -2001,11 +2003,9 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
20012003
else:
20022004
col = 1
20032005

2004-
nrows = max(nrows, row)
20052006
if row > 1:
20062007
trace._subplot_row = row
20072008

2008-
ncols = max(ncols, col)
20092009
if col > 1:
20102010
trace._subplot_col = col
20112011
if (
@@ -2064,6 +2064,16 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
20642064
):
20652065
layout_patch["legend"]["itemsizing"] = "constant"
20662066

2067+
if facet_col_wrap:
2068+
nrows = 1 + ncols // facet_col_wrap
2069+
ncols = ncols if ncols < facet_col_wrap else facet_col_wrap
2070+
2071+
if args.get("marginal_x"):
2072+
nrows += 1
2073+
2074+
if args.get("marginal_y"):
2075+
ncols += 1
2076+
20672077
fig = init_figure(
20682078
args, subplot_type, frame_list, nrows, ncols, col_labels, row_labels
20692079
)

0 commit comments

Comments
 (0)