diff --git a/plotly/tests/test_core/test_tools/test_make_subplots.py b/plotly/tests/test_core/test_tools/test_make_subplots.py index 9db24b3cd69..5e78e18f247 100644 --- a/plotly/tests/test_core/test_tools/test_make_subplots.py +++ b/plotly/tests/test_core/test_tools/test_make_subplots.py @@ -1946,7 +1946,7 @@ def test_subplot_titles_shared_axes(self): layout=Layout( annotations=Annotations([ Annotation( - x=0.22499999999999998, + x=0.225, y=1.0, xref='paper', yref='paper', @@ -1957,7 +1957,7 @@ def test_subplot_titles_shared_axes(self): yanchor='bottom' ), Annotation( - x=0.7749999999999999, + x=0.775, y=1.0, xref='paper', yref='paper', @@ -1968,7 +1968,7 @@ def test_subplot_titles_shared_axes(self): yanchor='bottom' ), Annotation( - x=0.22499999999999998, + x=0.225, y=0.375, xref='paper', yref='paper', @@ -1979,7 +1979,7 @@ def test_subplot_titles_shared_axes(self): yanchor='bottom' ), Annotation( - x=0.7749999999999999, + x=0.775, y=0.375, xref='paper', yref='paper', @@ -2010,13 +2010,13 @@ def test_subplot_titles_shared_axes(self): ) ) ) + fig = tls.make_subplots(rows=2, cols=2, subplot_titles=('Title 1', 'Title 2', 'Title 3', 'Title 4'), shared_xaxes=True, shared_yaxes=True) self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json()) - def test_subplot_titles_irregular_layout(self): # make a title for each subplot when the layout is irregular: expected = Figure( @@ -2155,3 +2155,112 @@ def test_large_columns_no_errors(self): fig = tls.make_subplots(100, 1, vertical_spacing=v_space, specs=[[{'is_3d': True}] for _ in range(100)]) + + def test_row_width_and_column_width(self): + + expected = Figure({ + 'data': [], + 'layout': {'annotations': [{'font': {'size': 16}, + 'showarrow': False, + 'text': 'Title 1', + 'x': 0.405, + 'xanchor': 'center', + 'xref': 'paper', + 'y': 1.0, + 'yanchor': 'bottom', + 'yref': 'paper'}, + {'font': {'size': 16}, + 'showarrow': False, + 'text': 'Title 2', + 'x': 0.9550000000000001, + 'xanchor': 'center', + 'xref': 'paper', + 'y': 1.0, + 'yanchor': 'bottom', + 'yref': 'paper'}, + {'font': {'size': 16}, + 'showarrow': False, + 'text': 'Title 3', + 'x': 0.405, + 'xanchor': 'center', + 'xref': 'paper', + 'y': 0.1875, + 'yanchor': 'bottom', + 'yref': 'paper'}, + {'font': {'size': 16}, + 'showarrow': False, + 'text': 'Title 4', + 'x': 0.9550000000000001, + 'xanchor': 'center', + 'xref': 'paper', + 'y': 0.1875, + 'yanchor': 'bottom', + 'yref': 'paper'}], + 'xaxis': {'anchor': 'y', 'domain': [0.0, 0.81]}, + 'xaxis2': {'anchor': 'y2', 'domain': [0.91, 1.0]}, + 'xaxis3': {'anchor': 'y3', 'domain': [0.0, 0.81]}, + 'xaxis4': {'anchor': 'y4', 'domain': [0.91, 1.0]}, + 'yaxis': {'anchor': 'x', 'domain': [0.4375, 1.0]}, + 'yaxis2': {'anchor': 'x2', 'domain': [0.4375, 1.0]}, + 'yaxis3': {'anchor': 'x3', 'domain': [0.0, 0.1875]}, + 'yaxis4': {'anchor': 'x4', 'domain': [0.0, 0.1875]}} + }) + fig = tls.make_subplots(rows=2, cols=2, + subplot_titles=('Title 1', 'Title 2', 'Title 3', 'Title 4'), + row_width=[1, 3], column_width=[9, 1]) + self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json()) + + def test_row_width_and_shared_yaxes(self): + + expected = Figure({ + 'data': [], + 'layout': {'annotations': [{'font': {'size': 16}, + 'showarrow': False, + 'text': 'Title 1', + 'x': 0.225, + 'xanchor': 'center', + 'xref': 'paper', + 'y': 1.0, + 'yanchor': 'bottom', + 'yref': 'paper'}, + {'font': {'size': 16}, + 'showarrow': False, + 'text': 'Title 2', + 'x': 0.775, + 'xanchor': 'center', + 'xref': 'paper', + 'y': 1.0, + 'yanchor': 'bottom', + 'yref': 'paper'}, + {'font': {'size': 16}, + 'showarrow': False, + 'text': 'Title 3', + 'x': 0.225, + 'xanchor': 'center', + 'xref': 'paper', + 'y': 0.1875, + 'yanchor': 'bottom', + 'yref': 'paper'}, + {'font': {'size': 16}, + 'showarrow': False, + 'text': 'Title 4', + 'x': 0.775, + 'xanchor': 'center', + 'xref': 'paper', + 'y': 0.1875, + 'yanchor': 'bottom', + 'yref': 'paper'}], + 'xaxis': {'anchor': 'y', 'domain': [0.0, 0.45]}, + 'xaxis2': {'anchor': 'free', 'domain': [0.55, 1.0], 'position': 0.4375}, + 'xaxis3': {'anchor': 'y2', 'domain': [0.0, 0.45]}, + 'xaxis4': {'anchor': 'free', 'domain': [0.55, 1.0], 'position': 0.0}, + 'yaxis': {'anchor': 'x', 'domain': [0.4375, 1.0]}, + 'yaxis2': {'anchor': 'x3', 'domain': [0.0, 0.1875]}} + }) + + fig = tls.make_subplots(rows=2, cols=2, row_width=[1, 3], shared_yaxes=True, + subplot_titles=('Title 1', 'Title 2', 'Title 3', 'Title 4')) + + self.assertEqual(fig.to_plotly_json(), expected.to_plotly_json()) + + # def test_row_width_and_shared_yaxes(self): \ No newline at end of file diff --git a/plotly/tools.py b/plotly/tools.py index 87b789309b3..3f6fd488069 100644 --- a/plotly/tools.py +++ b/plotly/tools.py @@ -13,6 +13,7 @@ import six import copy +import re from plotly import exceptions, optional_imports, session, utils from plotly.files import (CONFIG_FILE, CREDENTIALS_FILE, FILE_CONTENT, @@ -1001,7 +1002,6 @@ def _checks(item, defaults): ) for c in col_seq ] for r in row_seq ] - # [grid_ref] Initialize the grid and insets' axis-reference lists grid_ref = [[None for c in range(cols)] for r in range(rows)] insets_ref = [None for inset in range(len(insets))] if insets else None @@ -1323,20 +1323,49 @@ def _pad(s, cell_len=cell_len): subtitle_pos_x.append(sum(x_domains) / 2) for y_domains in y_dom: subtitle_pos_y.append(y_domains[1]) + # If shared_axes is True the domin of each subplot is not returned so the # title position must be calculated for each subplot else: - subtitle_pos_x = [None] * cols - subtitle_pos_y = [None] * rows - delt_x = (x_e - x_s) + x_dom_vals = [k for k in layout.to_plotly_json().keys() if 'xaxis' in k] + y_dom_vals = [k for k in layout.to_plotly_json().keys() if 'yaxis' in k] + + # sort xaxis and yaxis layout keys + r = re.compile('\d+') + + def key_func(m): + try: + return int(r.search(m).group(0)) + except AttributeError: + return 0 + + xaxies_labels_sorted = sorted(x_dom_vals, key=key_func) + yaxies_labels_sorted = sorted(y_dom_vals, key=key_func) + + x_dom = [layout[k]['domain'] for k in xaxies_labels_sorted] + y_dom = [layout[k]['domain'] for k in yaxies_labels_sorted] + for index in range(cols): - subtitle_pos_x[index] = ((delt_x / 2) + - ((delt_x + horizontal_spacing) * index)) - subtitle_pos_x *= rows - for index in range(rows): - subtitle_pos_y[index] = (1 - ((y_e + vertical_spacing) * index)) - subtitle_pos_y *= cols - subtitle_pos_y = sorted(subtitle_pos_y, reverse=True) + subtitle_pos_x = [] + for x_domains in x_dom: + subtitle_pos_x.append(sum(x_domains) / 2) + subtitle_pos_x *= rows + + if shared_yaxes: + for index in range(rows): + subtitle_pos_y = [] + for y_domain in y_dom: + subtitle_pos_y.append(y_domain[1]) + subtitle_pos_y *= cols + subtitle_pos_y = sorted(subtitle_pos_y, reverse=True) + + else: + for index in range(rows): + subtitle_pos_y = [] + for y_domain in y_dom: + subtitle_pos_y.append(y_domain[1]) + subtitle_pos_y = sorted(subtitle_pos_y, reverse=True) + subtitle_pos_y *= cols plot_titles = [] for index in range(len(subplot_titles)):