@@ -1779,7 +1779,7 @@ def infer_config(args, constructor, trace_patch, layout_patch):
1779
1779
else args ["geojson" ].__geo_interface__
1780
1780
)
1781
1781
1782
- # Compute marginal attribute
1782
+ # Compute marginal attribute: copy to appropriate marginal_*
1783
1783
if "marginal" in args :
1784
1784
position = "marginal_x" if args ["orientation" ] == "v" else "marginal_y"
1785
1785
other_position = "marginal_x" if args ["orientation" ] == "h" else "marginal_y"
@@ -1879,6 +1879,7 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1879
1879
1880
1880
col_labels = []
1881
1881
row_labels = []
1882
+ nrows = ncols = 1
1882
1883
for m in grouped_mappings :
1883
1884
if m .grouper not in sorted_group_values :
1884
1885
m .val_map ["" ] = m .sequence [0 ]
@@ -1887,9 +1888,11 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1887
1888
if m .facet == "col" :
1888
1889
prefix = get_label (args , args ["facet_col" ]) + "="
1889
1890
col_labels = [prefix + str (s ) for s in sorted_values ]
1891
+ ncols = len (col_labels )
1890
1892
if m .facet == "row" :
1891
1893
prefix = get_label (args , args ["facet_row" ]) + "="
1892
1894
row_labels = [prefix + str (s ) for s in sorted_values ]
1895
+ nrows = len (row_labels )
1893
1896
for val in sorted_values :
1894
1897
if val not in m .val_map : # always False if it's an IdentityMap
1895
1898
m .val_map [val ] = m .sequence [len (m .val_map ) % len (m .sequence )]
@@ -1899,8 +1902,8 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1899
1902
trace_names_by_frame = {}
1900
1903
frames = OrderedDict ()
1901
1904
trendline_rows = []
1902
- nrows = ncols = 1
1903
1905
trace_name_labels = None
1906
+ facet_col_wrap = args .get ("facet_col_wrap" , 0 )
1904
1907
for group_name in sorted_group_names :
1905
1908
group = grouped .get_group (group_name if len (group_name ) > 1 else group_name [0 ])
1906
1909
mapping_labels = OrderedDict ()
@@ -1981,14 +1984,13 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
1981
1984
row = m .val_map [val ]
1982
1985
else :
1983
1986
if (
1984
- bool (args .get ("marginal_x" , False ))
1985
- and trace_spec .marginal != "x"
1987
+ bool (args .get ("marginal_x" , False )) # there is a marginal
1988
+ and trace_spec .marginal != "x" # and we're not it
1986
1989
):
1987
1990
row = 2
1988
1991
else :
1989
1992
row = 1
1990
1993
1991
- facet_col_wrap = args .get ("facet_col_wrap" , 0 )
1992
1994
# Find col for trace, handling facet_col and marginal_y
1993
1995
if m .facet == "col" :
1994
1996
col = m .val_map [val ]
@@ -2001,11 +2003,9 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
2001
2003
else :
2002
2004
col = 1
2003
2005
2004
- nrows = max (nrows , row )
2005
2006
if row > 1 :
2006
2007
trace ._subplot_row = row
2007
2008
2008
- ncols = max (ncols , col )
2009
2009
if col > 1 :
2010
2010
trace ._subplot_col = col
2011
2011
if (
@@ -2064,6 +2064,16 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
2064
2064
):
2065
2065
layout_patch ["legend" ]["itemsizing" ] = "constant"
2066
2066
2067
+ if facet_col_wrap :
2068
+ nrows = 1 + ncols // facet_col_wrap
2069
+ ncols = ncols if ncols < facet_col_wrap else facet_col_wrap
2070
+
2071
+ if args .get ("marginal_x" ):
2072
+ nrows += 1
2073
+
2074
+ if args .get ("marginal_y" ):
2075
+ ncols += 1
2076
+
2067
2077
fig = init_figure (
2068
2078
args , subplot_type , frame_list , nrows , ncols , col_labels , row_labels
2069
2079
)
0 commit comments