diff --git a/plotly/figure_factory/_bullet.py b/plotly/figure_factory/_bullet.py index bcb5b2541d1..e2c42a7855a 100644 --- a/plotly/figure_factory/_bullet.py +++ b/plotly/figure_factory/_bullet.py @@ -78,7 +78,7 @@ def _bullet(df, markers, measures, ranges, subtitles, titles, orientation, for row in range(num_of_lanes): # ranges bars for idx in range(len(df.iloc[row]['ranges'])): - inter_colors = colors.n_colors( + inter_colors = utils.n_colors( range_colors[0], range_colors[1], len(df.iloc[row]['ranges']), 'rgb' ) @@ -104,7 +104,7 @@ def _bullet(df, markers, measures, ranges, subtitles, titles, orientation, # measures bars for idx in range(len(df.iloc[row]['measures'])): - inter_colors = colors.n_colors( + inter_colors = utils.n_colors( measure_colors[0], measure_colors[1], len(df.iloc[row]['measures']), 'rgb' ) @@ -318,7 +318,7 @@ def create_bullet(data, markers=None, measures=None, ranges=None, "of two valid colors." ) colors.validate_colors(colors_list) - colors_list = colors.convert_colors_to_same_type(colors_list, + colors_list = utils.convert_colors_to_same_type(colors_list, 'rgb')[0] # default scatter options diff --git a/plotly/figure_factory/_county_choropleth.py b/plotly/figure_factory/_county_choropleth.py index 8ee9fef7c2c..5c8674ce07a 100644 --- a/plotly/figure_factory/_county_choropleth.py +++ b/plotly/figure_factory/_county_choropleth.py @@ -642,7 +642,7 @@ def create_choropleth(fips, values, scope=['usa'], binning_endpoints=None, if not colorscale: colorscale = [] viridis_colors = colors.colorscale_to_colors( - colors.PLOTLY_SCALES['Viridis'] + utils.PLOTLY_SCALES['Viridis'] ) viridis_colors = colors.color_parser( viridis_colors, colors.hex_to_rgb @@ -674,7 +674,7 @@ def create_choropleth(fips, values, scope=['usa'], binning_endpoints=None, # make R,G,B into int values float_color = colors.unlabel_rgb(float_color) - float_color = colors.unconvert_from_RGB_255(float_color) + float_color = utils.unconvert_from_RGB_255(float_color) int_rgb = colors.convert_to_RGB_255(float_color) int_rgb = colors.label_rgb(int_rgb) diff --git a/plotly/figure_factory/_facet_grid.py b/plotly/figure_factory/_facet_grid.py index a202599d747..20b6ede8326 100644 --- a/plotly/figure_factory/_facet_grid.py +++ b/plotly/figure_factory/_facet_grid.py @@ -935,13 +935,13 @@ def create_facet_grid(df, x=None, y=None, facet_row=None, facet_col=None, marker_color, kwargs_trace, kwargs_marker ) elif isinstance(colormap, str): - if colormap in colors.PLOTLY_SCALES.keys(): - colorscale_list = colors.PLOTLY_SCALES[colormap] + if colormap in utils.PLOTLY_SCALES.keys(): + colorscale_list = utils.PLOTLY_SCALES[colormap] else: raise exceptions.PlotlyError( "If 'colormap' is a string, it must be the name " "of a Plotly Colorscale. The available colorscale " - "names are {}".format(colors.PLOTLY_SCALES.keys()) + "names are {}".format(utils.PLOTLY_SCALES.keys()) ) fig, annotations = _facet_grid_color_numerical( df, x, y, facet_row, facet_col, color_name, @@ -951,7 +951,7 @@ def create_facet_grid(df, x=None, y=None, facet_row=None, facet_col=None, marker_color, kwargs_trace, kwargs_marker ) else: - colorscale_list = colors.PLOTLY_SCALES['Reds'] + colorscale_list = utils.PLOTLY_SCALES['Reds'] fig, annotations = _facet_grid_color_numerical( df, x, y, facet_row, facet_col, color_name, colorscale_list, num_of_rows, num_of_cols, diff --git a/plotly/figure_factory/_scatterplot.py b/plotly/figure_factory/_scatterplot.py index 3c2de2452f7..d56af50620d 100644 --- a/plotly/figure_factory/_scatterplot.py +++ b/plotly/figure_factory/_scatterplot.py @@ -1,5 +1,7 @@ from __future__ import absolute_import +import six + from plotly import colors, exceptions, optional_imports from plotly.figure_factory import utils from plotly.graph_objs import graph_objs @@ -1080,6 +1082,19 @@ def create_scatterplotmatrix(df, index=None, endpts=None, diag='scatter', # Validate colormap if isinstance(colormap, dict): colormap = utils.validate_colors_dict(colormap, 'rgb') + elif isinstance(colormap, six.string_types) and 'rgb' not in colormap and '#' not in colormap: + if colormap not in utils.PLOTLY_SCALES.keys(): + raise exceptions.PlotlyError( + "If 'colormap' is a string, it must be the name " + "of a Plotly Colorscale. The available colorscale " + "names are {}".format(utils.PLOTLY_SCALES.keys()) + ) + else: + # TODO change below to allow the correct Plotly colorscale + colormap = utils.colorscale_to_colors(utils.PLOTLY_SCALES[colormap]) + # keep only first and last item - fix later + colormap = [colormap[0]] + [colormap[-1]] + colormap = utils.validate_colors(colormap, 'rgb') else: colormap = utils.validate_colors(colormap, 'rgb') diff --git a/plotly/figure_factory/_trisurf.py b/plotly/figure_factory/_trisurf.py index 06e9cc04fee..9b84b61cf46 100644 --- a/plotly/figure_factory/_trisurf.py +++ b/plotly/figure_factory/_trisurf.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from plotly import colors, exceptions, optional_imports +from plotly.figure_factory import utils from plotly.graph_objs import graph_objs np = optional_imports.get_module('numpy') @@ -147,8 +148,8 @@ def trisurf(x, y, z, simplices, show_colorbar, edges_color, scale, if mean_dists_are_numbers and show_colorbar is True: # make a colorscale from the colors - colorscale = colors.make_colorscale(colormap, scale) - colorscale = colors.convert_colorscale_to_rgb(colorscale) + colorscale = utils.make_colorscale(colormap, scale) + colorscale = utils.convert_colorscale_to_rgb(colorscale) colorbar = graph_objs.Scatter3d( x=x[:1], @@ -455,7 +456,7 @@ def dist_origin(x, y, z): # Validate colormap colors.validate_colors(colormap) - colormap, scale = colors.convert_colors_to_same_type( + colormap, scale = utils.convert_colors_to_same_type( colormap, colortype='tuple', return_default_colors=True, scale=scale ) diff --git a/plotly/figure_factory/figure_factory/_2d_density.py b/plotly/figure_factory/figure_factory/_2d_density.py new file mode 100644 index 00000000000..8d004bff147 --- /dev/null +++ b/plotly/figure_factory/figure_factory/_2d_density.py @@ -0,0 +1,166 @@ +from __future__ import absolute_import + +from numbers import Number + +from plotly import exceptions +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs + + +def make_linear_colorscale(colors): + """ + Makes a list of colors into a colorscale-acceptable form + + For documentation regarding to the form of the output, see + https://plot.ly/python/reference/#mesh3d-colorscale + """ + scale = 1. / (len(colors) - 1) + return [[i * scale, color] for i, color in enumerate(colors)] + + +def create_2d_density(x, y, colorscale='Earth', ncontours=20, + hist_color=(0, 0, 0.5), point_color=(0, 0, 0.5), + point_size=2, title='2D Density Plot', + height=600, width=600): + """ + Returns figure for a 2D density plot + + :param (list|array) x: x-axis data for plot generation + :param (list|array) y: y-axis data for plot generation + :param (str|tuple|list) colorscale: either a plotly scale name, an rgb + or hex color, a color tuple or a list or tuple of colors. An rgb + color is of the form 'rgb(x, y, z)' where x, y, z belong to the + interval [0, 255] and a color tuple is a tuple of the form + (a, b, c) where a, b and c belong to [0, 1]. If colormap is a + list, it must contain the valid color types aforementioned as its + members. + :param (int) ncontours: the number of 2D contours to draw on the plot + :param (str) hist_color: the color of the plotted histograms + :param (str) point_color: the color of the scatter points + :param (str) point_size: the color of the scatter points + :param (str) title: set the title for the plot + :param (float) height: the height of the chart + :param (float) width: the width of the chart + + Example 1: Simple 2D Density Plot + ``` + import plotly.plotly as py + from plotly.figure_factory create_2d_density + + import numpy as np + + # Make data points + t = np.linspace(-1,1.2,2000) + x = (t**3)+(0.3*np.random.randn(2000)) + y = (t**6)+(0.3*np.random.randn(2000)) + + # Create a figure + fig = create_2D_density(x, y) + + # Plot the data + py.iplot(fig, filename='simple-2d-density') + ``` + + Example 2: Using Parameters + ``` + import plotly.plotly as py + from plotly.figure_factory create_2d_density + + import numpy as np + + # Make data points + t = np.linspace(-1,1.2,2000) + x = (t**3)+(0.3*np.random.randn(2000)) + y = (t**6)+(0.3*np.random.randn(2000)) + + # Create custom colorscale + colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)', + (1, 1, 0.2), (0.98,0.98,0.98)] + + # Create a figure + fig = create_2D_density( + x, y, colorscale=colorscale, + hist_color='rgb(255, 237, 222)', point_size=3) + + # Plot the data + py.iplot(fig, filename='use-parameters') + ``` + """ + + # validate x and y are filled with numbers only + for array in [x, y]: + if not all(isinstance(element, Number) for element in array): + raise exceptions.PlotlyError( + "All elements of your 'x' and 'y' lists must be numbers." + ) + + # validate x and y are the same length + if len(x) != len(y): + raise exceptions.PlotlyError( + "Both lists 'x' and 'y' must be the same length." + ) + + colorscale = utils.validate_colors(colorscale, 'rgb') + colorscale = make_linear_colorscale(colorscale) + + # validate hist_color and point_color + hist_color = utils.validate_colors(hist_color, 'rgb') + point_color = utils.validate_colors(point_color, 'rgb') + + trace1 = graph_objs.Scatter( + x=x, y=y, mode='markers', name='points', + marker=dict( + color=point_color[0], + size=point_size, + opacity=0.4 + ) + ) + trace2 = graph_objs.Histogram2dContour( + x=x, y=y, name='density', ncontours=ncontours, + colorscale=colorscale, reversescale=True, showscale=False + ) + trace3 = graph_objs.Histogram( + x=x, name='x density', + marker=dict(color=hist_color[0]), yaxis='y2' + ) + trace4 = graph_objs.Histogram( + y=y, name='y density', + marker=dict(color=hist_color[0]), xaxis='x2' + ) + data = [trace1, trace2, trace3, trace4] + + layout = graph_objs.Layout( + showlegend=False, + autosize=False, + title=title, + height=height, + width=width, + xaxis=dict( + domain=[0, 0.85], + showgrid=False, + zeroline=False + ), + yaxis=dict( + domain=[0, 0.85], + showgrid=False, + zeroline=False + ), + margin=dict( + t=50 + ), + hovermode='closest', + bargap=0, + xaxis2=dict( + domain=[0.85, 1], + showgrid=False, + zeroline=False + ), + yaxis2=dict( + domain=[0.85, 1], + showgrid=False, + zeroline=False + ) + ) + + fig = graph_objs.Figure(data=data, layout=layout) + return fig diff --git a/plotly/figure_factory/figure_factory/__init__.py b/plotly/figure_factory/figure_factory/__init__.py new file mode 100644 index 00000000000..a8be19872e1 --- /dev/null +++ b/plotly/figure_factory/figure_factory/__init__.py @@ -0,0 +1,24 @@ +from __future__ import absolute_import + +from plotly import optional_imports + +# Require that numpy exists for figure_factory +import numpy + +from plotly.figure_factory._2d_density import create_2d_density +from plotly.figure_factory._annotated_heatmap import create_annotated_heatmap +from plotly.figure_factory._bullet import create_bullet +from plotly.figure_factory._candlestick import create_candlestick +from plotly.figure_factory._dendrogram import create_dendrogram +from plotly.figure_factory._distplot import create_distplot +from plotly.figure_factory._facet_grid import create_facet_grid +from plotly.figure_factory._gantt import create_gantt +from plotly.figure_factory._ohlc import create_ohlc +from plotly.figure_factory._quiver import create_quiver +from plotly.figure_factory._scatterplot import create_scatterplotmatrix +from plotly.figure_factory._streamline import create_streamline +from plotly.figure_factory._table import create_table +from plotly.figure_factory._trisurf import create_trisurf +from plotly.figure_factory._violin import create_violin +if optional_imports.get_module('pandas') is not None: + from plotly.figure_factory._county_choropleth import create_choropleth diff --git a/plotly/figure_factory/figure_factory/_annotated_heatmap.py b/plotly/figure_factory/figure_factory/_annotated_heatmap.py new file mode 100644 index 00000000000..3400aba0e48 --- /dev/null +++ b/plotly/figure_factory/figure_factory/_annotated_heatmap.py @@ -0,0 +1,263 @@ +from __future__ import absolute_import + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs +from plotly.validators.heatmap import ColorscaleValidator + +# Optional imports, may be None for users that only use our core functionality. +np = optional_imports.get_module('numpy') + + +def validate_annotated_heatmap(z, x, y, annotation_text): + """ + Annotated-heatmap-specific validations + + Check that if a text matrix is supplied, it has the same + dimensions as the z matrix. + + See FigureFactory.create_annotated_heatmap() for params + + :raises: (PlotlyError) If z and text matrices do not have the same + dimensions. + """ + if annotation_text is not None and isinstance(annotation_text, list): + utils.validate_equal_length(z, annotation_text) + for lst in range(len(z)): + if len(z[lst]) != len(annotation_text[lst]): + raise exceptions.PlotlyError("z and text should have the " + "same dimensions") + + if x: + if len(x) != len(z[0]): + raise exceptions.PlotlyError("oops, the x list that you " + "provided does not match the " + "width of your z matrix ") + + if y: + if len(y) != len(z): + raise exceptions.PlotlyError("oops, the y list that you " + "provided does not match the " + "length of your z matrix ") + + +def create_annotated_heatmap(z, x=None, y=None, annotation_text=None, + colorscale='RdBu', font_colors=None, + showscale=False, reversescale=False, + **kwargs): + """ + BETA function that creates annotated heatmaps + + This function adds annotations to each cell of the heatmap. + + :param (list[list]|ndarray) z: z matrix to create heatmap. + :param (list) x: x axis labels. + :param (list) y: y axis labels. + :param (list[list]|ndarray) annotation_text: Text strings for + annotations. Should have the same dimensions as the z matrix. If no + text is added, the values of the z matrix are annotated. Default = + z matrix values. + :param (list|str) colorscale: heatmap colorscale. + :param (list) font_colors: List of two color strings: [min_text_color, + max_text_color] where min_text_color is applied to annotations for + heatmap values < (max_value - min_value)/2. If font_colors is not + defined, the colors are defined logically as black or white + depending on the heatmap's colorscale. + :param (bool) showscale: Display colorscale. Default = False + :param (bool) reversescale: Reverse colorscale. Default = False + :param kwargs: kwargs passed through plotly.graph_objs.Heatmap. + These kwargs describe other attributes about the annotated Heatmap + trace such as the colorscale. For more information on valid kwargs + call help(plotly.graph_objs.Heatmap) + + Example 1: Simple annotated heatmap with default configuration + ``` + import plotly.plotly as py + import plotly.figure_factory as FF + + z = [[0.300000, 0.00000, 0.65, 0.300000], + [1, 0.100005, 0.45, 0.4300], + [0.300000, 0.00000, 0.65, 0.300000], + [1, 0.100005, 0.45, 0.00000]] + + figure = FF.create_annotated_heatmap(z) + py.iplot(figure) + ``` + """ + + # Avoiding mutables in the call signature + font_colors = font_colors if font_colors is not None else [] + validate_annotated_heatmap(z, x, y, annotation_text) + + # validate colorscale + colorscale_validator = ColorscaleValidator() + colorscale = colorscale_validator.validate_coerce(colorscale) + + annotations = _AnnotatedHeatmap(z, x, y, annotation_text, + colorscale, font_colors, reversescale, + **kwargs).make_annotations() + + if x or y: + trace = dict(type='heatmap', z=z, x=x, y=y, colorscale=colorscale, + showscale=showscale, reversescale=reversescale, **kwargs) + layout = dict(annotations=annotations, + xaxis=dict(ticks='', dtick=1, side='top', + gridcolor='rgb(0, 0, 0)'), + yaxis=dict(ticks='', dtick=1, ticksuffix=' ')) + else: + trace = dict(type='heatmap', z=z, colorscale=colorscale, + showscale=showscale, reversescale=reversescale, **kwargs) + layout = dict(annotations=annotations, + xaxis=dict(ticks='', side='top', + gridcolor='rgb(0, 0, 0)', + showticklabels=False), + yaxis=dict(ticks='', ticksuffix=' ', + showticklabels=False)) + + data = [trace] + + return graph_objs.Figure(data=data, layout=layout) + + +def to_rgb_color_list(color_str, default): + if 'rgb' in color_str: + return [int(v) for v in color_str.strip('rgb()').split(',')] + elif '#' in color_str: + return utils.hex_to_rgb(color_str) + else: + return default + + +def should_use_black_text(background_color): + return (background_color[0] * 0.299 + + background_color[1] * 0.587 + + background_color[2] * 0.114) > 186 + + +class _AnnotatedHeatmap(object): + """ + Refer to TraceFactory.create_annotated_heatmap() for docstring + """ + def __init__(self, z, x, y, annotation_text, colorscale, + font_colors, reversescale, **kwargs): + + self.z = z + if x: + self.x = x + else: + self.x = range(len(z[0])) + if y: + self.y = y + else: + self.y = range(len(z)) + if annotation_text is not None: + self.annotation_text = annotation_text + else: + self.annotation_text = self.z + self.colorscale = colorscale + self.reversescale = reversescale + self.font_colors = font_colors + + def get_text_color(self): + """ + Get font color for annotations. + + The annotated heatmap can feature two text colors: min_text_color and + max_text_color. The min_text_color is applied to annotations for + heatmap values < (max_value - min_value)/2. The user can define these + two colors. Otherwise the colors are defined logically as black or + white depending on the heatmap's colorscale. + + :rtype (string, string) min_text_color, max_text_color: text + color for annotations for heatmap values < + (max_value - min_value)/2 and text color for annotations for + heatmap values >= (max_value - min_value)/2 + """ + # Plotly colorscales ranging from a lighter shade to a darker shade + colorscales = ['Greys', 'Greens', 'Blues', + 'YIGnBu', 'YIOrRd', 'RdBu', + 'Picnic', 'Jet', 'Hot', 'Blackbody', + 'Earth', 'Electric', 'Viridis', 'Cividis'] + # Plotly colorscales ranging from a darker shade to a lighter shade + colorscales_reverse = ['Reds'] + + white = '#FFFFFF' + black = '#000000' + if self.font_colors: + min_text_color = self.font_colors[0] + max_text_color = self.font_colors[-1] + elif self.colorscale in colorscales and self.reversescale: + min_text_color = black + max_text_color = white + elif self.colorscale in colorscales: + min_text_color = white + max_text_color = black + elif self.colorscale in colorscales_reverse and self.reversescale: + min_text_color = white + max_text_color = black + elif self.colorscale in colorscales_reverse: + min_text_color = black + max_text_color = white + elif isinstance(self.colorscale, list): + + min_col = to_rgb_color_list(self.colorscale[0][1], + [255, 255, 255]) + max_col = to_rgb_color_list(self.colorscale[-1][1], + [255, 255, 255]) + + # swap min/max colors if reverse scale + if self.reversescale: + min_col, max_col = max_col, min_col + + if should_use_black_text(min_col): + min_text_color = black + else: + min_text_color = white + + if should_use_black_text(max_col): + max_text_color = black + else: + max_text_color = white + else: + min_text_color = black + max_text_color = black + return min_text_color, max_text_color + + def get_z_mid(self): + """ + Get the mid value of z matrix + + :rtype (float) z_avg: average val from z matrix + """ + if np and isinstance(self.z, np.ndarray): + z_min = np.amin(self.z) + z_max = np.amax(self.z) + else: + z_min = min(min(self.z)) + z_max = max(max(self.z)) + z_mid = (z_max+z_min) / 2 + return z_mid + + def make_annotations(self): + """ + Get annotations for each cell of the heatmap with graph_objs.Annotation + + :rtype (list[dict]) annotations: list of annotations for each cell of + the heatmap + """ + min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self) + z_mid = _AnnotatedHeatmap.get_z_mid(self) + annotations = [] + for n, row in enumerate(self.z): + for m, val in enumerate(row): + font_color = min_text_color if val < z_mid else max_text_color + annotations.append( + graph_objs.layout.Annotation( + text=str(self.annotation_text[n][m]), + x=self.x[m], + y=self.y[n], + xref='x1', + yref='y1', + font=dict(color=font_color), + showarrow=False)) + return annotations diff --git a/plotly/figure_factory/figure_factory/_bullet.py b/plotly/figure_factory/figure_factory/_bullet.py new file mode 100644 index 00000000000..046dea3224a --- /dev/null +++ b/plotly/figure_factory/figure_factory/_bullet.py @@ -0,0 +1,340 @@ +from __future__ import absolute_import + +import collections +import math + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils + +import plotly +import plotly.graph_objs as go + +pd = optional_imports.get_module('pandas') + + +def _bullet(df, markers, measures, ranges, subtitles, titles, orientation, + range_colors, measure_colors, horizontal_spacing, + vertical_spacing, scatter_options, layout_options): + + num_of_lanes = len(df) + num_of_rows = num_of_lanes if orientation == 'h' else 1 + num_of_cols = 1 if orientation == 'h' else num_of_lanes + if not horizontal_spacing: + horizontal_spacing = 1./num_of_lanes + if not vertical_spacing: + vertical_spacing = 1./num_of_lanes + fig = plotly.tools.make_subplots( + num_of_rows, num_of_cols, print_grid=False, + horizontal_spacing=horizontal_spacing, + vertical_spacing=vertical_spacing + ) + + # layout + fig['layout'].update( + dict(shapes=[]), + title='Bullet Chart', + height=600, + width=1000, + showlegend=False, + barmode='stack', + annotations=[], + margin=dict(l=120 if orientation == 'h' else 80), + ) + + # update layout + fig['layout'].update(layout_options) + + if orientation == 'h': + width_axis = 'yaxis' + length_axis = 'xaxis' + else: + width_axis = 'xaxis' + length_axis = 'yaxis' + + for key in fig['layout']: + if 'xaxis' in key or 'yaxis' in key: + fig['layout'][key]['showgrid'] = False + fig['layout'][key]['zeroline'] = False + if length_axis in key: + fig['layout'][key]['tickwidth'] = 1 + if width_axis in key: + fig['layout'][key]['showticklabels'] = False + fig['layout'][key]['range'] = [0, 1] + + # narrow domain if 1 bar + if num_of_lanes <= 1: + fig['layout'][width_axis + '1']['domain'] = [0.4, 0.6] + + if not range_colors: + range_colors = ['rgb(200, 200, 200)', 'rgb(245, 245, 245)'] + if not measure_colors: + measure_colors = ['rgb(31, 119, 180)', 'rgb(176, 196, 221)'] + + for row in range(num_of_lanes): + # ranges bars + for idx in range(len(df.iloc[row]['ranges'])): + inter_colors = utils.n_colors( + range_colors[0], range_colors[1], + len(df.iloc[row]['ranges']), 'rgb' + ) + x = ([sorted(df.iloc[row]['ranges'])[-1 - idx]] if + orientation == 'h' else [0]) + y = ([0] if orientation == 'h' else + [sorted(df.iloc[row]['ranges'])[-1 - idx]]) + bar = go.Bar( + x=x, + y=y, + marker=dict( + color=inter_colors[-1 - idx] + ), + name='ranges', + hoverinfo='x' if orientation == 'h' else 'y', + orientation=orientation, + width=2, + base=0, + xaxis='x{}'.format(row + 1), + yaxis='y{}'.format(row + 1) + ) + fig.add_trace(bar) + + # measures bars + for idx in range(len(df.iloc[row]['measures'])): + inter_colors = utils.n_colors( + measure_colors[0], measure_colors[1], + len(df.iloc[row]['measures']), 'rgb' + ) + x = ([sorted(df.iloc[row]['measures'])[-1 - idx]] if + orientation == 'h' else [0.5]) + y = ([0.5] if orientation == 'h' + else [sorted(df.iloc[row]['measures'])[-1 - idx]]) + bar = go.Bar( + x=x, + y=y, + marker=dict( + color=inter_colors[-1 - idx] + ), + name='measures', + hoverinfo='x' if orientation == 'h' else 'y', + orientation=orientation, + width=0.4, + base=0, + xaxis='x{}'.format(row + 1), + yaxis='y{}'.format(row + 1) + ) + fig.add_trace(bar) + + # markers + x = df.iloc[row]['markers'] if orientation == 'h' else [0.5] + y = [0.5] if orientation == 'h' else df.iloc[row]['markers'] + markers = go.Scatter( + x=x, + y=y, + name='markers', + hoverinfo='x' if orientation == 'h' else 'y', + xaxis='x{}'.format(row + 1), + yaxis='y{}'.format(row + 1), + **scatter_options + ) + + fig.add_trace(markers) + + # titles and subtitles + title = df.iloc[row]['titles'] + if 'subtitles' in df: + subtitle = '
{}'.format(df.iloc[row]['subtitles']) + else: + subtitle = '' + label = '{}'.format(title) + subtitle + annot = utils.annotation_dict_for_label( + label, + (num_of_lanes - row if orientation == 'h' else row + 1), + num_of_lanes, + vertical_spacing if orientation == 'h' else horizontal_spacing, + 'row' if orientation == 'h' else 'col', + True if orientation == 'h' else False, + False + ) + fig['layout']['annotations'] += (annot,) + + return fig + + +def create_bullet(data, markers=None, measures=None, ranges=None, + subtitles=None, titles=None, orientation='h', + range_colors=('rgb(200, 200, 200)', 'rgb(245, 245, 245)'), + measure_colors=('rgb(31, 119, 180)', 'rgb(176, 196, 221)'), + horizontal_spacing=None, vertical_spacing=None, + scatter_options={}, **layout_options): + """ + Returns figure for bullet chart. + + :param (pd.DataFrame | list | tuple) data: either a list/tuple of + dictionaries or a pandas DataFrame. + :param (str) markers: the column name or dictionary key for the markers in + each subplot. + :param (str) measures: the column name or dictionary key for the measure + bars in each subplot. This bar usually represents the quantitative + measure of performance, usually a list of two values [a, b] and are + the blue bars in the foreground of each subplot by default. + :param (str) ranges: the column name or dictionary key for the qualitative + ranges of performance, usually a 3-item list [bad, okay, good]. They + correspond to the grey bars in the background of each chart. + :param (str) subtitles: the column name or dictionary key for the subtitle + of each subplot chart. The subplots are displayed right underneath + each title. + :param (str) titles: the column name or dictionary key for the main label + of each subplot chart. + :param (bool) orientation: if 'h', the bars are placed horizontally as + rows. If 'v' the bars are placed vertically in the chart. + :param (list) range_colors: a tuple of two colors between which all + the rectangles for the range are drawn. These rectangles are meant to + be qualitative indicators against which the marker and measure bars + are compared. + Default=('rgb(200, 200, 200)', 'rgb(245, 245, 245)') + :param (list) measure_colors: a tuple of two colors which is used to color + the thin quantitative bars in the bullet chart. + Default=('rgb(31, 119, 180)', 'rgb(176, 196, 221)') + :param (float) horizontal_spacing: see the 'horizontal_spacing' param in + plotly.tools.make_subplots. Ranges between 0 and 1. + :param (float) vertical_spacing: see the 'vertical_spacing' param in + plotly.tools.make_subplots. Ranges between 0 and 1. + :param (dict) scatter_options: describes attributes for the scatter trace + in each subplot such as name and marker size. Call + help(plotly.graph_objs.Scatter) for more information on valid params. + :param layout_options: describes attributes for the layout of the figure + such as title, height and width. Call help(plotly.graph_objs.Layout) + for more information on valid params. + + Example 1: Use a Dictionary + ``` + import plotly + import plotly.plotly as py + import plotly.figure_factory as ff + + data = [ + {"label": "Revenue", "sublabel": "US$, in thousands", + "range": [150, 225, 300], "performance": [220,270], "point": [250]}, + {"label": "Profit", "sublabel": "%", "range": [20, 25, 30], + "performance": [21, 23], "point": [26]}, + {"label": "Order Size", "sublabel":"US$, average","range": [350, 500, 600], + "performance": [100,320],"point": [550]}, + {"label": "New Customers", "sublabel": "count", "range": [1400, 2000, 2500], + "performance": [1000, 1650],"point": [2100]}, + {"label": "Satisfaction", "sublabel": "out of 5","range": [3.5, 4.25, 5], + "performance": [3.2, 4.7], "point": [4.4]} + ] + + fig = ff.create_bullet( + data, titles='label', subtitles='sublabel', markers='point', + measures='performance', ranges='range', orientation='h', + title='my simple bullet chart' + ) + py.iplot(fig) + ``` + + Example 2: Use a DataFrame with Custom Colors + ``` + import plotly.plotly as py + import plotly.figure_factory as ff + + import pandas as pd + + data = pd.read_json('https://cdn.rawgit.com/plotly/datasets/master/BulletData.json') + + fig = ff.create_bullet( + data, titles='title', markers='markers', measures='measures', + orientation='v', measure_colors=['rgb(14, 52, 75)', 'rgb(31, 141, 127)'], + scatter_options={'marker': {'symbol': 'circle'}}, width=700 + + ) + py.iplot(fig) + ``` + """ + # validate df + if not pd: + raise exceptions.ImportError( + "'pandas' must be installed for this figure factory." + ) + + if utils.is_sequence(data): + if not all(isinstance(item, dict) for item in data): + raise exceptions.PlotlyError( + 'Every entry of the data argument list, tuple, etc must ' + 'be a dictionary.' + ) + + elif not isinstance(data, pd.DataFrame): + raise exceptions.PlotlyError( + 'You must input a pandas DataFrame, or a list of dictionaries.' + ) + + # make DataFrame from data with correct column headers + col_names = ['titles', 'subtitle', 'markers', 'measures', 'ranges'] + if utils.is_sequence(data): + df = pd.DataFrame( + [ + [d[titles] for d in data] if titles else [''] * len(data), + [d[subtitles] for d in data] if subtitles else [''] * len(data), + [d[markers] for d in data] if markers else [[]] * len(data), + [d[measures] for d in data] if measures else [[]] * len(data), + [d[ranges] for d in data] if ranges else [[]] * len(data), + ], + index=col_names + ) + elif isinstance(data, pd.DataFrame): + df = pd.DataFrame( + [ + data[titles].tolist() if titles else [''] * len(data), + data[subtitles].tolist() if subtitles else [''] * len(data), + data[markers].tolist() if markers else [[]] * len(data), + data[measures].tolist() if measures else [[]] * len(data), + data[ranges].tolist() if ranges else [[]] * len(data), + ], + index=col_names + ) + df = pd.DataFrame.transpose(df) + + # make sure ranges, measures, 'markers' are not NAN or NONE + for needed_key in ['ranges', 'measures', 'markers']: + for idx, r in enumerate(df[needed_key]): + try: + r_is_nan = math.isnan(r) + if r_is_nan or r is None: + df[needed_key][idx] = [] + except TypeError: + pass + + # validate custom colors + for colors_list in [range_colors, measure_colors]: + if colors_list: + if len(colors_list) != 2: + raise exceptions.PlotlyError( + "Both 'range_colors' or 'measure_colors' must be a list " + "of two valid colors." + ) + utils.validate_colors(colors_list) + colors_list = utils.convert_colors_to_same_type(colors_list, + 'rgb')[0] + + # default scatter options + default_scatter = { + 'marker': {'size': 12, + 'symbol': 'diamond-tall', + 'color': 'rgb(0, 0, 0)'} + } + + if scatter_options == {}: + scatter_options.update(default_scatter) + else: + # add default options to scatter_options if they are not present + for k in default_scatter['marker']: + if k not in scatter_options['marker']: + scatter_options['marker'][k] = default_scatter['marker'][k] + + fig = _bullet( + df, markers, measures, ranges, subtitles, titles, orientation, + range_colors, measure_colors, horizontal_spacing, vertical_spacing, + scatter_options, layout_options, + ) + + return fig diff --git a/plotly/figure_factory/figure_factory/_candlestick.py b/plotly/figure_factory/figure_factory/_candlestick.py new file mode 100644 index 00000000000..925b4c1a62b --- /dev/null +++ b/plotly/figure_factory/figure_factory/_candlestick.py @@ -0,0 +1,294 @@ +from __future__ import absolute_import + +from plotly.figure_factory import utils +from plotly.figure_factory._ohlc import (_DEFAULT_INCREASING_COLOR, + _DEFAULT_DECREASING_COLOR, + validate_ohlc) +from plotly.graph_objs import graph_objs + + +def make_increasing_candle(open, high, low, close, dates, **kwargs): + """ + Makes boxplot trace for increasing candlesticks + + _make_increasing_candle() and _make_decreasing_candle separate the + increasing traces from the decreasing traces so kwargs (such as + color) can be passed separately to increasing or decreasing traces + when direction is set to 'increasing' or 'decreasing' in + FigureFactory.create_candlestick() + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to increasing trace via + plotly.graph_objs.Scatter. + + :rtype (list) candle_incr_data: list of the box trace for + increasing candlesticks. + """ + increase_x, increase_y = _Candlestick( + open, high, low, close, dates, **kwargs).get_candle_increase() + + if 'line' in kwargs: + kwargs.setdefault('fillcolor', kwargs['line']['color']) + else: + kwargs.setdefault('fillcolor', _DEFAULT_INCREASING_COLOR) + if 'name' in kwargs: + kwargs.setdefault('showlegend', True) + else: + kwargs.setdefault('showlegend', False) + kwargs.setdefault('name', 'Increasing') + kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR)) + + candle_incr_data = dict(type='box', + x=increase_x, + y=increase_y, + whiskerwidth=0, + boxpoints=False, + **kwargs) + + return [candle_incr_data] + + +def make_decreasing_candle(open, high, low, close, dates, **kwargs): + """ + Makes boxplot trace for decreasing candlesticks + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to decreasing trace via + plotly.graph_objs.Scatter. + + :rtype (list) candle_decr_data: list of the box trace for + decreasing candlesticks. + """ + + decrease_x, decrease_y = _Candlestick( + open, high, low, close, dates, **kwargs).get_candle_decrease() + + if 'line' in kwargs: + kwargs.setdefault('fillcolor', kwargs['line']['color']) + else: + kwargs.setdefault('fillcolor', _DEFAULT_DECREASING_COLOR) + kwargs.setdefault('showlegend', False) + kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR)) + kwargs.setdefault('name', 'Decreasing') + + candle_decr_data = dict(type='box', + x=decrease_x, + y=decrease_y, + whiskerwidth=0, + boxpoints=False, + **kwargs) + + return [candle_decr_data] + + +def create_candlestick(open, high, low, close, dates=None, direction='both', + **kwargs): + """ + BETA function that creates a candlestick chart + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param (string) direction: direction can be 'increasing', 'decreasing', + or 'both'. When the direction is 'increasing', the returned figure + consists of all candlesticks where the close value is greater than + the corresponding open value, and when the direction is + 'decreasing', the returned figure consists of all candlesticks + where the close value is less than or equal to the corresponding + open value. When the direction is 'both', both increasing and + decreasing candlesticks are returned. Default: 'both' + :param kwargs: kwargs passed through plotly.graph_objs.Scatter. + These kwargs describe other attributes about the ohlc Scatter trace + such as the color or the legend name. For more information on valid + kwargs call help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of candlestick chart figure. + + Example 1: Simple candlestick chart from a Pandas DataFrame + ``` + import plotly.plotly as py + from plotly.figure_factory import create_candlestick + from datetime import datetime + + import pandas.io.data as web + + df = web.DataReader("aapl", 'yahoo', datetime(2007, 10, 1), datetime(2009, 4, 1)) + fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index) + py.plot(fig, filename='finance/aapl-candlestick', validate=False) + ``` + + Example 2: Add text and annotations to the candlestick chart + ``` + fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index) + # Update the fig - all options here: https://plot.ly/python/reference/#Layout + fig['layout'].update({ + 'title': 'The Great Recession', + 'yaxis': {'title': 'AAPL Stock'}, + 'shapes': [{ + 'x0': '2007-12-01', 'x1': '2007-12-01', + 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper', + 'line': {'color': 'rgb(30,30,30)', 'width': 1} + }], + 'annotations': [{ + 'x': '2007-12-01', 'y': 0.05, 'xref': 'x', 'yref': 'paper', + 'showarrow': False, 'xanchor': 'left', + 'text': 'Official start of the recession' + }] + }) + py.plot(fig, filename='finance/aapl-recession-candlestick', validate=False) + ``` + + Example 3: Customize the candlestick colors + ``` + import plotly.plotly as py + from plotly.figure_factory import create_candlestick + from plotly.graph_objs import Line, Marker + from datetime import datetime + + import pandas.io.data as web + + df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1)) + + # Make increasing candlesticks and customize their color and name + fig_increasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index, + direction='increasing', name='AAPL', + marker=Marker(color='rgb(150, 200, 250)'), + line=Line(color='rgb(150, 200, 250)')) + + # Make decreasing candlesticks and customize their color and name + fig_decreasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index, + direction='decreasing', + marker=Marker(color='rgb(128, 128, 128)'), + line=Line(color='rgb(128, 128, 128)')) + + # Initialize the figure + fig = fig_increasing + + # Add decreasing data with .extend() + fig['data'].extend(fig_decreasing['data']) + + py.iplot(fig, filename='finance/aapl-candlestick-custom', validate=False) + ``` + + Example 4: Candlestick chart with datetime objects + ``` + import plotly.plotly as py + from plotly.figure_factory import create_candlestick + + from datetime import datetime + + # Add data + open_data = [33.0, 33.3, 33.5, 33.0, 34.1] + high_data = [33.1, 33.3, 33.6, 33.2, 34.8] + low_data = [32.7, 32.7, 32.8, 32.6, 32.8] + close_data = [33.0, 32.9, 33.3, 33.1, 33.1] + dates = [datetime(year=2013, month=10, day=10), + datetime(year=2013, month=11, day=10), + datetime(year=2013, month=12, day=10), + datetime(year=2014, month=1, day=10), + datetime(year=2014, month=2, day=10)] + + # Create ohlc + fig = create_candlestick(open_data, high_data, + low_data, close_data, dates=dates) + + py.iplot(fig, filename='finance/simple-candlestick', validate=False) + ``` + """ + if dates is not None: + utils.validate_equal_length(open, high, low, close, dates) + else: + utils.validate_equal_length(open, high, low, close) + validate_ohlc(open, high, low, close, direction, **kwargs) + + if direction is 'increasing': + candle_incr_data = make_increasing_candle(open, high, low, close, + dates, **kwargs) + data = candle_incr_data + elif direction is 'decreasing': + candle_decr_data = make_decreasing_candle(open, high, low, close, + dates, **kwargs) + data = candle_decr_data + else: + candle_incr_data = make_increasing_candle(open, high, low, close, + dates, **kwargs) + candle_decr_data = make_decreasing_candle(open, high, low, close, + dates, **kwargs) + data = candle_incr_data + candle_decr_data + + layout = graph_objs.Layout() + return graph_objs.Figure(data=data, layout=layout) + + +class _Candlestick(object): + """ + Refer to FigureFactory.create_candlestick() for docstring. + """ + def __init__(self, open, high, low, close, dates, **kwargs): + self.open = open + self.high = high + self.low = low + self.close = close + if dates is not None: + self.x = dates + else: + self.x = [x for x in range(len(self.open))] + self.get_candle_increase() + + def get_candle_increase(self): + """ + Separate increasing data from decreasing data. + + The data is increasing when close value > open value + and decreasing when the close value <= open value. + """ + increase_y = [] + increase_x = [] + for index in range(len(self.open)): + if self.close[index] > self.open[index]: + increase_y.append(self.low[index]) + increase_y.append(self.open[index]) + increase_y.append(self.close[index]) + increase_y.append(self.close[index]) + increase_y.append(self.close[index]) + increase_y.append(self.high[index]) + increase_x.append(self.x[index]) + + increase_x = [[x, x, x, x, x, x] for x in increase_x] + increase_x = utils.flatten(increase_x) + + return increase_x, increase_y + + def get_candle_decrease(self): + """ + Separate increasing data from decreasing data. + + The data is increasing when close value > open value + and decreasing when the close value <= open value. + """ + decrease_y = [] + decrease_x = [] + for index in range(len(self.open)): + if self.close[index] <= self.open[index]: + decrease_y.append(self.low[index]) + decrease_y.append(self.open[index]) + decrease_y.append(self.close[index]) + decrease_y.append(self.close[index]) + decrease_y.append(self.close[index]) + decrease_y.append(self.high[index]) + decrease_x.append(self.x[index]) + + decrease_x = [[x, x, x, x, x, x] for x in decrease_x] + decrease_x = utils.flatten(decrease_x) + + return decrease_x, decrease_y diff --git a/plotly/figure_factory/figure_factory/_county_choropleth.py b/plotly/figure_factory/figure_factory/_county_choropleth.py new file mode 100644 index 00000000000..d2688439d86 --- /dev/null +++ b/plotly/figure_factory/figure_factory/_county_choropleth.py @@ -0,0 +1,940 @@ +from plotly import exceptions, optional_imports + +from plotly.figure_factory import utils + +import io +import numpy as np +import os +import pandas as pd +import warnings + +from math import log, floor +from numbers import Number + +pd.options.mode.chained_assignment = None + +shapely = optional_imports.get_module('shapely') +shapefile = optional_imports.get_module('shapefile') +gp = optional_imports.get_module('geopandas') + + +def _create_us_counties_df(st_to_state_name_dict, state_to_st_dict): + # URLS + abs_file_path = os.path.realpath(__file__) + abs_dir_path = os.path.dirname(abs_file_path) + + abs_plotly_dir_path = os.path.dirname(abs_dir_path) + + abs_package_data_dir_path = os.path.join(abs_plotly_dir_path, + 'package_data') + + shape_pre2010 = 'gz_2010_us_050_00_500k.shp' + shape_pre2010 = os.path.join(abs_package_data_dir_path, shape_pre2010) + + df_shape_pre2010 = gp.read_file(shape_pre2010) + df_shape_pre2010['FIPS'] = (df_shape_pre2010['STATE'] + + df_shape_pre2010['COUNTY']) + df_shape_pre2010['FIPS'] = pd.to_numeric(df_shape_pre2010['FIPS']) + + states_path = 'cb_2016_us_state_500k.shp' + states_path = os.path.join(abs_package_data_dir_path, states_path) + + # state df + df_state = gp.read_file(states_path) + df_state = df_state[['STATEFP', 'NAME', 'geometry']] + df_state = df_state.rename(columns={'NAME': 'STATE_NAME'}) + + filenames = ['cb_2016_us_county_500k.dbf', + 'cb_2016_us_county_500k.shp', + 'cb_2016_us_county_500k.shx'] + + for j in range(len(filenames)): + filenames[j] = os.path.join(abs_package_data_dir_path, filenames[j]) + + dbf = io.open(filenames[0], 'rb') + shp = io.open(filenames[1], 'rb') + shx = io.open(filenames[2], 'rb') + + r = shapefile.Reader(shp=shp, shx=shx, dbf=dbf) + + attributes, geometry = [], [] + field_names = [field[0] for field in r.fields[1:]] + for row in r.shapeRecords(): + geometry.append(shapely.geometry.shape(row.shape.__geo_interface__)) + attributes.append(dict(zip(field_names, row.record))) + + gdf = gp.GeoDataFrame(data=attributes, geometry=geometry) + + gdf['FIPS'] = gdf['STATEFP'] + gdf['COUNTYFP'] + gdf['FIPS'] = pd.to_numeric(gdf['FIPS']) + + # add missing counties + f = 46113 + singlerow = pd.DataFrame( + [ + [st_to_state_name_dict['SD'], 'SD', + df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['geometry'].iloc[0], + df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['FIPS'].iloc[0], + '46', 'Shannon'] + ], + columns=['State', 'ST', 'geometry', 'FIPS', 'STATEFP', 'NAME'], + index=[max(gdf.index) + 1] + ) + gdf = gdf.append(singlerow) + + f = 51515 + singlerow = pd.DataFrame( + [ + [st_to_state_name_dict['VA'], 'VA', + df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['geometry'].iloc[0], + df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['FIPS'].iloc[0], + '51', 'Bedford City'] + ], + columns=['State', 'ST', 'geometry', 'FIPS', 'STATEFP', 'NAME'], + index=[max(gdf.index) + 1] + ) + gdf = gdf.append(singlerow) + + f = 2270 + singlerow = pd.DataFrame( + [ + [st_to_state_name_dict['AK'], 'AK', + df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['geometry'].iloc[0], + df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['FIPS'].iloc[0], + '02', 'Wade Hampton'] + ], + columns=['State', 'ST', 'geometry', 'FIPS', 'STATEFP', 'NAME'], + index=[max(gdf.index) + 1] + ) + gdf = gdf.append(singlerow) + + row_2198 = gdf[gdf['FIPS'] == 2198] + row_2198.index = [max(gdf.index) + 1] + row_2198.loc[row_2198.index[0], 'FIPS'] = 2201 + row_2198.loc[row_2198.index[0], 'STATEFP'] = '02' + gdf = gdf.append(row_2198) + + row_2105 = gdf[gdf['FIPS'] == 2105] + row_2105.index = [max(gdf.index) + 1] + row_2105.loc[row_2105.index[0], 'FIPS'] = 2232 + row_2105.loc[row_2105.index[0], 'STATEFP'] = '02' + gdf = gdf.append(row_2105) + gdf = gdf.rename(columns={'NAME': 'COUNTY_NAME'}) + + gdf_reduced = gdf[['FIPS', 'STATEFP', 'COUNTY_NAME', 'geometry']] + gdf_statefp = gdf_reduced.merge(df_state[['STATEFP', 'STATE_NAME']], + on='STATEFP') + + ST = [] + for n in gdf_statefp['STATE_NAME']: + ST.append(state_to_st_dict[n]) + + gdf_statefp['ST'] = ST + return gdf_statefp, df_state + + +st_to_state_name_dict = { + 'AK': 'Alaska', + 'AL': 'Alabama', + 'AR': 'Arkansas', + 'AZ': 'Arizona', + 'CA': 'California', + 'CO': 'Colorado', + 'CT': 'Connecticut', + 'DC': 'District of Columbia', + 'DE': 'Delaware', + 'FL': 'Florida', + 'GA': 'Georgia', + 'HI': 'Hawaii', + 'IA': 'Iowa', + 'ID': 'Idaho', + 'IL': 'Illinois', + 'IN': 'Indiana', + 'KS': 'Kansas', + 'KY': 'Kentucky', + 'LA': 'Louisiana', + 'MA': 'Massachusetts', + 'MD': 'Maryland', + 'ME': 'Maine', + 'MI': 'Michigan', + 'MN': 'Minnesota', + 'MO': 'Missouri', + 'MS': 'Mississippi', + 'MT': 'Montana', + 'NC': 'North Carolina', + 'ND': 'North Dakota', + 'NE': 'Nebraska', + 'NH': 'New Hampshire', + 'NJ': 'New Jersey', + 'NM': 'New Mexico', + 'NV': 'Nevada', + 'NY': 'New York', + 'OH': 'Ohio', + 'OK': 'Oklahoma', + 'OR': 'Oregon', + 'PA': 'Pennsylvania', + 'RI': 'Rhode Island', + 'SC': 'South Carolina', + 'SD': 'South Dakota', + 'TN': 'Tennessee', + 'TX': 'Texas', + 'UT': 'Utah', + 'VA': 'Virginia', + 'VT': 'Vermont', + 'WA': 'Washington', + 'WI': 'Wisconsin', + 'WV': 'West Virginia', + 'WY': 'Wyoming' +} + +state_to_st_dict = { + 'Alabama': 'AL', + 'Alaska': 'AK', + 'American Samoa': 'AS', + 'Arizona': 'AZ', + 'Arkansas': 'AR', + 'California': 'CA', + 'Colorado': 'CO', + 'Commonwealth of the Northern Mariana Islands': 'MP', + 'Connecticut': 'CT', + 'Delaware': 'DE', + 'District of Columbia': 'DC', + 'Florida': 'FL', + 'Georgia': 'GA', + 'Guam': 'GU', + 'Hawaii': 'HI', + 'Idaho': 'ID', + 'Illinois': 'IL', + 'Indiana': 'IN', + 'Iowa': 'IA', + 'Kansas': 'KS', + 'Kentucky': 'KY', + 'Louisiana': 'LA', + 'Maine': 'ME', + 'Maryland': 'MD', + 'Massachusetts': 'MA', + 'Michigan': 'MI', + 'Minnesota': 'MN', + 'Mississippi': 'MS', + 'Missouri': 'MO', + 'Montana': 'MT', + 'Nebraska': 'NE', + 'Nevada': 'NV', + 'New Hampshire': 'NH', + 'New Jersey': 'NJ', + 'New Mexico': 'NM', + 'New York': 'NY', + 'North Carolina': 'NC', + 'North Dakota': 'ND', + 'Ohio': 'OH', + 'Oklahoma': 'OK', + 'Oregon': 'OR', + 'Pennsylvania': 'PA', + 'Puerto Rico': '', + 'Rhode Island': 'RI', + 'South Carolina': 'SC', + 'South Dakota': 'SD', + 'Tennessee': 'TN', + 'Texas': 'TX', + 'United States Virgin Islands': 'VI', + 'Utah': 'UT', + 'Vermont': 'VT', + 'Virginia': 'VA', + 'Washington': 'WA', + 'West Virginia': 'WV', + 'Wisconsin': 'WI', + 'Wyoming': 'WY' +} + +USA_XRANGE = [-125.0, -65.0] +USA_YRANGE = [25.0, 49.0] + + +def _human_format(number): + units = ['', 'K', 'M', 'G', 'T', 'P'] + k = 1000.0 + magnitude = int(floor(log(number, k))) + return '%.2f%s' % (number / k**magnitude, units[magnitude]) + + +def _intervals_as_labels(array_of_intervals, round_legend_values, exponent_format): + """ + Transform an number interval to a clean string for legend + + Example: [-inf, 30] to '< 30' + """ + infs = [float('-inf'), float('inf')] + string_intervals = [] + for interval in array_of_intervals: + # round to 2nd decimal place + if round_legend_values: + rnd_interval = [ + (int(interval[i]) if interval[i] not in infs else + interval[i]) + for i in range(2) + ] + else: + rnd_interval = [round(interval[0], 2), + round(interval[1], 2)] + + num0 = rnd_interval[0] + num1 = rnd_interval[1] + if exponent_format: + if num0 not in infs: + num0 = _human_format(num0) + if num1 not in infs: + num1 = _human_format(num1) + else: + if num0 not in infs: + num0 = "{:,}".format(num0) + if num1 not in infs: + num1 = "{:,}".format(num1) + + if num0 == float('-inf'): + as_str = '< {}'.format(num1) + elif num1 == float('inf'): + as_str = '> {}'.format(num0) + else: + as_str = '{} - {}'.format(num0, num1) + string_intervals.append(as_str) + return string_intervals + + +def _calculations(df, fips, values, index, f, simplify_county, level, + x_centroids, y_centroids, centroid_text, x_traces, + y_traces, fips_polygon_map): + # 0-pad FIPS code to ensure exactly 5 digits + padded_f = str(f).zfill(5) + if fips_polygon_map[f].type == 'Polygon': + x = fips_polygon_map[f].simplify( + simplify_county + ).exterior.xy[0].tolist() + y = fips_polygon_map[f].simplify( + simplify_county + ).exterior.xy[1].tolist() + + x_c, y_c = fips_polygon_map[f].centroid.xy + county_name_str = str(df[df['FIPS'] == f]['COUNTY_NAME'].iloc[0]) + state_name_str = str(df[df['FIPS'] == f]['STATE_NAME'].iloc[0]) + + t_c = ( + 'County: ' + county_name_str + '
' + + 'State: ' + state_name_str + '
' + + 'FIPS: ' + padded_f + '
Value: ' + str(values[index]) + ) + + x_centroids.append(x_c[0]) + y_centroids.append(y_c[0]) + centroid_text.append(t_c) + + x_traces[level] = x_traces[level] + x + [np.nan] + y_traces[level] = y_traces[level] + y + [np.nan] + elif fips_polygon_map[f].type == 'MultiPolygon': + x = ([poly.simplify(simplify_county).exterior.xy[0].tolist() for + poly in fips_polygon_map[f]]) + y = ([poly.simplify(simplify_county).exterior.xy[1].tolist() for + poly in fips_polygon_map[f]]) + + x_c = [poly.centroid.xy[0].tolist() for poly in fips_polygon_map[f]] + y_c = [poly.centroid.xy[1].tolist() for poly in fips_polygon_map[f]] + + county_name_str = str(df[df['FIPS'] == f]['COUNTY_NAME'].iloc[0]) + state_name_str = str(df[df['FIPS'] == f]['STATE_NAME'].iloc[0]) + text = ( + 'County: ' + county_name_str + '
' + + 'State: ' + state_name_str + '
' + + 'FIPS: ' + padded_f + '
Value: ' + str(values[index]) + ) + t_c = [text for poly in fips_polygon_map[f]] + x_centroids = x_c + x_centroids + y_centroids = y_c + y_centroids + centroid_text = t_c + centroid_text + for x_y_idx in range(len(x)): + x_traces[level] = x_traces[level] + x[x_y_idx] + [np.nan] + y_traces[level] = y_traces[level] + y[x_y_idx] + [np.nan] + + return x_traces, y_traces, x_centroids, y_centroids, centroid_text + + +def create_choropleth(fips, values, scope=['usa'], binning_endpoints=None, + colorscale=None, order=None, simplify_county=0.02, + simplify_state=0.02, asp=None, show_hover=True, + show_state_data=True, state_outline=None, + county_outline=None, centroid_marker=None, + round_legend_values=False, exponent_format=False, + legend_title='', **layout_options): + """ + Returns figure for county choropleth. Uses data from package_data. + + :param (list) fips: list of FIPS values which correspond to the con + catination of state and county ids. An example is '01001'. + :param (list) values: list of numbers/strings which correspond to the + fips list. These are the values that will determine how the counties + are colored. + :param (list) scope: list of states and/or states abbreviations. Fits + all states in the camera tightly. Selecting ['usa'] is the equivalent + of appending all 50 states into your scope list. Selecting only 'usa' + does not include 'Alaska', 'Puerto Rico', 'American Samoa', + 'Commonwealth of the Northern Mariana Islands', 'Guam', + 'United States Virgin Islands'. These must be added manually to the + list. + Default = ['usa'] + :param (list) binning_endpoints: ascending numbers which implicitly define + real number intervals which are used as bins. The colorscale used must + have the same number of colors as the number of bins and this will + result in a categorical colormap. + :param (list) colorscale: a list of colors with length equal to the + number of categories of colors. The length must match either all + unique numbers in the 'values' list or if endpoints is being used, the + number of categories created by the endpoints.\n + For example, if binning_endpoints = [4, 6, 8], then there are 4 bins: + [-inf, 4), [4, 6), [6, 8), [8, inf) + :param (list) order: a list of the unique categories (numbers/bins) in any + desired order. This is helpful if you want to order string values to + a chosen colorscale. + :param (float) simplify_county: determines the simplification factor + for the counties. The larger the number, the fewer vertices and edges + each polygon has. See + http://toblerity.org/shapely/manual.html#object.simplify for more + information. + Default = 0.02 + :param (float) simplify_state: simplifies the state outline polygon. + See http://toblerity.org/shapely/manual.html#object.simplify for more + information. + Default = 0.02 + :param (float) asp: the width-to-height aspect ratio for the camera. + Default = 2.5 + :param (bool) show_hover: show county hover and centroid info + :param (bool) show_state_data: reveals state boundary lines + :param (dict) state_outline: dict of attributes of the state outline + including width and color. See + https://plot.ly/python/reference/#scatter-marker-line for all valid + params + :param (dict) county_outline: dict of attributes of the county outline + including width and color. See + https://plot.ly/python/reference/#scatter-marker-line for all valid + params + :param (dict) centroid_marker: dict of attributes of the centroid marker. + The centroid markers are invisible by default and appear visible on + selection. See https://plot.ly/python/reference/#scatter-marker for + all valid params + :param (bool) round_legend_values: automatically round the numbers that + appear in the legend to the nearest integer. + Default = False + :param (bool) exponent_format: if set to True, puts numbers in the K, M, + B number format. For example 4000.0 becomes 4.0K + Default = False + :param (str) legend_title: title that appears above the legend + :param **layout_options: a **kwargs argument for all layout parameters + + + Example 1: Florida + ``` + import plotly.plotly as py + import plotly.figure_factory as ff + + import numpy as np + import pandas as pd + + df_sample = pd.read_csv( + 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv' + ) + df_sample_r = df_sample[df_sample['STNAME'] == 'Florida'] + + values = df_sample_r['TOT_POP'].tolist() + fips = df_sample_r['FIPS'].tolist() + + binning_endpoints = list(np.mgrid[min(values):max(values):4j]) + colorscale = ["#030512","#1d1d3b","#323268","#3d4b94","#3e6ab0", + "#4989bc","#60a7c7","#85c5d3","#b7e0e4","#eafcfd"] + fig = ff.create_choropleth( + fips=fips, values=values, scope=['Florida'], show_state_data=True, + colorscale=colorscale, binning_endpoints=binning_endpoints, + round_legend_values=True, plot_bgcolor='rgb(229,229,229)', + paper_bgcolor='rgb(229,229,229)', legend_title='Florida Population', + county_outline={'color': 'rgb(255,255,255)', 'width': 0.5}, + exponent_format=True, + ) + py.iplot(fig, filename='choropleth_florida') + ``` + + Example 2: New England + ``` + import plotly.plotly as py + import plotly.figure_factory as ff + + import pandas as pd + + NE_states = ['Connecticut', 'Maine', 'Massachusetts', + 'New Hampshire', 'Rhode Island'] + df_sample = pd.read_csv( + 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv' + ) + df_sample_r = df_sample[df_sample['STNAME'].isin(NE_states)] + colorscale = ['rgb(68.0, 1.0, 84.0)', + 'rgb(66.0, 64.0, 134.0)', + 'rgb(38.0, 130.0, 142.0)', + 'rgb(63.0, 188.0, 115.0)', + 'rgb(216.0, 226.0, 25.0)'] + + values = df_sample_r['TOT_POP'].tolist() + fips = df_sample_r['FIPS'].tolist() + fig = ff.create_choropleth( + fips=fips, values=values, scope=NE_states, show_state_data=True + ) + py.iplot(fig, filename='choropleth_new_england') + ``` + + Example 3: California and Surrounding States + ``` + import plotly.plotly as py + import plotly.figure_factory as ff + + import pandas as pd + + df_sample = pd.read_csv( + 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv' + ) + df_sample_r = df_sample[df_sample['STNAME'] == 'California'] + + values = df_sample_r['TOT_POP'].tolist() + fips = df_sample_r['FIPS'].tolist() + + colorscale = [ + 'rgb(193, 193, 193)', + 'rgb(239,239,239)', + 'rgb(195, 196, 222)', + 'rgb(144,148,194)', + 'rgb(101,104,168)', + 'rgb(65, 53, 132)' + ] + + fig = ff.create_choropleth( + fips=fips, values=values, colorscale=colorscale, + scope=['CA', 'AZ', 'Nevada', 'Oregon', ' Idaho'], + binning_endpoints=[14348, 63983, 134827, 426762, 2081313], + county_outline={'color': 'rgb(255,255,255)', 'width': 0.5}, + legend_title='California Counties', + title='California and Nearby States' + ) + py.iplot(fig, filename='choropleth_california_and_surr_states_outlines') + ``` + + Example 4: USA + ``` + import plotly.plotly as py + import plotly.figure_factory as ff + + import numpy as np + import pandas as pd + + df_sample = pd.read_csv( + 'https://raw.githubusercontent.com/plotly/datasets/master/laucnty16.csv' + ) + df_sample['State FIPS Code'] = df_sample['State FIPS Code'].apply( + lambda x: str(x).zfill(2) + ) + df_sample['County FIPS Code'] = df_sample['County FIPS Code'].apply( + lambda x: str(x).zfill(3) + ) + df_sample['FIPS'] = ( + df_sample['State FIPS Code'] + df_sample['County FIPS Code'] + ) + + binning_endpoints = list(np.linspace(1, 12, len(colorscale) - 1)) + colorscale = ["#f7fbff", "#ebf3fb", "#deebf7", "#d2e3f3", "#c6dbef", + "#b3d2e9", "#9ecae1", "#85bcdb", "#6baed6", "#57a0ce", + "#4292c6", "#3082be", "#2171b5", "#1361a9", "#08519c", + "#0b4083","#08306b"] + fips = df_sample['FIPS'] + values = df_sample['Unemployment Rate (%)'] + fig = ff.create_choropleth( + fips=fips, values=values, scope=['usa'], + binning_endpoints=binning_endpoints, colorscale=colorscale, + show_hover=True, centroid_marker={'opacity': 0}, + asp=2.9, title='USA by Unemployment %', + legend_title='Unemployment %' + ) + + py.iplot(fig, filename='choropleth_full_usa') + ``` + """ + # ensure optional modules imported + if not gp or not shapefile or not shapely: + raise ImportError( + "geopandas, pyshp and shapely must be installed for this figure " + "factory.\n\nRun the following commands to install the correct " + "versions of the following modules:\n\n" + "```\n" + "pip install geopandas==0.3.0\n" + "pip install pyshp==1.2.10\n" + "pip install shapely==1.6.3\n" + "```\n" + "If you are using Windows, follow this post to properly " + "install geopandas and dependencies:" + "http://geoffboeing.com/2014/09/using-geopandas-windows/\n\n" + "If you are using Anaconda, do not use PIP to install the " + "packages above. Instead use conda to install them:\n\n" + "```\n" + "conda install plotly\n" + "conda install geopandas\n" + "```" + ) + + df, df_state = _create_us_counties_df(st_to_state_name_dict, + state_to_st_dict) + + fips_polygon_map = dict( + zip( + df['FIPS'].tolist(), + df['geometry'].tolist() + ) + ) + + if not state_outline: + state_outline = {'color': 'rgb(240, 240, 240)', + 'width': 1} + if not county_outline: + county_outline = {'color': 'rgb(0, 0, 0)', + 'width': 0} + if not centroid_marker: + centroid_marker = {'size': 3, 'color': 'white', 'opacity': 1} + + # ensure centroid markers appear on selection + if 'opacity' not in centroid_marker: + centroid_marker.update({'opacity': 1}) + + if len(fips) != len(values): + raise exceptions.PlotlyError( + 'fips and values must be the same length' + ) + + # make fips, values into lists + if isinstance(fips, pd.core.series.Series): + fips = fips.tolist() + if isinstance(values, pd.core.series.Series): + values = values.tolist() + + # make fips numeric + fips = map(lambda x: int(x), fips) + + if binning_endpoints: + intervals = utils.endpts_to_intervals(binning_endpoints) + LEVELS = _intervals_as_labels(intervals, round_legend_values, + exponent_format) + else: + if not order: + LEVELS = sorted(list(set(values))) + else: + # check if order is permutation + # of unique color col values + same_sets = sorted(list(set(values))) == set(order) + no_duplicates = not any(order.count(x) > 1 for x in order) + if same_sets and no_duplicates: + LEVELS = order + else: + raise exceptions.PlotlyError( + 'if you are using a custom order of unique values from ' + 'your color column, you must: have all the unique values ' + 'in your order and have no duplicate items' + ) + + if not colorscale: + colorscale = [] + viridis_colors = utils.colorscale_to_colors( + utils.PLOTLY_SCALES['Viridis'] + ) + viridis_colors = utils.color_parser( + viridis_colors, utils.hex_to_rgb + ) + viridis_colors = utils.color_parser( + viridis_colors, utils.label_rgb + ) + viri_len = len(viridis_colors) + 1 + viri_intervals = utils.endpts_to_intervals( + list(np.linspace(0, 1, viri_len)) + )[1:-1] + + for L in np.linspace(0, 1, len(LEVELS)): + for idx, inter in enumerate(viri_intervals): + if L == 0: + break + elif inter[0] < L <= inter[1]: + break + + intermed = ((L - viri_intervals[idx][0]) / + (viri_intervals[idx][1] - viri_intervals[idx][0])) + + float_color = utils.find_intermediate_color( + viridis_colors[idx], + viridis_colors[idx], + intermed, + colortype='rgb' + ) + + # make R,G,B into int values + float_color = utils.unlabel_rgb(float_color) + float_color = utils.unconvert_from_RGB_255(float_color) + int_rgb = utils.convert_to_RGB_255(float_color) + int_rgb = utils.label_rgb(int_rgb) + + colorscale.append(int_rgb) + + if len(colorscale) < len(LEVELS): + raise exceptions.PlotlyError( + "You have {} LEVELS. Your number of colors in 'colorscale' must " + "be at least the number of LEVELS: {}. If you are " + "using 'binning_endpoints' then 'colorscale' must have at " + "least len(binning_endpoints) + 2 colors".format( + len(LEVELS), min(LEVELS, LEVELS[:20]) + ) + ) + + color_lookup = dict(zip(LEVELS, colorscale)) + x_traces = dict(zip(LEVELS, [[] for i in range(len(LEVELS))])) + y_traces = dict(zip(LEVELS, [[] for i in range(len(LEVELS))])) + + # scope + if isinstance(scope, str): + raise exceptions.PlotlyError( + "'scope' must be a list/tuple/sequence" + ) + + scope_names = [] + extra_states = ['Alaska', 'Commonwealth of the Northern Mariana Islands', + 'Puerto Rico', 'Guam', 'United States Virgin Islands', + 'American Samoa'] + for state in scope: + if state.lower() == 'usa': + scope_names = df['STATE_NAME'].unique() + scope_names = list(scope_names) + for ex_st in extra_states: + try: + scope_names.remove(ex_st) + except ValueError: + pass + else: + if state in st_to_state_name_dict.keys(): + state = st_to_state_name_dict[state] + scope_names.append(state) + df_state = df_state[df_state['STATE_NAME'].isin(scope_names)] + + plot_data = [] + x_centroids = [] + y_centroids = [] + centroid_text = [] + fips_not_in_shapefile = [] + if not binning_endpoints: + for index, f in enumerate(fips): + level = values[index] + try: + fips_polygon_map[f].type + + (x_traces, y_traces, x_centroids, + y_centroids, centroid_text) = _calculations( + df, fips, values, index, f, simplify_county, level, + x_centroids, y_centroids, centroid_text, x_traces, + y_traces, fips_polygon_map + ) + except KeyError: + fips_not_in_shapefile.append(f) + + else: + for index, f in enumerate(fips): + for j, inter in enumerate(intervals): + if inter[0] < values[index] <= inter[1]: + break + level = LEVELS[j] + + try: + fips_polygon_map[f].type + + (x_traces, y_traces, x_centroids, + y_centroids, centroid_text) = _calculations( + df, fips, values, index, f, simplify_county, level, + x_centroids, y_centroids, centroid_text, x_traces, + y_traces, fips_polygon_map + ) + except KeyError: + fips_not_in_shapefile.append(f) + + if len(fips_not_in_shapefile) > 0: + msg = ( + 'Unrecognized FIPS Values\n\nWhoops! It looks like you are ' + 'trying to pass at least one FIPS value that is not in ' + 'our shapefile of FIPS and data for the counties. Your ' + 'choropleth will still show up but these counties cannot ' + 'be shown.\nUnrecognized FIPS are: {}'.format( + fips_not_in_shapefile + ) + ) + warnings.warn(msg) + + x_states = [] + y_states = [] + for index, row in df_state.iterrows(): + if df_state['geometry'][index].type == 'Polygon': + x = row.geometry.simplify(simplify_state).exterior.xy[0].tolist() + y = row.geometry.simplify(simplify_state).exterior.xy[1].tolist() + x_states = x_states + x + y_states = y_states + y + elif df_state['geometry'][index].type == 'MultiPolygon': + x = ([poly.simplify(simplify_state).exterior.xy[0].tolist() for + poly in df_state['geometry'][index]]) + y = ([poly.simplify(simplify_state).exterior.xy[1].tolist() for + poly in df_state['geometry'][index]]) + for segment in range(len(x)): + x_states = x_states + x[segment] + y_states = y_states + y[segment] + x_states.append(np.nan) + y_states.append(np.nan) + x_states.append(np.nan) + y_states.append(np.nan) + + for lev in LEVELS: + county_data = dict( + type='scatter', + mode='lines', + x=x_traces[lev], + y=y_traces[lev], + line=county_outline, + fill='toself', + fillcolor=color_lookup[lev], + name=lev, + hoverinfo='none', + ) + plot_data.append(county_data) + + if show_hover: + hover_points = dict( + type='scatter', + showlegend=False, + legendgroup='centroids', + x=x_centroids, + y=y_centroids, + text=centroid_text, + name='US Counties', + mode='markers', + marker={'color': 'white', 'opacity': 0}, + hoverinfo='text' + ) + centroids_on_select = dict( + selected=dict(marker=centroid_marker), + unselected=dict(marker=dict(opacity=0)) + ) + hover_points.update(centroids_on_select) + plot_data.append(hover_points) + + if show_state_data: + state_data = dict( + type='scatter', + legendgroup='States', + line=state_outline, + x=x_states, + y=y_states, + hoverinfo='text', + showlegend=False, + mode='lines' + ) + plot_data.append(state_data) + + DEFAULT_LAYOUT = dict( + hovermode='closest', + xaxis=dict( + autorange=False, + range=USA_XRANGE, + showgrid=False, + zeroline=False, + fixedrange=True, + showticklabels=False + ), + yaxis=dict( + autorange=False, + range=USA_YRANGE, + showgrid=False, + zeroline=False, + fixedrange=True, + showticklabels=False + ), + margin=dict(t=40, b=20, r=20, l=20), + width=900, + height=450, + dragmode='select', + legend=dict( + traceorder='reversed', + xanchor='right', + yanchor='top', + x=1, + y=1 + ), + annotations=[] + ) + fig = dict(data=plot_data, layout=DEFAULT_LAYOUT) + fig['layout'].update(layout_options) + fig['layout']['annotations'].append( + dict( + x=1, + y=1.05, + xref='paper', + yref='paper', + xanchor='right', + showarrow=False, + text='' + legend_title + '' + ) + ) + + if len(scope) == 1 and scope[0].lower() == 'usa': + xaxis_range_low = -125.0 + xaxis_range_high = -55.0 + yaxis_range_low = 25.0 + yaxis_range_high = 49.0 + else: + xaxis_range_low = float('inf') + xaxis_range_high = float('-inf') + yaxis_range_low = float('inf') + yaxis_range_high = float('-inf') + for trace in fig['data']: + if all(isinstance(n, Number) for n in trace['x']): + calc_x_min = min(trace['x'] or [float('inf')]) + calc_x_max = max(trace['x'] or [float('-inf')]) + if calc_x_min < xaxis_range_low: + xaxis_range_low = calc_x_min + if calc_x_max > xaxis_range_high: + xaxis_range_high = calc_x_max + if all(isinstance(n, Number) for n in trace['y']): + calc_y_min = min(trace['y'] or [float('inf')]) + calc_y_max = max(trace['y'] or [float('-inf')]) + if calc_y_min < yaxis_range_low: + yaxis_range_low = calc_y_min + if calc_y_max > yaxis_range_high: + yaxis_range_high = calc_y_max + + # camera zoom + fig['layout']['xaxis']['range'] = [xaxis_range_low, xaxis_range_high] + fig['layout']['yaxis']['range'] = [yaxis_range_low, yaxis_range_high] + + # aspect ratio + if asp is None: + usa_x_range = USA_XRANGE[1] - USA_XRANGE[0] + usa_y_range = USA_YRANGE[1] - USA_YRANGE[0] + asp = usa_x_range / usa_y_range + + # based on your figure + width = float(fig['layout']['xaxis']['range'][1] - + fig['layout']['xaxis']['range'][0]) + height = float(fig['layout']['yaxis']['range'][1] - + fig['layout']['yaxis']['range'][0]) + + center = (sum(fig['layout']['xaxis']['range']) / 2., + sum(fig['layout']['yaxis']['range']) / 2.) + + if height / width > (1 / asp): + new_width = asp * height + fig['layout']['xaxis']['range'][0] = center[0] - new_width * 0.5 + fig['layout']['xaxis']['range'][1] = center[0] + new_width * 0.5 + else: + new_height = (1 / asp) * width + fig['layout']['yaxis']['range'][0] = center[1] - new_height * 0.5 + fig['layout']['yaxis']['range'][1] = center[1] + new_height * 0.5 + + return fig diff --git a/plotly/figure_factory/figure_factory/_dendrogram.py b/plotly/figure_factory/figure_factory/_dendrogram.py new file mode 100644 index 00000000000..4bafc976c1d --- /dev/null +++ b/plotly/figure_factory/figure_factory/_dendrogram.py @@ -0,0 +1,330 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import + +from collections import OrderedDict + +from plotly import exceptions, optional_imports +from plotly.graph_objs import graph_objs + +# Optional imports, may be None for users that only use our core functionality. +np = optional_imports.get_module('numpy') +scp = optional_imports.get_module('scipy') +sch = optional_imports.get_module('scipy.cluster.hierarchy') +scs = optional_imports.get_module('scipy.spatial') + + +def create_dendrogram(X, orientation="bottom", labels=None, + colorscale=None, distfun=None, + linkagefun=lambda x: sch.linkage(x, 'complete'), + hovertext=None, color_threshold=None): + """ + BETA function that returns a dendrogram Plotly figure object. + + :param (ndarray) X: Matrix of observations as array of arrays + :param (str) orientation: 'top', 'right', 'bottom', or 'left' + :param (list) labels: List of axis category labels(observation labels) + :param (list) colorscale: Optional colorscale for dendrogram tree + :param (function) distfun: Function to compute the pairwise distance from + the observations + :param (function) linkagefun: Function to compute the linkage matrix from + the pairwise distances + :param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram + clusters + :param (double) color_threshold: Value at which the separation of clusters will be made + + Example 1: Simple bottom oriented dendrogram + ``` + import plotly.plotly as py + from plotly.figure_factory import create_dendrogram + + import numpy as np + + X = np.random.rand(10,10) + dendro = create_dendrogram(X) + plot_url = py.plot(dendro, filename='simple-dendrogram') + + ``` + + Example 2: Dendrogram to put on the left of the heatmap + ``` + import plotly.plotly as py + from plotly.figure_factory import create_dendrogram + + import numpy as np + + X = np.random.rand(5,5) + names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark'] + dendro = create_dendrogram(X, orientation='right', labels=names) + dendro['layout'].update({'width':700, 'height':500}) + + py.iplot(dendro, filename='vertical-dendrogram') + ``` + + Example 3: Dendrogram with Pandas + ``` + import plotly.plotly as py + from plotly.figure_factory import create_dendrogram + + import numpy as np + import pandas as pd + + Index= ['A','B','C','D','E','F','G','H','I','J'] + df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index) + fig = create_dendrogram(df, labels=Index) + url = py.plot(fig, filename='pandas-dendrogram') + ``` + """ + if not scp or not scs or not sch: + raise ImportError("FigureFactory.create_dendrogram requires scipy, \ + scipy.spatial and scipy.hierarchy") + + s = X.shape + if len(s) != 2: + exceptions.PlotlyError("X should be 2-dimensional array.") + + if distfun is None: + distfun = scs.distance.pdist + + dendrogram = _Dendrogram(X, orientation, labels, colorscale, + distfun=distfun, linkagefun=linkagefun, + hovertext=hovertext, color_threshold=color_threshold) + + return graph_objs.Figure(data=dendrogram.data, + layout=dendrogram.layout) + + +class _Dendrogram(object): + """Refer to FigureFactory.create_dendrogram() for docstring.""" + + def __init__(self, X, orientation='bottom', labels=None, colorscale=None, + width=np.inf, height=np.inf, xaxis='xaxis', yaxis='yaxis', + distfun=None, + linkagefun=lambda x: sch.linkage(x, 'complete'), + hovertext=None, color_threshold=None): + self.orientation = orientation + self.labels = labels + self.xaxis = xaxis + self.yaxis = yaxis + self.data = [] + self.leaves = [] + self.sign = {self.xaxis: 1, self.yaxis: 1} + self.layout = {self.xaxis: {}, self.yaxis: {}} + + if self.orientation in ['left', 'bottom']: + self.sign[self.xaxis] = 1 + else: + self.sign[self.xaxis] = -1 + + if self.orientation in ['right', 'bottom']: + self.sign[self.yaxis] = 1 + else: + self.sign[self.yaxis] = -1 + + if distfun is None: + distfun = scs.distance.pdist + + (dd_traces, xvals, yvals, + ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale, + distfun, + linkagefun, + hovertext, + color_threshold) + + self.labels = ordered_labels + self.leaves = leaves + yvals_flat = yvals.flatten() + xvals_flat = xvals.flatten() + + self.zero_vals = [] + + for i in range(len(yvals_flat)): + if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals: + self.zero_vals.append(xvals_flat[i]) + + if len(self.zero_vals) > len(yvals) + 1: + # If the length of zero_vals is larger than the length of yvals, + # it means that there are wrong vals because of the identicial samples. + # Three and more identicial samples will make the yvals of spliting center into 0 and it will \ + # accidentally take it as leaves. + l_border = int(min(self.zero_vals)) + r_border = int(max(self.zero_vals)) + correct_leaves_pos = range(l_border, + r_border + 1, + int((r_border - l_border) / len(yvals))) + # Regenerating the leaves pos from the self.zero_vals with equally intervals. + self.zero_vals = [v for v in correct_leaves_pos] + + self.zero_vals.sort() + self.layout = self.set_figure_layout(width, height) + self.data = dd_traces + + def get_color_dict(self, colorscale): + """ + Returns colorscale used for dendrogram tree clusters. + + :param (list) colorscale: Colors to use for the plot in rgb format. + :rtype (dict): A dict of default colors mapped to the user colorscale. + + """ + + # These are the color codes returned for dendrograms + # We're replacing them with nicer colors + d = {'r': 'red', + 'g': 'green', + 'b': 'blue', + 'c': 'cyan', + 'm': 'magenta', + 'y': 'yellow', + 'k': 'black', + 'w': 'white'} + default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0])) + + if colorscale is None: + colorscale = [ + 'rgb(0,116,217)', # blue + 'rgb(35,205,205)', # cyan + 'rgb(61,153,112)', # green + 'rgb(40,35,35)', # black + 'rgb(133,20,75)', # magenta + 'rgb(255,65,54)', # red + 'rgb(255,255,255)', # white + 'rgb(255,220,0)'] # yellow + + for i in range(len(default_colors.keys())): + k = list(default_colors.keys())[i] # PY3 won't index keys + if i < len(colorscale): + default_colors[k] = colorscale[i] + + return default_colors + + def set_axis_layout(self, axis_key): + """ + Sets and returns default axis object for dendrogram figure. + + :param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc. + :rtype (dict): An axis_key dictionary with set parameters. + + """ + axis_defaults = { + 'type': 'linear', + 'ticks': 'outside', + 'mirror': 'allticks', + 'rangemode': 'tozero', + 'showticklabels': True, + 'zeroline': False, + 'showgrid': False, + 'showline': True, + } + + if len(self.labels) != 0: + axis_key_labels = self.xaxis + if self.orientation in ['left', 'right']: + axis_key_labels = self.yaxis + if axis_key_labels not in self.layout: + self.layout[axis_key_labels] = {} + self.layout[axis_key_labels]['tickvals'] = \ + [zv*self.sign[axis_key] for zv in self.zero_vals] + self.layout[axis_key_labels]['ticktext'] = self.labels + self.layout[axis_key_labels]['tickmode'] = 'array' + + self.layout[axis_key].update(axis_defaults) + + return self.layout[axis_key] + + def set_figure_layout(self, width, height): + """ + Sets and returns default layout object for dendrogram figure. + + """ + self.layout.update({ + 'showlegend': False, + 'autosize': False, + 'hovermode': 'closest', + 'width': width, + 'height': height + }) + + self.set_axis_layout(self.xaxis) + self.set_axis_layout(self.yaxis) + + return self.layout + + def get_dendrogram_traces(self, X, colorscale, distfun, linkagefun, hovertext, color_threshold): + """ + Calculates all the elements needed for plotting a dendrogram. + + :param (ndarray) X: Matrix of observations as array of arrays + :param (list) colorscale: Color scale for dendrogram tree clusters + :param (function) distfun: Function to compute the pairwise distance + from the observations + :param (function) linkagefun: Function to compute the linkage matrix + from the pairwise distances + :param (list) hovertext: List of hovertext for constituent traces of dendrogram + :rtype (tuple): Contains all the traces in the following order: + (a) trace_list: List of Plotly trace objects for dendrogram tree + (b) icoord: All X points of the dendrogram tree as array of arrays + with length 4 + (c) dcoord: All Y points of the dendrogram tree as array of arrays + with length 4 + (d) ordered_labels: leaf labels in the order they are going to + appear on the plot + (e) P['leaves']: left-to-right traversal of the leaves + + """ + d = distfun(X) + Z = linkagefun(d) + P = sch.dendrogram(Z, orientation=self.orientation, + labels=self.labels, no_plot=True, + color_threshold=color_threshold) + + icoord = scp.array(P['icoord']) + dcoord = scp.array(P['dcoord']) + ordered_labels = scp.array(P['ivl']) + color_list = scp.array(P['color_list']) + colors = self.get_color_dict(colorscale) + + trace_list = [] + + for i in range(len(icoord)): + # xs and ys are arrays of 4 points that make up the '∩' shapes + # of the dendrogram tree + if self.orientation in ['top', 'bottom']: + xs = icoord[i] + else: + xs = dcoord[i] + + if self.orientation in ['top', 'bottom']: + ys = dcoord[i] + else: + ys = icoord[i] + color_key = color_list[i] + hovertext_label = None + if hovertext: + hovertext_label = hovertext[i] + trace = dict( + type='scatter', + x=np.multiply(self.sign[self.xaxis], xs), + y=np.multiply(self.sign[self.yaxis], ys), + mode='lines', + marker=dict(color=colors[color_key]), + text=hovertext_label, + hoverinfo='text' + ) + + try: + x_index = int(self.xaxis[-1]) + except ValueError: + x_index = '' + + try: + y_index = int(self.yaxis[-1]) + except ValueError: + y_index = '' + + trace['xaxis'] = 'x' + x_index + trace['yaxis'] = 'y' + y_index + + trace_list.append(trace) + + return trace_list, icoord, dcoord, ordered_labels, P['leaves'] diff --git a/plotly/figure_factory/figure_factory/_distplot.py b/plotly/figure_factory/figure_factory/_distplot.py new file mode 100644 index 00000000000..0d88847eeb4 --- /dev/null +++ b/plotly/figure_factory/figure_factory/_distplot.py @@ -0,0 +1,390 @@ +from __future__ import absolute_import + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs + +# Optional imports, may be None for users that only use our core functionality. +np = optional_imports.get_module('numpy') +pd = optional_imports.get_module('pandas') +scipy = optional_imports.get_module('scipy') +scipy_stats = optional_imports.get_module('scipy.stats') + + +DEFAULT_HISTNORM = 'probability density' +ALTERNATIVE_HISTNORM = 'probability' + + +def validate_distplot(hist_data, curve_type): + """ + Distplot-specific validations + + :raises: (PlotlyError) If hist_data is not a list of lists + :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or + 'normal'). + """ + hist_data_types = (list,) + if np: + hist_data_types += (np.ndarray,) + if pd: + hist_data_types += (pd.core.series.Series,) + + if not isinstance(hist_data[0], hist_data_types): + raise exceptions.PlotlyError("Oops, this function was written " + "to handle multiple datasets, if " + "you want to plot just one, make " + "sure your hist_data variable is " + "still a list of lists, i.e. x = " + "[1, 2, 3] -> x = [[1, 2, 3]]") + + curve_opts = ('kde', 'normal') + if curve_type not in curve_opts: + raise exceptions.PlotlyError("curve_type must be defined as " + "'kde' or 'normal'") + + if not scipy: + raise ImportError("FigureFactory.create_distplot requires scipy") + + +def create_distplot(hist_data, group_labels, bin_size=1., curve_type='kde', + colors=None, rug_text=None, histnorm=DEFAULT_HISTNORM, + show_hist=True, show_curve=True, show_rug=True): + """ + BETA function that creates a distplot similar to seaborn.distplot + + The distplot can be composed of all or any combination of the following + 3 components: (1) histogram, (2) curve: (a) kernel density estimation + or (b) normal curve, and (3) rug plot. Additionally, multiple distplots + (from multiple datasets) can be created in the same plot. + + :param (list[list]) hist_data: Use list of lists to plot multiple data + sets on the same plot. + :param (list[str]) group_labels: Names for each data set. + :param (list[float]|float) bin_size: Size of histogram bins. + Default = 1. + :param (str) curve_type: 'kde' or 'normal'. Default = 'kde' + :param (str) histnorm: 'probability density' or 'probability' + Default = 'probability density' + :param (bool) show_hist: Add histogram to distplot? Default = True + :param (bool) show_curve: Add curve to distplot? Default = True + :param (bool) show_rug: Add rug to distplot? Default = True + :param (list[str]) colors: Colors for traces. + :param (list[list]) rug_text: Hovertext values for rug_plot, + :return (dict): Representation of a distplot figure. + + Example 1: Simple distplot of 1 data set + ``` + import plotly.plotly as py + from plotly.figure_factory import create_distplot + + hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5, + 3.5, 4.1, 4.4, 4.5, 4.5, + 5.0, 5.0, 5.2, 5.5, 5.5, + 5.5, 5.5, 5.5, 6.1, 7.0]] + + group_labels = ['distplot example'] + + fig = create_distplot(hist_data, group_labels) + + url = py.plot(fig, filename='Simple distplot', validate=False) + ``` + + Example 2: Two data sets and added rug text + ``` + import plotly.plotly as py + from plotly.figure_factory import create_distplot + + # Add histogram data + hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6, + -0.9, -0.07, 1.95, 0.9, -0.2, + -0.5, 0.3, 0.4, -0.37, 0.6] + hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59, + 1.0, 0.8, 1.7, 0.5, 0.8, + -0.3, 1.2, 0.56, 0.3, 2.2] + + # Group data together + hist_data = [hist1_x, hist2_x] + + group_labels = ['2012', '2013'] + + # Add text + rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1', + 'f1', 'g1', 'h1', 'i1', 'j1', + 'k1', 'l1', 'm1', 'n1', 'o1'] + + rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2', + 'f2', 'g2', 'h2', 'i2', 'j2', + 'k2', 'l2', 'm2', 'n2', 'o2'] + + # Group text together + rug_text_all = [rug_text_1, rug_text_2] + + # Create distplot + fig = create_distplot( + hist_data, group_labels, rug_text=rug_text_all, bin_size=.2) + + # Add title + fig['layout'].update(title='Dist Plot') + + # Plot! + url = py.plot(fig, filename='Distplot with rug text', validate=False) + ``` + + Example 3: Plot with normal curve and hide rug plot + ``` + import plotly.plotly as py + from plotly.figure_factory import create_distplot + import numpy as np + + x1 = np.random.randn(190) + x2 = np.random.randn(200)+1 + x3 = np.random.randn(200)-1 + x4 = np.random.randn(210)+2 + + hist_data = [x1, x2, x3, x4] + group_labels = ['2012', '2013', '2014', '2015'] + + fig = create_distplot( + hist_data, group_labels, curve_type='normal', + show_rug=False, bin_size=.4) + + url = py.plot(fig, filename='hist and normal curve', validate=False) + + Example 4: Distplot with Pandas + ``` + import plotly.plotly as py + from plotly.figure_factory import create_distplot + import numpy as np + import pandas as pd + + df = pd.DataFrame({'2012': np.random.randn(200), + '2013': np.random.randn(200)+1}) + py.iplot(create_distplot([df[c] for c in df.columns], df.columns), + filename='examples/distplot with pandas', + validate=False) + ``` + """ + if colors is None: + colors = [] + if rug_text is None: + rug_text = [] + + validate_distplot(hist_data, curve_type) + utils.validate_equal_length(hist_data, group_labels) + + if isinstance(bin_size, (float, int)): + bin_size = [bin_size] * len(hist_data) + + hist = _Distplot( + hist_data, histnorm, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_hist() + + if curve_type == 'normal': + curve = _Distplot( + hist_data, histnorm, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_normal() + else: + curve = _Distplot( + hist_data, histnorm, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_kde() + + rug = _Distplot( + hist_data, histnorm, group_labels, bin_size, + curve_type, colors, rug_text, + show_hist, show_curve).make_rug() + + data = [] + if show_hist: + data.append(hist) + if show_curve: + data.append(curve) + if show_rug: + data.append(rug) + layout = graph_objs.Layout( + barmode='overlay', + hovermode='closest', + legend=dict(traceorder='reversed'), + xaxis1=dict(domain=[0.0, 1.0], + anchor='y2', + zeroline=False), + yaxis1=dict(domain=[0.35, 1], + anchor='free', + position=0.0), + yaxis2=dict(domain=[0, 0.25], + anchor='x1', + dtick=1, + showticklabels=False)) + else: + layout = graph_objs.Layout( + barmode='overlay', + hovermode='closest', + legend=dict(traceorder='reversed'), + xaxis1=dict(domain=[0.0, 1.0], + anchor='y2', + zeroline=False), + yaxis1=dict(domain=[0., 1], + anchor='free', + position=0.0)) + + data = sum(data, []) + return graph_objs.Figure(data=data, layout=layout) + + +class _Distplot(object): + """ + Refer to TraceFactory.create_distplot() for docstring + """ + def __init__(self, hist_data, histnorm, group_labels, + bin_size, curve_type, colors, + rug_text, show_hist, show_curve): + self.hist_data = hist_data + self.histnorm = histnorm + self.group_labels = group_labels + self.bin_size = bin_size + self.show_hist = show_hist + self.show_curve = show_curve + self.trace_number = len(hist_data) + if rug_text: + self.rug_text = rug_text + else: + self.rug_text = [None] * self.trace_number + + self.start = [] + self.end = [] + if colors: + self.colors = colors + else: + self.colors = [ + "rgb(31, 119, 180)", "rgb(255, 127, 14)", + "rgb(44, 160, 44)", "rgb(214, 39, 40)", + "rgb(148, 103, 189)", "rgb(140, 86, 75)", + "rgb(227, 119, 194)", "rgb(127, 127, 127)", + "rgb(188, 189, 34)", "rgb(23, 190, 207)"] + self.curve_x = [None] * self.trace_number + self.curve_y = [None] * self.trace_number + + for trace in self.hist_data: + self.start.append(min(trace) * 1.) + self.end.append(max(trace) * 1.) + + def make_hist(self): + """ + Makes the histogram(s) for FigureFactory.create_distplot(). + + :rtype (list) hist: list of histogram representations + """ + hist = [None] * self.trace_number + + for index in range(self.trace_number): + hist[index] = dict(type='histogram', + x=self.hist_data[index], + xaxis='x1', + yaxis='y1', + histnorm=self.histnorm, + name=self.group_labels[index], + legendgroup=self.group_labels[index], + marker=dict(color=self.colors[index % len(self.colors)]), + autobinx=False, + xbins=dict(start=self.start[index], + end=self.end[index], + size=self.bin_size[index]), + opacity=.7) + return hist + + def make_kde(self): + """ + Makes the kernel density estimation(s) for create_distplot(). + + This is called when curve_type = 'kde' in create_distplot(). + + :rtype (list) curve: list of kde representations + """ + curve = [None] * self.trace_number + for index in range(self.trace_number): + self.curve_x[index] = [self.start[index] + + x * (self.end[index] - self.start[index]) + / 500 for x in range(500)] + self.curve_y[index] = (scipy_stats.gaussian_kde + (self.hist_data[index]) + (self.curve_x[index])) + + if self.histnorm == ALTERNATIVE_HISTNORM: + self.curve_y[index] *= self.bin_size[index] + + for index in range(self.trace_number): + curve[index] = dict(type='scatter', + x=self.curve_x[index], + y=self.curve_y[index], + xaxis='x1', + yaxis='y1', + mode='lines', + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=False if self.show_hist else True, + marker=dict(color=self.colors[index % len(self.colors)])) + return curve + + def make_normal(self): + """ + Makes the normal curve(s) for create_distplot(). + + This is called when curve_type = 'normal' in create_distplot(). + + :rtype (list) curve: list of normal curve representations + """ + curve = [None] * self.trace_number + mean = [None] * self.trace_number + sd = [None] * self.trace_number + + for index in range(self.trace_number): + mean[index], sd[index] = (scipy_stats.norm.fit + (self.hist_data[index])) + self.curve_x[index] = [self.start[index] + + x * (self.end[index] - self.start[index]) + / 500 for x in range(500)] + self.curve_y[index] = scipy_stats.norm.pdf( + self.curve_x[index], loc=mean[index], scale=sd[index]) + + if self.histnorm == ALTERNATIVE_HISTNORM: + self.curve_y[index] *= self.bin_size[index] + + for index in range(self.trace_number): + curve[index] = dict(type='scatter', + x=self.curve_x[index], + y=self.curve_y[index], + xaxis='x1', + yaxis='y1', + mode='lines', + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=False if self.show_hist else True, + marker=dict(color=self.colors[index % len(self.colors)])) + return curve + + def make_rug(self): + """ + Makes the rug plot(s) for create_distplot(). + + :rtype (list) rug: list of rug plot representations + """ + rug = [None] * self.trace_number + for index in range(self.trace_number): + + rug[index] = dict(type='scatter', + x=self.hist_data[index], + y=([self.group_labels[index]] * + len(self.hist_data[index])), + xaxis='x1', + yaxis='y2', + mode='markers', + name=self.group_labels[index], + legendgroup=self.group_labels[index], + showlegend=(False if self.show_hist or + self.show_curve else True), + text=self.rug_text[index], + marker=dict(color=self.colors[index % len(self.colors)], + symbol='line-ns-open')) + return rug diff --git a/plotly/figure_factory/figure_factory/_facet_grid.py b/plotly/figure_factory/figure_factory/_facet_grid.py new file mode 100644 index 00000000000..8ef1825ebf8 --- /dev/null +++ b/plotly/figure_factory/figure_factory/_facet_grid.py @@ -0,0 +1,1111 @@ +from __future__ import absolute_import + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.tools import make_subplots + +import math +from numbers import Number + +pd = optional_imports.get_module('pandas') + +TICK_COLOR = '#969696' +AXIS_TITLE_COLOR = '#0f0f0f' +AXIS_TITLE_SIZE = 12 +GRID_COLOR = '#ffffff' +LEGEND_COLOR = '#efefef' +PLOT_BGCOLOR = '#ededed' +ANNOT_RECT_COLOR = '#d0d0d0' +LEGEND_BORDER_WIDTH = 1 +LEGEND_ANNOT_X = 1.05 +LEGEND_ANNOT_Y = 0.5 +MAX_TICKS_PER_AXIS = 5 +THRES_FOR_FLIPPED_FACET_TITLES = 10 +GRID_WIDTH = 1 + +VALID_TRACE_TYPES = ['scatter', 'scattergl', 'histogram', 'bar', 'box'] + +CUSTOM_LABEL_ERROR = ( + "If you are using a dictionary for custom labels for the facet row/col, " + "make sure each key in that column of the dataframe is in your facet " + "labels. The keys you need are {}" +) + + +def _is_flipped(num): + if num >= THRES_FOR_FLIPPED_FACET_TITLES: + flipped = True + else: + flipped = False + return flipped + + +def _return_label(original_label, facet_labels, facet_var): + if isinstance(facet_labels, dict): + label = facet_labels[original_label] + elif isinstance(facet_labels, str): + label = '{}: {}'.format(facet_var, original_label) + else: + label = original_label + return label + + +def _legend_annotation(color_name): + legend_title = dict( + textangle=0, + xanchor='left', + yanchor='middle', + x=LEGEND_ANNOT_X, + y=1.03, + showarrow=False, + xref='paper', + yref='paper', + text='factor({})'.format(color_name), + font=dict( + size=13, + color='#000000' + ) + ) + return legend_title + + +def _annotation_dict(text, lane, num_of_lanes, SUBPLOT_SPACING, row_col='col', + flipped=True): + l = (1 - (num_of_lanes - 1) * SUBPLOT_SPACING) / (num_of_lanes) + if not flipped: + xanchor = 'center' + yanchor = 'middle' + if row_col == 'col': + x = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l + y = 1.03 + textangle = 0 + elif row_col == 'row': + y = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l + x = 1.03 + textangle = 90 + else: + if row_col == 'col': + xanchor = 'center' + yanchor = 'bottom' + x = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l + y = 1.0 + textangle = 270 + elif row_col == 'row': + xanchor = 'left' + yanchor = 'middle' + y = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l + x = 1.0 + textangle = 0 + + annotation_dict = dict( + textangle=textangle, + xanchor=xanchor, + yanchor=yanchor, + x=x, + y=y, + showarrow=False, + xref='paper', + yref='paper', + text=str(text), + font=dict( + size=13, + color=AXIS_TITLE_COLOR + ) + ) + return annotation_dict + + +def _axis_title_annotation(text, x_or_y_axis): + if x_or_y_axis == 'x': + x_pos = 0.5 + y_pos = -0.1 + textangle = 0 + elif x_or_y_axis == 'y': + x_pos = -0.1 + y_pos = 0.5 + textangle = 270 + + if not text: + text = '' + + annot = {'font': {'color': '#000000', 'size': AXIS_TITLE_SIZE}, + 'showarrow': False, + 'text': text, + 'textangle': textangle, + 'x': x_pos, + 'xanchor': 'center', + 'xref': 'paper', + 'y': y_pos, + 'yanchor': 'middle', + 'yref': 'paper'} + return annot + + +def _add_shapes_to_fig(fig, annot_rect_color, flipped_rows=False, + flipped_cols=False): + shapes_list = [] + for key in fig['layout'].to_plotly_json().keys(): + if 'axis' in key and fig['layout'][key]['domain'] != [0.0, 1.0]: + shape = { + 'fillcolor': annot_rect_color, + 'layer': 'below', + 'line': {'color': annot_rect_color, 'width': 1}, + 'type': 'rect', + 'xref': 'paper', + 'yref': 'paper' + } + + if 'xaxis' in key: + shape['x0'] = fig['layout'][key]['domain'][0] + shape['x1'] = fig['layout'][key]['domain'][1] + shape['y0'] = 1.005 + shape['y1'] = 1.05 + + if flipped_cols: + shape['y1'] += 0.5 + shapes_list.append(shape) + + elif 'yaxis' in key: + shape['x0'] = 1.005 + shape['x1'] = 1.05 + shape['y0'] = fig['layout'][key]['domain'][0] + shape['y1'] = fig['layout'][key]['domain'][1] + + if flipped_rows: + shape['x1'] += 1 + shapes_list.append(shape) + + fig['layout']['shapes'] = shapes_list + + +def _make_trace_for_scatter(trace, trace_type, color, **kwargs_marker): + if trace_type in ['scatter', 'scattergl']: + trace['mode'] = 'markers' + trace['marker'] = dict(color=color, **kwargs_marker) + return trace + + +def _facet_grid_color_categorical(df, x, y, facet_row, facet_col, color_name, + colormap, num_of_rows, num_of_cols, + facet_row_labels, facet_col_labels, + trace_type, flipped_rows, flipped_cols, + show_boxes, SUBPLOT_SPACING, marker_color, + kwargs_trace, kwargs_marker): + + fig = make_subplots(rows=num_of_rows, cols=num_of_cols, + shared_xaxes=True, shared_yaxes=True, + horizontal_spacing=SUBPLOT_SPACING, + vertical_spacing=SUBPLOT_SPACING, print_grid=False) + + annotations = [] + if not facet_row and not facet_col: + color_groups = list(df.groupby(color_name)) + for group in color_groups: + trace = dict( + type=trace_type, + name=group[0], + marker=dict( + color=colormap[group[0]], + ), + **kwargs_trace + ) + if x: + trace['x'] = group[1][x] + if y: + trace['y'] = group[1][y] + trace = _make_trace_for_scatter( + trace, trace_type, colormap[group[0]], **kwargs_marker + ) + + fig.append_trace(trace, 1, 1) + + elif (facet_row and not facet_col) or (not facet_row and facet_col): + groups_by_facet = list( + df.groupby(facet_row if facet_row else facet_col) + ) + for j, group in enumerate(groups_by_facet): + for color_val in df[color_name].unique(): + data_by_color = group[1][group[1][color_name] == color_val] + trace = dict( + type=trace_type, + name=color_val, + marker=dict( + color=colormap[color_val], + ), + **kwargs_trace + ) + if x: + trace['x'] = data_by_color[x] + if y: + trace['y'] = data_by_color[y] + trace = _make_trace_for_scatter( + trace, trace_type, colormap[color_val], **kwargs_marker + ) + + fig.append_trace(trace, + j + 1 if facet_row else 1, + 1 if facet_row else j + 1) + + label = _return_label( + group[0], + facet_row_labels if facet_row else facet_col_labels, + facet_row if facet_row else facet_col + ) + + annotations.append( + _annotation_dict( + label, + num_of_rows - j if facet_row else j + 1, + num_of_rows if facet_row else num_of_cols, + SUBPLOT_SPACING, + 'row' if facet_row else 'col', + flipped_rows) + ) + + elif facet_row and facet_col: + groups_by_facets = list(df.groupby([facet_row, facet_col])) + tuple_to_facet_group = {item[0]: item[1] for + item in groups_by_facets} + + row_values = df[facet_row].unique() + col_values = df[facet_col].unique() + color_vals = df[color_name].unique() + for row_count, x_val in enumerate(row_values): + for col_count, y_val in enumerate(col_values): + try: + group = tuple_to_facet_group[(x_val, y_val)] + except KeyError: + group = pd.DataFrame([[None, None, None]], + columns=[x, y, color_name]) + + for color_val in color_vals: + if group.values.tolist() != [[None, None, None]]: + group_filtered = group[group[color_name] == color_val] + + trace = dict( + type=trace_type, + name=color_val, + marker=dict( + color=colormap[color_val], + ), + **kwargs_trace + ) + new_x = group_filtered[x] + new_y = group_filtered[y] + else: + trace = dict( + type=trace_type, + name=color_val, + marker=dict( + color=colormap[color_val], + ), + showlegend=False, + **kwargs_trace + ) + new_x = group[x] + new_y = group[y] + + if x: + trace['x'] = new_x + if y: + trace['y'] = new_y + trace = _make_trace_for_scatter( + trace, trace_type, colormap[color_val], + **kwargs_marker + ) + + fig.append_trace(trace, row_count + 1, col_count + 1) + if row_count == 0: + label = _return_label(col_values[col_count], + facet_col_labels, facet_col) + annotations.append( + _annotation_dict(label, col_count + 1, num_of_cols, + SUBPLOT_SPACING, + row_col='col', flipped=flipped_cols) + ) + label = _return_label(row_values[row_count], + facet_row_labels, facet_row) + annotations.append( + _annotation_dict(label, num_of_rows - row_count, num_of_rows, + SUBPLOT_SPACING, + row_col='row', flipped=flipped_rows) + ) + + return fig, annotations + + +def _facet_grid_color_numerical(df, x, y, facet_row, facet_col, color_name, + colormap, num_of_rows, + num_of_cols, facet_row_labels, + facet_col_labels, trace_type, + flipped_rows, flipped_cols, show_boxes, + SUBPLOT_SPACING, marker_color, kwargs_trace, + kwargs_marker): + + fig = make_subplots(rows=num_of_rows, cols=num_of_cols, + shared_xaxes=True, shared_yaxes=True, + horizontal_spacing=SUBPLOT_SPACING, + vertical_spacing=SUBPLOT_SPACING, print_grid=False) + + annotations = [] + if not facet_row and not facet_col: + trace = dict( + type=trace_type, + marker=dict( + color=df[color_name], + colorscale=colormap, + showscale=True, + ), + **kwargs_trace + ) + if x: + trace['x'] = df[x] + if y: + trace['y'] = df[y] + trace = _make_trace_for_scatter( + trace, trace_type, df[color_name], **kwargs_marker + ) + + fig.append_trace(trace, 1, 1) + + if (facet_row and not facet_col) or (not facet_row and facet_col): + groups_by_facet = list( + df.groupby(facet_row if facet_row else facet_col) + ) + for j, group in enumerate(groups_by_facet): + trace = dict( + type=trace_type, + marker=dict( + color=df[color_name], + colorscale=colormap, + showscale=True, + colorbar=dict(x=1.15), + ), + **kwargs_trace + ) + if x: + trace['x'] = group[1][x] + if y: + trace['y'] = group[1][y] + trace = _make_trace_for_scatter( + trace, trace_type, df[color_name], **kwargs_marker + ) + + fig.append_trace( + trace, + j + 1 if facet_row else 1, + 1 if facet_row else j + 1 + ) + + labels = facet_row_labels if facet_row else facet_col_labels + label = _return_label( + group[0], labels, facet_row if facet_row else facet_col + ) + + annotations.append( + _annotation_dict( + label, + num_of_rows - j if facet_row else j + 1, + num_of_rows if facet_row else num_of_cols, + SUBPLOT_SPACING, + 'row' if facet_row else 'col', + flipped=flipped_rows) + ) + + elif facet_row and facet_col: + groups_by_facets = list(df.groupby([facet_row, facet_col])) + tuple_to_facet_group = {item[0]: item[1] for + item in groups_by_facets} + + row_values = df[facet_row].unique() + col_values = df[facet_col].unique() + for row_count, x_val in enumerate(row_values): + for col_count, y_val in enumerate(col_values): + try: + group = tuple_to_facet_group[(x_val, y_val)] + except KeyError: + group = pd.DataFrame([[None, None, None]], + columns=[x, y, color_name]) + + if group.values.tolist() != [[None, None, None]]: + trace = dict( + type=trace_type, + marker=dict( + color=df[color_name], + colorscale=colormap, + showscale=(row_count == 0), + colorbar=dict(x=1.15), + ), + **kwargs_trace + ) + + else: + trace = dict( + type=trace_type, + showlegend=False, + **kwargs_trace + ) + + if x: + trace['x'] = group[x] + if y: + trace['y'] = group[y] + trace = _make_trace_for_scatter( + trace, trace_type, df[color_name], **kwargs_marker + ) + + fig.append_trace(trace, row_count + 1, col_count + 1) + if row_count == 0: + label = _return_label(col_values[col_count], + facet_col_labels, facet_col) + annotations.append( + _annotation_dict(label, col_count + 1, num_of_cols, + SUBPLOT_SPACING, + row_col='col', flipped=flipped_cols) + ) + label = _return_label(row_values[row_count], + facet_row_labels, facet_row) + annotations.append( + _annotation_dict(row_values[row_count], + num_of_rows - row_count, num_of_rows, SUBPLOT_SPACING, + row_col='row', flipped=flipped_rows) + ) + + return fig, annotations + + +def _facet_grid(df, x, y, facet_row, facet_col, num_of_rows, + num_of_cols, facet_row_labels, facet_col_labels, + trace_type, flipped_rows, flipped_cols, show_boxes, + SUBPLOT_SPACING, marker_color, kwargs_trace, kwargs_marker): + + fig = make_subplots(rows=num_of_rows, cols=num_of_cols, + shared_xaxes=True, shared_yaxes=True, + horizontal_spacing=SUBPLOT_SPACING, + vertical_spacing=SUBPLOT_SPACING, print_grid=False) + annotations = [] + if not facet_row and not facet_col: + trace = dict( + type=trace_type, + marker=dict( + color=marker_color, + line=kwargs_marker['line'], + ), + **kwargs_trace + ) + + if x: + trace['x'] = df[x] + if y: + trace['y'] = df[y] + trace = _make_trace_for_scatter( + trace, trace_type, marker_color, **kwargs_marker + ) + + fig.append_trace(trace, 1, 1) + + elif (facet_row and not facet_col) or (not facet_row and facet_col): + groups_by_facet = list( + df.groupby(facet_row if facet_row else facet_col) + ) + for j, group in enumerate(groups_by_facet): + trace = dict( + type=trace_type, + marker=dict( + color=marker_color, + line=kwargs_marker['line'], + ), + **kwargs_trace + ) + + if x: + trace['x'] = group[1][x] + if y: + trace['y'] = group[1][y] + trace = _make_trace_for_scatter( + trace, trace_type, marker_color, **kwargs_marker + ) + + fig.append_trace(trace, + j + 1 if facet_row else 1, + 1 if facet_row else j + 1) + + label = _return_label( + group[0], + facet_row_labels if facet_row else facet_col_labels, + facet_row if facet_row else facet_col + ) + + annotations.append( + _annotation_dict( + label, + num_of_rows - j if facet_row else j + 1, + num_of_rows if facet_row else num_of_cols, + SUBPLOT_SPACING, + 'row' if facet_row else 'col', + flipped_rows + ) + ) + + elif facet_row and facet_col: + groups_by_facets = list(df.groupby([facet_row, facet_col])) + tuple_to_facet_group = {item[0]: item[1] for + item in groups_by_facets} + + row_values = df[facet_row].unique() + col_values = df[facet_col].unique() + for row_count, x_val in enumerate(row_values): + for col_count, y_val in enumerate(col_values): + try: + group = tuple_to_facet_group[(x_val, y_val)] + except KeyError: + group = pd.DataFrame([[None, None]], columns=[x, y]) + trace = dict( + type=trace_type, + marker=dict( + color=marker_color, + line=kwargs_marker['line'], + ), + **kwargs_trace + ) + if x: + trace['x'] = group[x] + if y: + trace['y'] = group[y] + trace = _make_trace_for_scatter( + trace, trace_type, marker_color, **kwargs_marker + ) + + fig.append_trace(trace, row_count + 1, col_count + 1) + if row_count == 0: + label = _return_label(col_values[col_count], + facet_col_labels, + facet_col) + annotations.append( + _annotation_dict(label, col_count + 1, num_of_cols, SUBPLOT_SPACING, + row_col='col', flipped=flipped_cols) + ) + + label = _return_label(row_values[row_count], + facet_row_labels, + facet_row) + annotations.append( + _annotation_dict(label, num_of_rows - row_count, num_of_rows, SUBPLOT_SPACING, + row_col='row', flipped=flipped_rows) + ) + + return fig, annotations + + +def create_facet_grid(df, x=None, y=None, facet_row=None, facet_col=None, + color_name=None, colormap=None, color_is_cat=False, + facet_row_labels=None, facet_col_labels=None, + height=None, width=None, trace_type='scatter', + scales='fixed', dtick_x=None, dtick_y=None, + show_boxes=True, ggplot2=False, binsize=1, **kwargs): + """ + Returns figure for facet grid. + + :param (pd.DataFrame) df: the dataframe of columns for the facet grid. + :param (str) x: the name of the dataframe column for the x axis data. + :param (str) y: the name of the dataframe column for the y axis data. + :param (str) facet_row: the name of the dataframe column that is used to + facet the grid into row panels. + :param (str) facet_col: the name of the dataframe column that is used to + facet the grid into column panels. + :param (str) color_name: the name of your dataframe column that will + function as the colormap variable. + :param (str|list|dict) colormap: the param that determines how the + color_name column colors the data. If the dataframe contains numeric + data, then a dictionary of colors will group the data categorically + while a Plotly Colorscale name or a custom colorscale will treat it + numerically. To learn more about colors and types of colormap, run + `help(plotly.colors)`. + :param (bool) color_is_cat: determines whether a numerical column for the + colormap will be treated as categorical (True) or sequential (False). + Default = False. + :param (str|dict) facet_row_labels: set to either 'name' or a dictionary + of all the unique values in the faceting row mapped to some text to + show up in the label annotations. If None, labeling works like usual. + :param (str|dict) facet_col_labels: set to either 'name' or a dictionary + of all the values in the faceting row mapped to some text to show up + in the label annotations. If None, labeling works like usual. + :param (int) height: the height of the facet grid figure. + :param (int) width: the width of the facet grid figure. + :param (str) trace_type: decides the type of plot to appear in the + facet grid. The options are 'scatter', 'scattergl', 'histogram', + 'bar', and 'box'. + Default = 'scatter'. + :param (str) scales: determines if axes have fixed ranges or not. Valid + settings are 'fixed' (all axes fixed), 'free_x' (x axis free only), + 'free_y' (y axis free only) or 'free' (both axes free). + :param (float) dtick_x: determines the distance between each tick on the + x-axis. Default is None which means dtick_x is set automatically. + :param (float) dtick_y: determines the distance between each tick on the + y-axis. Default is None which means dtick_y is set automatically. + :param (bool) show_boxes: draws grey boxes behind the facet titles. + :param (bool) ggplot2: draws the facet grid in the style of `ggplot2`. See + http://ggplot2.tidyverse.org/reference/facet_grid.html for reference. + Default = False + :param (int) binsize: groups all data into bins of a given length. + :param (dict) kwargs: a dictionary of scatterplot arguments. + + Examples 1: One Way Faceting + ``` + import plotly.plotly as py + import plotly.figure_factory as ff + + import pandas as pd + + mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt') + + fig = ff.create_facet_grid( + mpg, + x='displ', + y='cty', + facet_col='cyl', + ) + py.iplot(fig, filename='facet_grid_mpg_one_way_facet') + ``` + + Example 2: Two Way Faceting + ``` + import plotly.plotly as py + import plotly.figure_factory as ff + + import pandas as pd + + mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt') + + fig = ff.create_facet_grid( + mpg, + x='displ', + y='cty', + facet_row='drv', + facet_col='cyl', + ) + py.iplot(fig, filename='facet_grid_mpg_two_way_facet') + ``` + + Example 3: Categorical Coloring + ``` + import plotly.plotly as py + import plotly.figure_factory as ff + + import pandas as pd + + mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt') + + fig = ff.create_facet_grid( + mtcars, + x='mpg', + y='wt', + facet_col='cyl', + color_name='cyl', + color_is_cat=True, + ) + py.iplot(fig, filename='facet_grid_mpg_default_colors') + ``` + + Example 4: Sequential Coloring + ``` + import plotly.plotly as py + import plotly.figure_factory as ff + + import pandas as pd + + tips = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/tips.csv') + + fig = ff.create_facet_grid( + tips, + x='total_bill', + y='tip', + facet_row='sex', + facet_col='smoker', + color_name='size', + colormap='Viridis', + ) + py.iplot(fig, filename='facet_grid_tips_sequential_colors') + ``` + + Example 5: Custom labels + ``` + import plotly.plotly as py + import plotly.figure_factory as ff + + import pandas as pd + + mtcars = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/mtcars.csv') + + fig = ff.create_facet_grid( + mtcars, + x='wt', + y='mpg', + facet_col='cyl', + facet_col_labels={4: "$\\alpha$", 6: '$\\beta$', 8: '$\sqrt[y]{x}$'}, + ) + + py.iplot(fig, filename='facet_grid_mtcars_custom_labels') + ``` + + Example 6: Other Trace Type + ``` + import plotly.plotly as py + import plotly.figure_factory as ff + + import pandas as pd + + mtcars = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/mtcars.csv') + + fig = ff.create_facet_grid( + mtcars, + x='wt', + facet_col='cyl', + trace_type='histogram', + ) + + py.iplot(fig, filename='facet_grid_mtcars_other_trace_type') + ``` + """ + if not pd: + raise exceptions.ImportError( + "'pandas' must be installed for this figure_factory." + ) + + if not isinstance(df, pd.DataFrame): + raise exceptions.PlotlyError( + "You must input a pandas DataFrame." + ) + + # make sure all columns are of homogenous datatype + utils.validate_dataframe(df) + + if trace_type in ['scatter', 'scattergl']: + if not x or not y: + raise exceptions.PlotlyError( + "You need to input 'x' and 'y' if you are you are using a " + "trace_type of 'scatter' or 'scattergl'." + ) + + for key in [x, y, facet_row, facet_col, color_name]: + if key is not None: + try: + df[key] + except KeyError: + raise exceptions.PlotlyError( + "x, y, facet_row, facet_col and color_name must be keys " + "in your dataframe." + ) + # autoscale histogram bars + if trace_type not in ['scatter', 'scattergl']: + scales = 'free' + + # validate scales + if scales not in ['fixed', 'free_x', 'free_y', 'free']: + raise exceptions.PlotlyError( + "'scales' must be set to 'fixed', 'free_x', 'free_y' and 'free'." + ) + + if trace_type not in VALID_TRACE_TYPES: + raise exceptions.PlotlyError( + "'trace_type' must be in {}".format(VALID_TRACE_TYPES) + ) + + if trace_type == 'histogram': + SUBPLOT_SPACING = 0.06 + else: + SUBPLOT_SPACING = 0.015 + + # seperate kwargs for marker and else + if 'marker' in kwargs: + kwargs_marker = kwargs['marker'] + else: + kwargs_marker = {} + marker_color = kwargs_marker.pop('color', None) + kwargs.pop('marker', None) + kwargs_trace = kwargs + + if 'size' not in kwargs_marker: + if ggplot2: + kwargs_marker['size'] = 5 + else: + kwargs_marker['size'] = 8 + + if 'opacity' not in kwargs_marker: + if not ggplot2: + kwargs_trace['opacity'] = 0.6 + + if 'line' not in kwargs_marker: + if not ggplot2: + kwargs_marker['line'] = {'color': 'darkgrey', 'width': 1} + else: + kwargs_marker['line'] = {} + + # default marker size + if not ggplot2: + if not marker_color: + marker_color = 'rgb(31, 119, 180)' + else: + marker_color = 'rgb(0, 0, 0)' + + num_of_rows = 1 + num_of_cols = 1 + flipped_rows = False + flipped_cols = False + if facet_row: + num_of_rows = len(df[facet_row].unique()) + flipped_rows = _is_flipped(num_of_rows) + if isinstance(facet_row_labels, dict): + for key in df[facet_row].unique(): + if key not in facet_row_labels.keys(): + unique_keys = df[facet_row].unique().tolist() + raise exceptions.PlotlyError( + CUSTOM_LABEL_ERROR.format(unique_keys) + ) + if facet_col: + num_of_cols = len(df[facet_col].unique()) + flipped_cols = _is_flipped(num_of_cols) + if isinstance(facet_col_labels, dict): + for key in df[facet_col].unique(): + if key not in facet_col_labels.keys(): + unique_keys = df[facet_col].unique().tolist() + raise exceptions.PlotlyError( + CUSTOM_LABEL_ERROR.format(unique_keys) + ) + show_legend = False + if color_name: + if isinstance(df[color_name].iloc[0], str) or color_is_cat: + show_legend = True + if isinstance(colormap, dict): + utils.validate_colors_dict(colormap, 'rgb') + + for val in df[color_name].unique(): + if val not in colormap.keys(): + raise exceptions.PlotlyError( + "If using 'colormap' as a dictionary, make sure " + "all the values of the colormap column are in " + "the keys of your dictionary." + ) + else: + # use default plotly colors for dictionary + default_colors = utils.DEFAULT_PLOTLY_COLORS + colormap = {} + j = 0 + for val in df[color_name].unique(): + if j >= len(default_colors): + j = 0 + colormap[val] = default_colors[j] + j += 1 + fig, annotations = _facet_grid_color_categorical( + df, x, y, facet_row, facet_col, color_name, colormap, + num_of_rows, num_of_cols, facet_row_labels, facet_col_labels, + trace_type, flipped_rows, flipped_cols, show_boxes, + SUBPLOT_SPACING, marker_color, kwargs_trace, kwargs_marker + ) + + elif isinstance(df[color_name].iloc[0], Number): + if isinstance(colormap, dict): + show_legend = True + utils.validate_colors_dict(colormap, 'rgb') + + for val in df[color_name].unique(): + if val not in colormap.keys(): + raise exceptions.PlotlyError( + "If using 'colormap' as a dictionary, make sure " + "all the values of the colormap column are in " + "the keys of your dictionary." + ) + fig, annotations = _facet_grid_color_categorical( + df, x, y, facet_row, facet_col, color_name, colormap, + num_of_rows, num_of_cols, facet_row_labels, + facet_col_labels, trace_type, flipped_rows, + flipped_cols, show_boxes, SUBPLOT_SPACING, marker_color, + kwargs_trace, kwargs_marker + ) + + elif isinstance(colormap, list): + colorscale_list = colormap + utils.validate_colorscale(colorscale_list) + + fig, annotations = _facet_grid_color_numerical( + df, x, y, facet_row, facet_col, color_name, + colorscale_list, num_of_rows, num_of_cols, + facet_row_labels, facet_col_labels, trace_type, + flipped_rows, flipped_cols, show_boxes, SUBPLOT_SPACING, + marker_color, kwargs_trace, kwargs_marker + ) + elif isinstance(colormap, str): + if colormap in utils.PLOTLY_SCALES.keys(): + colorscale_list = utils.PLOTLY_SCALES[colormap] + else: + raise exceptions.PlotlyError( + "If 'colormap' is a string, it must be the name " + "of a Plotly Colorscale. The available colorscale " + "names are {}".format(utils.PLOTLY_SCALES.keys()) + ) + fig, annotations = _facet_grid_color_numerical( + df, x, y, facet_row, facet_col, color_name, + colorscale_list, num_of_rows, num_of_cols, + facet_row_labels, facet_col_labels, trace_type, + flipped_rows, flipped_cols, show_boxes, SUBPLOT_SPACING, + marker_color, kwargs_trace, kwargs_marker + ) + else: + colorscale_list = utils.PLOTLY_SCALES['Reds'] + fig, annotations = _facet_grid_color_numerical( + df, x, y, facet_row, facet_col, color_name, + colorscale_list, num_of_rows, num_of_cols, + facet_row_labels, facet_col_labels, trace_type, + flipped_rows, flipped_cols, show_boxes, SUBPLOT_SPACING, + marker_color, kwargs_trace, kwargs_marker + ) + + else: + fig, annotations = _facet_grid( + df, x, y, facet_row, facet_col, num_of_rows, num_of_cols, + facet_row_labels, facet_col_labels, trace_type, flipped_rows, + flipped_cols, show_boxes, SUBPLOT_SPACING, marker_color, + kwargs_trace, kwargs_marker + ) + + if not height: + height = max(600, 100 * num_of_rows) + if not width: + width = max(600, 100 * num_of_cols) + + fig['layout'].update(height=height, width=width, title='', + paper_bgcolor='rgb(251, 251, 251)') + if ggplot2: + fig['layout'].update(plot_bgcolor=PLOT_BGCOLOR, + paper_bgcolor='rgb(255, 255, 255)', + hovermode='closest') + + # axis titles + x_title_annot = _axis_title_annotation(x, 'x') + y_title_annot = _axis_title_annotation(y, 'y') + + # annotations + annotations.append(x_title_annot) + annotations.append(y_title_annot) + + # legend + fig['layout']['showlegend'] = show_legend + fig['layout']['legend']['bgcolor'] = LEGEND_COLOR + fig['layout']['legend']['borderwidth'] = LEGEND_BORDER_WIDTH + fig['layout']['legend']['x'] = 1.05 + fig['layout']['legend']['y'] = 1 + fig['layout']['legend']['yanchor'] = 'top' + + if show_legend: + fig['layout']['showlegend'] = show_legend + if ggplot2: + if color_name: + legend_annot = _legend_annotation(color_name) + annotations.append(legend_annot) + fig['layout']['margin']['r'] = 150 + + # assign annotations to figure + fig['layout']['annotations'] = annotations + + # add shaded boxes behind axis titles + if show_boxes and ggplot2: + _add_shapes_to_fig(fig, ANNOT_RECT_COLOR, flipped_rows, flipped_cols) + + # all xaxis and yaxis labels + axis_labels = {'x': [], 'y': []} + for key in fig['layout']: + if 'xaxis' in key: + axis_labels['x'].append(key) + elif 'yaxis' in key: + axis_labels['y'].append(key) + + string_number_in_data = False + for var in [v for v in [x, y] if v]: + if isinstance(df[var].tolist()[0], str): + for item in df[var]: + try: + int(item) + string_number_in_data = True + except ValueError: + pass + + if string_number_in_data: + for x_y in axis_labels.keys(): + for axis_name in axis_labels[x_y]: + fig['layout'][axis_name]['type'] = 'category' + + if scales == 'fixed': + fixed_axes = ['x', 'y'] + elif scales == 'free_x': + fixed_axes = ['y'] + elif scales == 'free_y': + fixed_axes = ['x'] + elif scales == 'free': + fixed_axes = [] + + # fixed ranges + for x_y in fixed_axes: + min_ranges = [] + max_ranges = [] + for trace in fig['data']: + if trace[x_y] is not None and len(trace[x_y]) > 0: + min_ranges.append(min(trace[x_y])) + max_ranges.append(max(trace[x_y])) + while None in min_ranges: + min_ranges.remove(None) + while None in max_ranges: + max_ranges.remove(None) + + min_range = min(min_ranges) + max_range = max(max_ranges) + + range_are_numbers = (isinstance(min_range, Number) and + isinstance(max_range, Number)) + + if range_are_numbers: + min_range = math.floor(min_range) + max_range = math.ceil(max_range) + + # extend widen frame by 5% on each side + min_range -= 0.05 * (max_range - min_range) + max_range += 0.05 * (max_range - min_range) + + if x_y == 'x': + if dtick_x: + dtick = dtick_x + else: + dtick = math.floor( + (max_range - min_range) / MAX_TICKS_PER_AXIS + ) + elif x_y == 'y': + if dtick_y: + dtick = dtick_y + else: + dtick = math.floor( + (max_range - min_range) / MAX_TICKS_PER_AXIS + ) + else: + dtick = 1 + + for axis_title in axis_labels[x_y]: + fig['layout'][axis_title]['dtick'] = dtick + fig['layout'][axis_title]['ticklen'] = 0 + fig['layout'][axis_title]['zeroline'] = False + if ggplot2: + fig['layout'][axis_title]['tickwidth'] = 1 + fig['layout'][axis_title]['ticklen'] = 4 + fig['layout'][axis_title]['gridwidth'] = GRID_WIDTH + + fig['layout'][axis_title]['gridcolor'] = GRID_COLOR + fig['layout'][axis_title]['gridwidth'] = 2 + fig['layout'][axis_title]['tickfont'] = { + 'color': TICK_COLOR, 'size': 10 + } + + # insert ranges into fig + if x_y in fixed_axes: + for key in fig['layout']: + if '{}axis'.format(x_y) in key and range_are_numbers: + fig['layout'][key]['range'] = [min_range, max_range] + + return fig diff --git a/plotly/figure_factory/figure_factory/_gantt.py b/plotly/figure_factory/figure_factory/_gantt.py new file mode 100644 index 00000000000..36024e6c1b5 --- /dev/null +++ b/plotly/figure_factory/figure_factory/_gantt.py @@ -0,0 +1,780 @@ +from __future__ import absolute_import + +from numbers import Number + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs + +pd = optional_imports.get_module('pandas') + +REQUIRED_GANTT_KEYS = ['Task', 'Start', 'Finish'] + + +def validate_gantt(df): + """ + Validates the inputted dataframe or list + """ + if pd and isinstance(df, pd.core.frame.DataFrame): + # validate that df has all the required keys + for key in REQUIRED_GANTT_KEYS: + if key not in df: + raise exceptions.PlotlyError( + "The columns in your dataframe must include the " + "following keys: {0}".format( + ', '.join(REQUIRED_GANTT_KEYS)) + ) + + num_of_rows = len(df.index) + chart = [] + for index in range(num_of_rows): + task_dict = {} + for key in df: + task_dict[key] = df.ix[index][key] + chart.append(task_dict) + + return chart + + # validate if df is a list + if not isinstance(df, list): + raise exceptions.PlotlyError("You must input either a dataframe " + "or a list of dictionaries.") + + # validate if df is empty + if len(df) <= 0: + raise exceptions.PlotlyError("Your list is empty. It must contain " + "at least one dictionary.") + if not isinstance(df[0], dict): + raise exceptions.PlotlyError("Your list must only " + "include dictionaries.") + return df + + +def gantt(chart, colors, title, bar_width, showgrid_x, showgrid_y, height, + width, tasks=None, task_names=None, data=None, group_tasks=False): + """ + Refer to create_gantt() for docstring + """ + if tasks is None: + tasks = [] + if task_names is None: + task_names = [] + if data is None: + data = [] + + for index in range(len(chart)): + task = dict(x0=chart[index]['Start'], + x1=chart[index]['Finish'], + name=chart[index]['Task']) + if 'Description' in chart[index]: + task['description'] = chart[index]['Description'] + tasks.append(task) + + shape_template = { + 'type': 'rect', + 'xref': 'x', + 'yref': 'y', + 'opacity': 1, + 'line': { + 'width': 0, + } + } + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]['name'] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted first + # are shown at the top + if group_tasks: + task_names.reverse() + + color_index = 0 + for index in range(len(tasks)): + tn = tasks[index]['name'] + del tasks[index]['name'] + tasks[index].update(shape_template) + + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]['y0'] = groupID - bar_width + tasks[index]['y1'] = groupID + bar_width + + # check if colors need to be looped + if color_index >= len(colors): + color_index = 0 + tasks[index]['fillcolor'] = colors[color_index] + # Add a line for hover text and autorange + entry = dict( + x=[tasks[index]['x0'], tasks[index]['x1']], + y=[groupID, groupID], + name='', + marker={'color': 'white'} + ) + if "description" in tasks[index]: + entry['text'] = tasks[index]['description'] + del tasks[index]['description'] + data.append(entry) + color_index += 1 + + layout = dict( + title=title, + showlegend=False, + height=height, + width=width, + shapes=[], + hovermode='closest', + yaxis=dict( + showgrid=showgrid_y, + ticktext=task_names, + tickvals=list(range(len(task_names))), + range=[-1, len(task_names) + 1], + autorange=False, + zeroline=False, + ), + xaxis=dict( + showgrid=showgrid_x, + zeroline=False, + rangeselector=dict( + buttons=list([ + dict(count=7, + label='1w', + step='day', + stepmode='backward'), + dict(count=1, + label='1m', + step='month', + stepmode='backward'), + dict(count=6, + label='6m', + step='month', + stepmode='backward'), + dict(count=1, + label='YTD', + step='year', + stepmode='todate'), + dict(count=1, + label='1y', + step='year', + stepmode='backward'), + dict(step='all') + ]) + ), + type='date' + ) + ) + layout['shapes'] = tasks + + fig = graph_objs.Figure(data=data, layout=layout) + return fig + + +def gantt_colorscale(chart, colors, title, index_col, show_colorbar, bar_width, + showgrid_x, showgrid_y, height, width, tasks=None, + task_names=None, data=None, group_tasks=False): + """ + Refer to FigureFactory.create_gantt() for docstring + """ + if tasks is None: + tasks = [] + if task_names is None: + task_names = [] + if data is None: + data = [] + showlegend = False + + for index in range(len(chart)): + task = dict(x0=chart[index]['Start'], + x1=chart[index]['Finish'], + name=chart[index]['Task']) + if 'Description' in chart[index]: + task['description'] = chart[index]['Description'] + tasks.append(task) + + shape_template = { + 'type': 'rect', + 'xref': 'x', + 'yref': 'y', + 'opacity': 1, + 'line': { + 'width': 0, + } + } + + # compute the color for task based on indexing column + if isinstance(chart[0][index_col], Number): + # check that colors has at least 2 colors + if len(colors) < 2: + raise exceptions.PlotlyError( + "You must use at least 2 colors in 'colors' if you " + "are using a colorscale. However only the first two " + "colors given will be used for the lower and upper " + "bounds on the colormap." + ) + + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]['name'] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted + # first are shown at the top + if group_tasks: + task_names.reverse() + + for index in range(len(tasks)): + tn = tasks[index]['name'] + del tasks[index]['name'] + tasks[index].update(shape_template) + + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]['y0'] = groupID - bar_width + tasks[index]['y1'] = groupID + bar_width + + # unlabel color + colors = utils.color_parser(colors, utils.unlabel_rgb) + lowcolor = colors[0] + highcolor = colors[1] + + intermed = (chart[index][index_col]) / 100.0 + intermed_color = utils.find_intermediate_color( + lowcolor, highcolor, intermed + ) + intermed_color = utils.color_parser( + intermed_color, utils.label_rgb + ) + tasks[index]['fillcolor'] = intermed_color + # relabel colors with 'rgb' + colors = utils.color_parser(colors, utils.label_rgb) + + # add a line for hover text and autorange + entry = dict( + x=[tasks[index]['x0'], tasks[index]['x1']], + y=[groupID, groupID], + name='', + marker={'color': 'white'} + ) + if "description" in tasks[index]: + entry['text'] = tasks[index]['description'] + del tasks[index]['description'] + data.append(entry) + + if show_colorbar is True: + # generate dummy data for colorscale visibility + data.append( + dict( + x=[tasks[index]['x0'], tasks[index]['x0']], + y=[index, index], + name='', + marker={'color': 'white', + 'colorscale': [[0, colors[0]], [1, colors[1]]], + 'showscale': True, + 'cmax': 100, + 'cmin': 0} + ) + ) + + if isinstance(chart[0][index_col], str): + index_vals = [] + for row in range(len(tasks)): + if chart[row][index_col] not in index_vals: + index_vals.append(chart[row][index_col]) + + index_vals.sort() + + if len(colors) < len(index_vals): + raise exceptions.PlotlyError( + "Error. The number of colors in 'colors' must be no less " + "than the number of unique index values in your group " + "column." + ) + + # make a dictionary assignment to each index value + index_vals_dict = {} + # define color index + c_index = 0 + for key in index_vals: + if c_index > len(colors) - 1: + c_index = 0 + index_vals_dict[key] = colors[c_index] + c_index += 1 + + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]['name'] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted + # first are shown at the top + if group_tasks: + task_names.reverse() + + for index in range(len(tasks)): + tn = tasks[index]['name'] + del tasks[index]['name'] + tasks[index].update(shape_template) + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]['y0'] = groupID - bar_width + tasks[index]['y1'] = groupID + bar_width + + tasks[index]['fillcolor'] = index_vals_dict[ + chart[index][index_col] + ] + + # add a line for hover text and autorange + entry = dict( + x=[tasks[index]['x0'], tasks[index]['x1']], + y=[groupID, groupID], + name='', + marker={'color': 'white'} + ) + if "description" in tasks[index]: + entry['text'] = tasks[index]['description'] + del tasks[index]['description'] + data.append(entry) + + if show_colorbar is True: + # generate dummy data to generate legend + showlegend = True + for k, index_value in enumerate(index_vals): + data.append( + dict( + x=[tasks[index]['x0'], tasks[index]['x0']], + y=[k, k], + showlegend=True, + name=str(index_value), + hoverinfo='none', + marker=dict( + color=colors[k], + size=1 + ) + ) + ) + + layout = dict( + title=title, + showlegend=showlegend, + height=height, + width=width, + shapes=[], + hovermode='closest', + yaxis=dict( + showgrid=showgrid_y, + ticktext=task_names, + tickvals=list(range(len(task_names))), + range=[-1, len(task_names) + 1], + autorange=False, + zeroline=False, + ), + xaxis=dict( + showgrid=showgrid_x, + zeroline=False, + rangeselector=dict( + buttons=list([ + dict(count=7, + label='1w', + step='day', + stepmode='backward'), + dict(count=1, + label='1m', + step='month', + stepmode='backward'), + dict(count=6, + label='6m', + step='month', + stepmode='backward'), + dict(count=1, + label='YTD', + step='year', + stepmode='todate'), + dict(count=1, + label='1y', + step='year', + stepmode='backward'), + dict(step='all') + ]) + ), + type='date' + ) + ) + layout['shapes'] = tasks + + fig = dict(data=data, layout=layout) + return fig + + +def gantt_dict(chart, colors, title, index_col, show_colorbar, bar_width, + showgrid_x, showgrid_y, height, width, tasks=None, + task_names=None, data=None, group_tasks=False): + """ + Refer to FigureFactory.create_gantt() for docstring + """ + if tasks is None: + tasks = [] + if task_names is None: + task_names = [] + if data is None: + data = [] + showlegend = False + + for index in range(len(chart)): + task = dict(x0=chart[index]['Start'], + x1=chart[index]['Finish'], + name=chart[index]['Task']) + if 'Description' in chart[index]: + task['description'] = chart[index]['Description'] + tasks.append(task) + + shape_template = { + 'type': 'rect', + 'xref': 'x', + 'yref': 'y', + 'opacity': 1, + 'line': { + 'width': 0, + } + } + + index_vals = [] + for row in range(len(tasks)): + if chart[row][index_col] not in index_vals: + index_vals.append(chart[row][index_col]) + + index_vals.sort() + + # verify each value in index column appears in colors dictionary + for key in index_vals: + if key not in colors: + raise exceptions.PlotlyError( + "If you are using colors as a dictionary, all of its " + "keys must be all the values in the index column." + ) + + # create the list of task names + for index in range(len(tasks)): + tn = tasks[index]['name'] + # Is added to task_names if group_tasks is set to False, + # or if the option is used (True) it only adds them if the + # name is not already in the list + if not group_tasks or tn not in task_names: + task_names.append(tn) + # Guarantees that for grouped tasks the tasks that are inserted first + # are shown at the top + if group_tasks: + task_names.reverse() + + for index in range(len(tasks)): + tn = tasks[index]['name'] + del tasks[index]['name'] + tasks[index].update(shape_template) + + # If group_tasks is True, all tasks with the same name belong + # to the same row. + groupID = index + if group_tasks: + groupID = task_names.index(tn) + tasks[index]['y0'] = groupID - bar_width + tasks[index]['y1'] = groupID + bar_width + + tasks[index]['fillcolor'] = colors[chart[index][index_col]] + + # add a line for hover text and autorange + entry = dict( + x=[tasks[index]['x0'], tasks[index]['x1']], + y=[groupID, groupID], + showlegend=False, + name='', + marker={'color': 'white'} + ) + if "description" in tasks[index]: + entry['text'] = tasks[index]['description'] + del tasks[index]['description'] + data.append(entry) + + if show_colorbar is True: + # generate dummy data to generate legend + showlegend = True + for k, index_value in enumerate(index_vals): + data.append( + dict( + x=[tasks[index]['x0'], tasks[index]['x0']], + y=[k, k], + showlegend=True, + hoverinfo='none', + name=str(index_value), + marker=dict( + color=colors[index_value], + size=1 + ) + ) + ) + + layout = dict( + title=title, + showlegend=showlegend, + height=height, + width=width, + shapes=[], + hovermode='closest', + yaxis=dict( + showgrid=showgrid_y, + ticktext=task_names, + tickvals=list(range(len(task_names))), + range=[-1, len(task_names) + 1], + autorange=False, + zeroline=False, + ), + xaxis=dict( + showgrid=showgrid_x, + zeroline=False, + rangeselector=dict( + buttons=list([ + dict(count=7, + label='1w', + step='day', + stepmode='backward'), + dict(count=1, + label='1m', + step='month', + stepmode='backward'), + dict(count=6, + label='6m', + step='month', + stepmode='backward'), + dict(count=1, + label='YTD', + step='year', + stepmode='todate'), + dict(count=1, + label='1y', + step='year', + stepmode='backward'), + dict(step='all') + ]) + ), + type='date' + ) + ) + layout['shapes'] = tasks + + fig = dict(data=data, layout=layout) + return fig + + +def create_gantt(df, colors=None, index_col=None, show_colorbar=False, + reverse_colors=False, title='Gantt Chart', bar_width=0.2, + showgrid_x=False, showgrid_y=False, height=600, width=900, + tasks=None, task_names=None, data=None, group_tasks=False): + """ + Returns figure for a gantt chart + + :param (array|list) df: input data for gantt chart. Must be either a + a dataframe or a list. If dataframe, the columns must include + 'Task', 'Start' and 'Finish'. Other columns can be included and + used for indexing. If a list, its elements must be dictionaries + with the same required column headers: 'Task', 'Start' and + 'Finish'. + :param (str|list|dict|tuple) colors: either a plotly scale name, an + rgb or hex color, a color tuple or a list of colors. An rgb color + is of the form 'rgb(x, y, z)' where x, y, z belong to the interval + [0, 255] and a color tuple is a tuple of the form (a, b, c) where + a, b and c belong to [0, 1]. If colors is a list, it must + contain the valid color types aforementioned as its members. + If a dictionary, all values of the indexing column must be keys in + colors. + :param (str|float) index_col: the column header (if df is a data + frame) that will function as the indexing column. If df is a list, + index_col must be one of the keys in all the items of df. + :param (bool) show_colorbar: determines if colorbar will be visible. + Only applies if values in the index column are numeric. + :param (bool) reverse_colors: reverses the order of selected colors + :param (str) title: the title of the chart + :param (float) bar_width: the width of the horizontal bars in the plot + :param (bool) showgrid_x: show/hide the x-axis grid + :param (bool) showgrid_y: show/hide the y-axis grid + :param (float) height: the height of the chart + :param (float) width: the width of the chart + + Example 1: Simple Gantt Chart + ``` + import plotly.plotly as py + from plotly.figure_factory import create_gantt + + # Make data for chart + df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30'), + dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15'), + dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30')] + + # Create a figure + fig = create_gantt(df) + + # Plot the data + py.iplot(fig, filename='Simple Gantt Chart', world_readable=True) + ``` + + Example 2: Index by Column with Numerical Entries + ``` + import plotly.plotly as py + from plotly.figure_factory import create_gantt + + # Make data for chart + df = [dict(Task="Job A", Start='2009-01-01', + Finish='2009-02-30', Complete=10), + dict(Task="Job B", Start='2009-03-05', + Finish='2009-04-15', Complete=60), + dict(Task="Job C", Start='2009-02-20', + Finish='2009-05-30', Complete=95)] + + # Create a figure with Plotly colorscale + fig = create_gantt(df, colors='Blues', index_col='Complete', + show_colorbar=True, bar_width=0.5, + showgrid_x=True, showgrid_y=True) + + # Plot the data + py.iplot(fig, filename='Numerical Entries', world_readable=True) + ``` + + Example 3: Index by Column with String Entries + ``` + import plotly.plotly as py + from plotly.figure_factory import create_gantt + + # Make data for chart + df = [dict(Task="Job A", Start='2009-01-01', + Finish='2009-02-30', Resource='Apple'), + dict(Task="Job B", Start='2009-03-05', + Finish='2009-04-15', Resource='Grape'), + dict(Task="Job C", Start='2009-02-20', + Finish='2009-05-30', Resource='Banana')] + + # Create a figure with Plotly colorscale + fig = create_gantt(df, colors=['rgb(200, 50, 25)', (1, 0, 1), '#6c4774'], + index_col='Resource', reverse_colors=True, + show_colorbar=True) + + # Plot the data + py.iplot(fig, filename='String Entries', world_readable=True) + ``` + + Example 4: Use a dictionary for colors + ``` + import plotly.plotly as py + from plotly.figure_factory import create_gantt + + # Make data for chart + df = [dict(Task="Job A", Start='2009-01-01', + Finish='2009-02-30', Resource='Apple'), + dict(Task="Job B", Start='2009-03-05', + Finish='2009-04-15', Resource='Grape'), + dict(Task="Job C", Start='2009-02-20', + Finish='2009-05-30', Resource='Banana')] + + # Make a dictionary of colors + colors = {'Apple': 'rgb(255, 0, 0)', + 'Grape': 'rgb(170, 14, 200)', + 'Banana': (1, 1, 0.2)} + + # Create a figure with Plotly colorscale + fig = create_gantt(df, colors=colors, index_col='Resource', + show_colorbar=True) + + # Plot the data + py.iplot(fig, filename='dictioanry colors', world_readable=True) + ``` + + Example 5: Use a pandas dataframe + ``` + import plotly.plotly as py + from plotly.figure_factory import create_gantt + + import pandas as pd + + # Make data as a dataframe + df = pd.DataFrame([['Run', '2010-01-01', '2011-02-02', 10], + ['Fast', '2011-01-01', '2012-06-05', 55], + ['Eat', '2012-01-05', '2013-07-05', 94]], + columns=['Task', 'Start', 'Finish', 'Complete']) + + # Create a figure with Plotly colorscale + fig = create_gantt(df, colors='Blues', index_col='Complete', + show_colorbar=True, bar_width=0.5, + showgrid_x=True, showgrid_y=True) + + # Plot the data + py.iplot(fig, filename='data with dataframe', world_readable=True) + ``` + """ + # validate gantt input data + chart = validate_gantt(df) + + if index_col: + if index_col not in chart[0]: + raise exceptions.PlotlyError( + "In order to use an indexing column and assign colors to " + "the values of the index, you must choose an actual " + "column name in the dataframe or key if a list of " + "dictionaries is being used.") + + # validate gantt index column + index_list = [] + for dictionary in chart: + index_list.append(dictionary[index_col]) + utils.validate_index(index_list) + + # Validate colors + if isinstance(colors, dict): + colors = utils.validate_colors_dict(colors, 'rgb') + else: + colors = utils.validate_colors(colors, 'rgb') + + if reverse_colors is True: + colors.reverse() + + if not index_col: + if isinstance(colors, dict): + raise exceptions.PlotlyError( + "Error. You have set colors to a dictionary but have not " + "picked an index. An index is required if you are " + "assigning colors to particular values in a dictioanry." + ) + fig = gantt( + chart, colors, title, bar_width, showgrid_x, showgrid_y, + height, width, tasks=None, task_names=None, data=None, + group_tasks=group_tasks + ) + return fig + else: + if not isinstance(colors, dict): + fig = gantt_colorscale( + chart, colors, title, index_col, show_colorbar, bar_width, + showgrid_x, showgrid_y, height, width, + tasks=None, task_names=None, data=None, group_tasks=group_tasks + ) + return fig + else: + fig = gantt_dict( + chart, colors, title, index_col, show_colorbar, bar_width, + showgrid_x, showgrid_y, height, width, + tasks=None, task_names=None, data=None, group_tasks=group_tasks + ) + return fig diff --git a/plotly/figure_factory/figure_factory/_ohlc.py b/plotly/figure_factory/figure_factory/_ohlc.py new file mode 100644 index 00000000000..b5e84cd6d93 --- /dev/null +++ b/plotly/figure_factory/figure_factory/_ohlc.py @@ -0,0 +1,380 @@ +from __future__ import absolute_import + +from plotly import exceptions +from plotly.graph_objs import graph_objs +from plotly.figure_factory import utils + + +# Default colours for finance charts +_DEFAULT_INCREASING_COLOR = '#3D9970' # http://clrs.cc +_DEFAULT_DECREASING_COLOR = '#FF4136' + + +def validate_ohlc(open, high, low, close, direction, **kwargs): + """ + ohlc and candlestick specific validations + + Specifically, this checks that the high value is the greatest value and + the low value is the lowest value in each unit. + + See FigureFactory.create_ohlc() or FigureFactory.create_candlestick() + for params + + :raises: (PlotlyError) If the high value is not the greatest value in + each unit. + :raises: (PlotlyError) If the low value is not the lowest value in each + unit. + :raises: (PlotlyError) If direction is not 'increasing' or 'decreasing' + """ + for lst in [open, low, close]: + for index in range(len(high)): + if high[index] < lst[index]: + raise exceptions.PlotlyError("Oops! Looks like some of " + "your high values are less " + "the corresponding open, " + "low, or close values. " + "Double check that your data " + "is entered in O-H-L-C order") + + for lst in [open, high, close]: + for index in range(len(low)): + if low[index] > lst[index]: + raise exceptions.PlotlyError("Oops! Looks like some of " + "your low values are greater " + "than the corresponding high" + ", open, or close values. " + "Double check that your data " + "is entered in O-H-L-C order") + + direction_opts = ('increasing', 'decreasing', 'both') + if direction not in direction_opts: + raise exceptions.PlotlyError("direction must be defined as " + "'increasing', 'decreasing', or " + "'both'") + + +def make_increasing_ohlc(open, high, low, close, dates, **kwargs): + """ + Makes increasing ohlc sticks + + _make_increasing_ohlc() and _make_decreasing_ohlc separate the + increasing trace from the decreasing trace so kwargs (such as + color) can be passed separately to increasing or decreasing traces + when direction is set to 'increasing' or 'decreasing' in + FigureFactory.create_candlestick() + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to increasing trace via + plotly.graph_objs.Scatter. + + :rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc + sticks. + """ + (flat_increase_x, + flat_increase_y, + text_increase) = _OHLC(open, high, low, close, dates).get_increase() + + if 'name' in kwargs: + showlegend = True + else: + kwargs.setdefault('name', 'Increasing') + showlegend = False + + kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR, + width=1)) + kwargs.setdefault('text', text_increase) + + ohlc_incr = dict(type='scatter', + x=flat_increase_x, + y=flat_increase_y, + mode='lines', + showlegend=showlegend, + **kwargs) + return ohlc_incr + + +def make_decreasing_ohlc(open, high, low, close, dates, **kwargs): + """ + Makes decreasing ohlc sticks + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing values + :param (list) dates: list of datetime objects. Default: None + :param kwargs: kwargs to be passed to increasing trace via + plotly.graph_objs.Scatter. + + :rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc + sticks. + """ + (flat_decrease_x, + flat_decrease_y, + text_decrease) = _OHLC(open, high, low, close, dates).get_decrease() + + kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR, + width=1)) + kwargs.setdefault('text', text_decrease) + kwargs.setdefault('showlegend', False) + kwargs.setdefault('name', 'Decreasing') + + ohlc_decr = dict(type='scatter', + x=flat_decrease_x, + y=flat_decrease_y, + mode='lines', + **kwargs) + return ohlc_decr + + +def create_ohlc(open, high, low, close, dates=None, direction='both', + **kwargs): + """ + BETA function that creates an ohlc chart + + :param (list) open: opening values + :param (list) high: high values + :param (list) low: low values + :param (list) close: closing + :param (list) dates: list of datetime objects. Default: None + :param (string) direction: direction can be 'increasing', 'decreasing', + or 'both'. When the direction is 'increasing', the returned figure + consists of all units where the close value is greater than the + corresponding open value, and when the direction is 'decreasing', + the returned figure consists of all units where the close value is + less than or equal to the corresponding open value. When the + direction is 'both', both increasing and decreasing units are + returned. Default: 'both' + :param kwargs: kwargs passed through plotly.graph_objs.Scatter. + These kwargs describe other attributes about the ohlc Scatter trace + such as the color or the legend name. For more information on valid + kwargs call help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of an ohlc chart figure. + + Example 1: Simple OHLC chart from a Pandas DataFrame + ``` + import plotly.plotly as py + from plotly.figure_factory import create_ohlc + from datetime import datetime + + import pandas.io.data as web + + df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15), + datetime(2008, 10, 15)) + fig = create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index) + + py.plot(fig, filename='finance/aapl-ohlc') + ``` + + Example 2: Add text and annotations to the OHLC chart + ``` + import plotly.plotly as py + from plotly.figure_factory import create_ohlc + from datetime import datetime + + import pandas.io.data as web + + df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15), + datetime(2008, 10, 15)) + fig = create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index) + + # Update the fig - options here: https://plot.ly/python/reference/#Layout + fig['layout'].update({ + 'title': 'The Great Recession', + 'yaxis': {'title': 'AAPL Stock'}, + 'shapes': [{ + 'x0': '2008-09-15', 'x1': '2008-09-15', 'type': 'line', + 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper', + 'line': {'color': 'rgb(40,40,40)', 'width': 0.5} + }], + 'annotations': [{ + 'text': "the fall of Lehman Brothers", + 'x': '2008-09-15', 'y': 1.02, + 'xref': 'x', 'yref': 'paper', + 'showarrow': False, 'xanchor': 'left' + }] + }) + + py.plot(fig, filename='finance/aapl-recession-ohlc', validate=False) + ``` + + Example 3: Customize the OHLC colors + ``` + import plotly.plotly as py + from plotly.figure_factory import create_ohlc + from plotly.graph_objs import Line, Marker + from datetime import datetime + + import pandas.io.data as web + + df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), + datetime(2009, 4, 1)) + + # Make increasing ohlc sticks and customize their color and name + fig_increasing = create_ohlc(df.Open, df.High, df.Low, df.Close, + dates=df.index, direction='increasing', + name='AAPL', + line=Line(color='rgb(150, 200, 250)')) + + # Make decreasing ohlc sticks and customize their color and name + fig_decreasing = create_ohlc(df.Open, df.High, df.Low, df.Close, + dates=df.index, direction='decreasing', + line=Line(color='rgb(128, 128, 128)')) + + # Initialize the figure + fig = fig_increasing + + # Add decreasing data with .extend() + fig['data'].extend(fig_decreasing['data']) + + py.iplot(fig, filename='finance/aapl-ohlc-colors', validate=False) + ``` + + Example 4: OHLC chart with datetime objects + ``` + import plotly.plotly as py + from plotly.figure_factory import create_ohlc + + from datetime import datetime + + # Add data + open_data = [33.0, 33.3, 33.5, 33.0, 34.1] + high_data = [33.1, 33.3, 33.6, 33.2, 34.8] + low_data = [32.7, 32.7, 32.8, 32.6, 32.8] + close_data = [33.0, 32.9, 33.3, 33.1, 33.1] + dates = [datetime(year=2013, month=10, day=10), + datetime(year=2013, month=11, day=10), + datetime(year=2013, month=12, day=10), + datetime(year=2014, month=1, day=10), + datetime(year=2014, month=2, day=10)] + + # Create ohlc + fig = create_ohlc(open_data, high_data, low_data, close_data, dates=dates) + + py.iplot(fig, filename='finance/simple-ohlc', validate=False) + ``` + """ + if dates is not None: + utils.validate_equal_length(open, high, low, close, dates) + else: + utils.validate_equal_length(open, high, low, close) + validate_ohlc(open, high, low, close, direction, **kwargs) + + if direction is 'increasing': + ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, + **kwargs) + data = [ohlc_incr] + elif direction is 'decreasing': + ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, + **kwargs) + data = [ohlc_decr] + else: + ohlc_incr = make_increasing_ohlc(open, high, low, close, dates, + **kwargs) + ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates, + **kwargs) + data = [ohlc_incr, ohlc_decr] + + layout = graph_objs.Layout(xaxis=dict(zeroline=False), + hovermode='closest') + + return graph_objs.Figure(data=data, layout=layout) + + +class _OHLC(object): + """ + Refer to FigureFactory.create_ohlc_increase() for docstring. + """ + def __init__(self, open, high, low, close, dates, **kwargs): + self.open = open + self.high = high + self.low = low + self.close = close + self.empty = [None] * len(open) + self.dates = dates + + self.all_x = [] + self.all_y = [] + self.increase_x = [] + self.increase_y = [] + self.decrease_x = [] + self.decrease_y = [] + self.get_all_xy() + self.separate_increase_decrease() + + def get_all_xy(self): + """ + Zip data to create OHLC shape + + OHLC shape: low to high vertical bar with + horizontal branches for open and close values. + If dates were added, the smallest date difference is calculated and + multiplied by .2 to get the length of the open and close branches. + If no date data was provided, the x-axis is a list of integers and the + length of the open and close branches is .2. + """ + self.all_y = list(zip(self.open, self.open, self.high, + self.low, self.close, self.close, self.empty)) + if self.dates is not None: + date_dif = [] + for i in range(len(self.dates) - 1): + date_dif.append(self.dates[i + 1] - self.dates[i]) + date_dif_min = (min(date_dif)) / 5 + self.all_x = [[x - date_dif_min, x, x, x, x, x + + date_dif_min, None] for x in self.dates] + else: + self.all_x = [[x - .2, x, x, x, x, x + .2, None] + for x in range(len(self.open))] + + def separate_increase_decrease(self): + """ + Separate data into two groups: increase and decrease + + (1) Increase, where close > open and + (2) Decrease, where close <= open + """ + for index in range(len(self.open)): + if self.close[index] is None: + pass + elif self.close[index] > self.open[index]: + self.increase_x.append(self.all_x[index]) + self.increase_y.append(self.all_y[index]) + else: + self.decrease_x.append(self.all_x[index]) + self.decrease_y.append(self.all_y[index]) + + def get_increase(self): + """ + Flatten increase data and get increase text + + :rtype (list, list, list): flat_increase_x: x-values for the increasing + trace, flat_increase_y: y=values for the increasing trace and + text_increase: hovertext for the increasing trace + """ + flat_increase_x = utils.flatten(self.increase_x) + flat_increase_y = utils.flatten(self.increase_y) + text_increase = (("Open", "Open", "High", + "Low", "Close", "Close", '') + * (len(self.increase_x))) + + return flat_increase_x, flat_increase_y, text_increase + + def get_decrease(self): + """ + Flatten decrease data and get decrease text + + :rtype (list, list, list): flat_decrease_x: x-values for the decreasing + trace, flat_decrease_y: y=values for the decreasing trace and + text_decrease: hovertext for the decreasing trace + """ + flat_decrease_x = utils.flatten(self.decrease_x) + flat_decrease_y = utils.flatten(self.decrease_y) + text_decrease = (("Open", "Open", "High", + "Low", "Close", "Close", '') + * (len(self.decrease_x))) + + return flat_decrease_x, flat_decrease_y, text_decrease diff --git a/plotly/figure_factory/figure_factory/_quiver.py b/plotly/figure_factory/figure_factory/_quiver.py new file mode 100644 index 00000000000..ef9d00e80dd --- /dev/null +++ b/plotly/figure_factory/figure_factory/_quiver.py @@ -0,0 +1,283 @@ +from __future__ import absolute_import + +import math + +from plotly import exceptions +from plotly.graph_objs import graph_objs +from plotly.figure_factory import utils + + +def create_quiver(x, y, u, v, scale=.1, arrow_scale=.3, + angle=math.pi / 9, scaleratio=None, **kwargs): + """ + Returns data for a quiver plot. + + :param (list|ndarray) x: x coordinates of the arrow locations + :param (list|ndarray) y: y coordinates of the arrow locations + :param (list|ndarray) u: x components of the arrow vectors + :param (list|ndarray) v: y components of the arrow vectors + :param (float in [0,1]) scale: scales size of the arrows(ideally to + avoid overlap). Default = .1 + :param (float in [0,1]) arrow_scale: value multiplied to length of barb + to get length of arrowhead. Default = .3 + :param (angle in radians) angle: angle of arrowhead. Default = pi/9 + :param (positive float) scaleratio: the ratio between the scale of the y-axis + and the scale of the x-axis (scale_y / scale_x). Default = None, the + scale ratio is not fixed. + :param kwargs: kwargs passed through plotly.graph_objs.Scatter + for more information on valid kwargs call + help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of quiver figure. + + Example 1: Trivial Quiver + ``` + import plotly.plotly as py + from plotly.figure_factory import create_quiver + + import math + + # 1 Arrow from (0,0) to (1,1) + fig = create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1) + + py.plot(fig, filename='quiver') + ``` + + Example 2: Quiver plot using meshgrid + ``` + import plotly.plotly as py + from plotly.figure_factory import create_quiver + + import numpy as np + import math + + # Add data + x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2)) + u = np.cos(x)*y + v = np.sin(x)*y + + #Create quiver + fig = create_quiver(x, y, u, v) + + # Plot + py.plot(fig, filename='quiver') + ``` + + Example 3: Styling the quiver plot + ``` + import plotly.plotly as py + from plotly.figure_factory import create_quiver + import numpy as np + import math + + # Add data + x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5), + np.arange(-math.pi, math.pi, .5)) + u = np.cos(x)*y + v = np.sin(x)*y + + # Create quiver + fig = create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6, + name='Wind Velocity', line=dict(width=1)) + + # Add title to layout + fig['layout'].update(title='Quiver Plot') + + # Plot + py.plot(fig, filename='quiver') + ``` + + Example 4: Forcing a fix scale ratio to maintain the arrow length + ``` + import plotly.plotly as py + from plotly.figure_factory import create_quiver + + import numpy as np + + # Add data + x,y = np.meshgrid(np.arange(0.5, 3.5, .5), np.arange(0.5, 4.5, .5)) + u = x + v = y + angle = np.arctan(v / u) + norm = 0.25 + u = norm * np.cos(angle) + v = norm * np.sin(angle) + + # Create quiver with a fix scale ratio + fig = create_quiver(x, y, u, v, scale = 1, scaleratio = 0.5) + + # Plot + py.plot(fig, filename='quiver') + ``` + """ + utils.validate_equal_length(x, y, u, v) + utils.validate_positive_scalars(arrow_scale=arrow_scale, scale=scale) + + if scaleratio is None: + quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle) + else: + quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle, scaleratio) + + barb_x, barb_y = quiver_obj.get_barbs() + arrow_x, arrow_y = quiver_obj.get_quiver_arrows() + + quiver_plot = graph_objs.Scatter(x=barb_x + arrow_x, + y=barb_y + arrow_y, + mode='lines', **kwargs) + + data = [quiver_plot] + + if scaleratio is None: + layout = graph_objs.Layout(hovermode='closest') + else: + layout = graph_objs.Layout( + hovermode='closest', + yaxis=dict( + scaleratio = scaleratio, + scaleanchor = "x" + ) + ) + + return graph_objs.Figure(data=data, layout=layout) + +class _Quiver(object): + """ + Refer to FigureFactory.create_quiver() for docstring + """ + def __init__(self, x, y, u, v, + scale, arrow_scale, angle, scaleratio=1, **kwargs): + try: + x = utils.flatten(x) + except exceptions.PlotlyError: + pass + + try: + y = utils.flatten(y) + except exceptions.PlotlyError: + pass + + try: + u = utils.flatten(u) + except exceptions.PlotlyError: + pass + + try: + v = utils.flatten(v) + except exceptions.PlotlyError: + pass + + self.x = x + self.y = y + self.u = u + self.v = v + self.scale = scale + self.scaleratio = scaleratio + self.arrow_scale = arrow_scale + self.angle = angle + self.end_x = [] + self.end_y = [] + self.scale_uv() + barb_x, barb_y = self.get_barbs() + arrow_x, arrow_y = self.get_quiver_arrows() + + def scale_uv(self): + """ + Scales u and v to avoid overlap of the arrows. + + u and v are added to x and y to get the + endpoints of the arrows so a smaller scale value will + result in less overlap of arrows. + """ + self.u = [i * self.scale * self.scaleratio for i in self.u] + self.v = [i * self.scale for i in self.v] + + def get_barbs(self): + """ + Creates x and y startpoint and endpoint pairs + + After finding the endpoint of each barb this zips startpoint and + endpoint pairs to create 2 lists: x_values for barbs and y values + for barbs + + :rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint + x_value pairs separated by a None to create the barb of the arrow, + and list of startpoint and endpoint y_value pairs separated by a + None to create the barb of the arrow. + """ + self.end_x = [i + j for i, j in zip(self.x, self.u)] + self.end_y = [i + j for i, j in zip(self.y, self.v)] + empty = [None] * len(self.x) + barb_x = utils.flatten(zip(self.x, self.end_x, empty)) + barb_y = utils.flatten(zip(self.y, self.end_y, empty)) + return barb_x, barb_y + + def get_quiver_arrows(self): + """ + Creates lists of x and y values to plot the arrows + + Gets length of each barb then calculates the length of each side of + the arrow. Gets angle of barb and applies angle to each side of the + arrowhead. Next uses arrow_scale to scale the length of arrowhead and + creates x and y values for arrowhead point1 and point2. Finally x and y + values for point1, endpoint and point2s for each arrowhead are + separated by a None and zipped to create lists of x and y values for + the arrows. + + :rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2 + x_values separated by a None to create the arrowhead and list of + point1, endpoint, point2 y_values separated by a None to create + the barb of the arrow. + """ + dif_x = [i - j for i, j in zip(self.end_x, self.x)] + dif_y = [i - j for i, j in zip(self.end_y, self.y)] + + # Get barb lengths(default arrow length = 30% barb length) + barb_len = [None] * len(self.x) + for index in range(len(barb_len)): + barb_len[index] = math.hypot(dif_x[index] / self.scaleratio, dif_y[index]) + + # Make arrow lengths + arrow_len = [None] * len(self.x) + arrow_len = [i * self.arrow_scale for i in barb_len] + + # Get barb angles + barb_ang = [None] * len(self.x) + for index in range(len(barb_ang)): + barb_ang[index] = math.atan2(dif_y[index], dif_x[index] / self.scaleratio) + + # Set angles to create arrow + ang1 = [i + self.angle for i in barb_ang] + ang2 = [i - self.angle for i in barb_ang] + + cos_ang1 = [None] * len(ang1) + for index in range(len(ang1)): + cos_ang1[index] = math.cos(ang1[index]) + seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)] + + sin_ang1 = [None] * len(ang1) + for index in range(len(ang1)): + sin_ang1[index] = math.sin(ang1[index]) + seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)] + + cos_ang2 = [None] * len(ang2) + for index in range(len(ang2)): + cos_ang2[index] = math.cos(ang2[index]) + seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)] + + sin_ang2 = [None] * len(ang2) + for index in range(len(ang2)): + sin_ang2[index] = math.sin(ang2[index]) + seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)] + + # Set coordinates to create arrow + for index in range(len(self.end_x)): + point1_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg1_x)] + point1_y = [i - j for i, j in zip(self.end_y, seg1_y)] + point2_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg2_x)] + point2_y = [i - j for i, j in zip(self.end_y, seg2_y)] + + # Combine lists to create arrow + empty = [None] * len(self.end_x) + arrow_x = utils.flatten(zip(point1_x, self.end_x, point2_x, empty)) + arrow_y = utils.flatten(zip(point1_y, self.end_y, point2_y, empty)) + return arrow_x, arrow_y diff --git a/plotly/figure_factory/figure_factory/_scatterplot.py b/plotly/figure_factory/figure_factory/_scatterplot.py new file mode 100644 index 00000000000..25ae4f0976f --- /dev/null +++ b/plotly/figure_factory/figure_factory/_scatterplot.py @@ -0,0 +1,1104 @@ +from __future__ import absolute_import + +import six + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs +from plotly.tools import make_subplots + +pd = optional_imports.get_module('pandas') + +DIAG_CHOICES = ['scatter', 'histogram', 'box'] +VALID_COLORMAP_TYPES = ['cat', 'seq'] + + +def hide_tick_labels_from_box_subplots(fig): + """ + Hides tick labels for box plots in scatterplotmatrix subplots. + """ + boxplot_xaxes = [] + for trace in fig['data']: + if trace['type'] == 'box': + # stores the xaxes which correspond to boxplot subplots + # since we use xaxis1, xaxis2, etc, in plotly.py + boxplot_xaxes.append( + 'xaxis{}'.format(trace['xaxis'][1:]) + ) + for xaxis in boxplot_xaxes: + fig['layout'][xaxis]['showticklabels'] = False + + +def validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs): + """ + Validates basic inputs for FigureFactory.create_scatterplotmatrix() + + :raises: (PlotlyError) If pandas is not imported + :raises: (PlotlyError) If pandas dataframe is not inputted + :raises: (PlotlyError) If pandas dataframe has <= 1 columns + :raises: (PlotlyError) If diagonal plot choice (diag) is not one of + the viable options + :raises: (PlotlyError) If colormap_type is not a valid choice + :raises: (PlotlyError) If kwargs contains 'size', 'color' or + 'colorscale' + """ + if not pd: + raise ImportError("FigureFactory.scatterplotmatrix requires " + "a pandas DataFrame.") + + # Check if pandas dataframe + if not isinstance(df, pd.core.frame.DataFrame): + raise exceptions.PlotlyError("Dataframe not inputed. Please " + "use a pandas dataframe to pro" + "duce a scatterplot matrix.") + + # Check if dataframe is 1 column or less + if len(df.columns) <= 1: + raise exceptions.PlotlyError("Dataframe has only one column. To " + "use the scatterplot matrix, use at " + "least 2 columns.") + + # Check that diag parameter is a valid selection + if diag not in DIAG_CHOICES: + raise exceptions.PlotlyError("Make sure diag is set to " + "one of {}".format(DIAG_CHOICES)) + + # Check that colormap_types is a valid selection + if colormap_type not in VALID_COLORMAP_TYPES: + raise exceptions.PlotlyError("Must choose a valid colormap type. " + "Either 'cat' or 'seq' for a cate" + "gorical and sequential colormap " + "respectively.") + + # Check for not 'size' or 'color' in 'marker' of **kwargs + if 'marker' in kwargs: + FORBIDDEN_PARAMS = ['size', 'color', 'colorscale'] + if any(param in kwargs['marker'] for param in FORBIDDEN_PARAMS): + raise exceptions.PlotlyError("Your kwargs dictionary cannot " + "include the 'size', 'color' or " + "'colorscale' key words inside " + "the marker dict since 'size' is " + "already an argument of the " + "scatterplot matrix function and " + "both 'color' and 'colorscale " + "are set internally.") + + +def scatterplot(dataframe, headers, diag, size, height, width, title, + **kwargs): + """ + Refer to FigureFactory.create_scatterplotmatrix() for docstring + + Returns fig for scatterplotmatrix without index + + """ + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + # Insert traces into trace_list + for listy in dataframe: + for listx in dataframe: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=listx, + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=listx, + name=None, + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + showlegend=False, + **kwargs + ) + trace_list.append(trace) + else: + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + marker=dict( + size=size), + showlegend=False, + **kwargs + ) + trace_list.append(trace) + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + fig.append_trace(trace_list[trace_index], + y_index, + x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) + fig['layout'][xaxis_key].update(title=headers[j]) + for j in range(dim): + yaxis_key = 'yaxis{}'.format(1 + (dim * j)) + fig['layout'][yaxis_key].update(title=headers[j]) + + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True + ) + + hide_tick_labels_from_box_subplots(fig) + + return fig + + +def scatterplot_dict(dataframe, headers, diag, size, + height, width, title, index, index_vals, + endpts, colormap, colormap_type, **kwargs): + """ + Refer to FigureFactory.create_scatterplotmatrix() for docstring + + Returns fig for scatterplotmatrix with both index and colormap picked. + Used if colormap is a dictionary with index values as keys pointing to + colors. Forces colormap_type to behave categorically because it would + not make sense colors are assigned to each index value and thus + implies that a categorical approach should be taken + + """ + + theme = colormap + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Work over all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + # create a dictionary for index_vals + unique_index_vals = {} + for name in index_vals: + if name not in unique_index_vals: + unique_index_vals[name] = [] + + # Fill all the rest of the names into the dictionary + for name in sorted(unique_index_vals.keys()): + new_listx = [] + new_listy = [] + for j in range(len(index_vals)): + if index_vals[j] == name: + new_listx.append(listx[j]) + new_listy.append(listy[j]) + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[name]), + showlegend=True + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[name]), + showlegend=True + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = theme[name] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + showlegend=True, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + marker=dict( + size=size, + color=theme[name]), + showlegend=True, + **kwargs + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[name]), + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[name]), + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = theme[name] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + showlegend=False, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + marker=dict( + size=size, + color=theme[name]), + showlegend=False, + **kwargs + ) + # Push the trace into dictionary + unique_index_vals[name] = trace + trace_list.append(unique_index_vals) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + for name in sorted(trace_list[trace_index].keys()): + fig.append_trace( + trace_list[trace_index][name], + y_index, + x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) + fig['layout'][xaxis_key].update(title=headers[j]) + + for j in range(dim): + yaxis_key = 'yaxis{}'.format(1 + (dim * j)) + fig['layout'][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == 'histogram': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True, + barmode='stack') + return fig + + else: + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + +def scatterplot_theme(dataframe, headers, diag, size, height, width, title, + index, index_vals, endpts, colormap, colormap_type, + **kwargs): + """ + Refer to FigureFactory.create_scatterplotmatrix() for docstring + + Returns fig for scatterplotmatrix with both index and colormap picked + + """ + + # Check if index is made of string values + if isinstance(index_vals[0], str): + unique_index_vals = [] + for name in index_vals: + if name not in unique_index_vals: + unique_index_vals.append(name) + n_colors_len = len(unique_index_vals) + + # Convert colormap to list of n RGB tuples + if colormap_type == 'seq': + foo = colors.color_parser(colormap, colors.unlabel_rgb) + foo = utils.n_colors(foo[0], foo[1], n_colors_len) + theme = colors.color_parser(foo, colors.label_rgb) + + if colormap_type == 'cat': + # leave list of colors the same way + theme = colormap + + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Work over all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + # create a dictionary for index_vals + unique_index_vals = {} + for name in index_vals: + if name not in unique_index_vals: + unique_index_vals[name] = [] + + c_indx = 0 # color index + # Fill all the rest of the names into the dictionary + for name in sorted(unique_index_vals.keys()): + new_listx = [] + new_listy = [] + for j in range(len(index_vals)): + if index_vals[j] == name: + new_listx.append(listx[j]) + new_listy.append(listy[j]) + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[c_indx]), + showlegend=True + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[c_indx]), + showlegend=True + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + showlegend=True, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + marker=dict( + size=size, + color=theme[c_indx]), + showlegend=True, + **kwargs + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[c_indx]), + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[c_indx]), + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + showlegend=False, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=name, + marker=dict( + size=size, + color=theme[c_indx]), + showlegend=False, + **kwargs + ) + # Push the trace into dictionary + unique_index_vals[name] = trace + if c_indx >= (len(theme) - 1): + c_indx = -1 + c_indx += 1 + trace_list.append(unique_index_vals) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + for name in sorted(trace_list[trace_index].keys()): + fig.append_trace( + trace_list[trace_index][name], + y_index, + x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) + fig['layout'][xaxis_key].update(title=headers[j]) + + for j in range(dim): + yaxis_key = 'yaxis{}'.format(1 + (dim * j)) + fig['layout'][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == 'histogram': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True, + barmode='stack') + return fig + + elif diag == 'box': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + else: + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + else: + if endpts: + intervals = utils.endpts_to_intervals(endpts) + + # Convert colormap to list of n RGB tuples + if colormap_type == 'seq': + foo = colors.color_parser(colormap, colors.unlabel_rgb) + foo = utils.n_colors(foo[0], foo[1], len(intervals)) + theme = colors.color_parser(foo, colors.label_rgb) + + if colormap_type == 'cat': + # leave list of colors the same way + theme = colormap + + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Work over all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + interval_labels = {} + for interval in intervals: + interval_labels[str(interval)] = [] + + c_indx = 0 # color index + # Fill all the rest of the names into the dictionary + for interval in intervals: + new_listx = [] + new_listy = [] + for j in range(len(index_vals)): + if interval[0] < index_vals[j] <= interval[1]: + new_listx.append(listx[j]) + new_listy.append(listy[j]) + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[c_indx]), + showlegend=True + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[c_indx]), + showlegend=True + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + (kwargs['marker'] + ['color']) = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=str(interval), + showlegend=True, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=str(interval), + marker=dict( + size=size, + color=theme[c_indx]), + showlegend=True, + **kwargs + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=new_listx, + marker=dict( + color=theme[c_indx]), + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=new_listx, + name=None, + marker=dict( + color=theme[c_indx]), + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + (kwargs['marker'] + ['color']) = theme[c_indx] + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=str(interval), + showlegend=False, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=new_listx, + y=new_listy, + mode='markers', + name=str(interval), + marker=dict( + size=size, + color=theme[c_indx]), + showlegend=False, + **kwargs + ) + # Push the trace into dictionary + interval_labels[str(interval)] = trace + if c_indx >= (len(theme) - 1): + c_indx = -1 + c_indx += 1 + trace_list.append(interval_labels) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + for interval in intervals: + fig.append_trace( + trace_list[trace_index][str(interval)], + y_index, + x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) + fig['layout'][xaxis_key].update(title=headers[j]) + for j in range(dim): + yaxis_key = 'yaxis{}'.format(1 + (dim * j)) + fig['layout'][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == 'histogram': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True, + barmode='stack') + return fig + + elif diag == 'box': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + else: + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + else: + theme = colormap + + # add a copy of rgb color to theme if it contains one color + if len(theme) <= 1: + theme.append(theme[0]) + + color = [] + for incr in range(len(theme)): + color.append([1. / (len(theme) - 1) * incr, theme[incr]]) + + dim = len(dataframe) + fig = make_subplots(rows=dim, cols=dim, print_grid=False) + trace_list = [] + legend_param = 0 + # Run through all permutations of list pairs + for listy in dataframe: + for listx in dataframe: + # Generate trace with VISIBLE icon + if legend_param == 1: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=listx, + marker=dict( + color=theme[0]), + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=listx, + marker=dict( + color=theme[0]), + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = index_vals + kwargs['marker']['colorscale'] = color + kwargs['marker']['showscale'] = True + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + showlegend=False, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + marker=dict( + size=size, + color=index_vals, + colorscale=color, + showscale=True), + showlegend=False, + **kwargs + ) + # Generate trace with INVISIBLE icon + else: + if (listx == listy) and (diag == 'histogram'): + trace = graph_objs.Histogram( + x=listx, + marker=dict( + color=theme[0]), + showlegend=False + ) + elif (listx == listy) and (diag == 'box'): + trace = graph_objs.Box( + y=listx, + marker=dict( + color=theme[0]), + showlegend=False + ) + else: + if 'marker' in kwargs: + kwargs['marker']['size'] = size + kwargs['marker']['color'] = index_vals + kwargs['marker']['colorscale'] = color + kwargs['marker']['showscale'] = False + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + showlegend=False, + **kwargs + ) + else: + trace = graph_objs.Scatter( + x=listx, + y=listy, + mode='markers', + marker=dict( + size=size, + color=index_vals, + colorscale=color, + showscale=False), + showlegend=False, + **kwargs + ) + # Push the trace into list + trace_list.append(trace) + legend_param += 1 + + trace_index = 0 + indices = range(1, dim + 1) + for y_index in indices: + for x_index in indices: + fig.append_trace(trace_list[trace_index], + y_index, + x_index) + trace_index += 1 + + # Insert headers into the figure + for j in range(dim): + xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j) + fig['layout'][xaxis_key].update(title=headers[j]) + for j in range(dim): + yaxis_key = 'yaxis{}'.format(1 + (dim * j)) + fig['layout'][yaxis_key].update(title=headers[j]) + + hide_tick_labels_from_box_subplots(fig) + + if diag == 'histogram': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True, + barmode='stack') + return fig + + elif diag == 'box': + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + else: + fig['layout'].update( + height=height, width=width, + title=title, + showlegend=True) + return fig + + +def create_scatterplotmatrix(df, index=None, endpts=None, diag='scatter', + height=500, width=500, size=6, + title='Scatterplot Matrix', colormap=None, + colormap_type='cat', dataframe=None, + headers=None, index_vals=None, **kwargs): + """ + Returns data for a scatterplot matrix. + + :param (array) df: array of the data with column headers + :param (str) index: name of the index column in data array + :param (list|tuple) endpts: takes an increasing sequece of numbers + that defines intervals on the real line. They are used to group + the entries in an index of numbers into their corresponding + interval and therefore can be treated as categorical data + :param (str) diag: sets the chart type for the main diagonal plots. + The options are 'scatter', 'histogram' and 'box'. + :param (int|float) height: sets the height of the chart + :param (int|float) width: sets the width of the chart + :param (float) size: sets the marker size (in px) + :param (str) title: the title label of the scatterplot matrix + :param (str|tuple|list|dict) colormap: either a plotly scale name, + an rgb or hex color, a color tuple, a list of colors or a + dictionary. An rgb color is of the form 'rgb(x, y, z)' where + x, y and z belong to the interval [0, 255] and a color tuple is a + tuple of the form (a, b, c) where a, b and c belong to [0, 1]. + If colormap is a list, it must contain valid color types as its + members. + If colormap is a dictionary, all the string entries in + the index column must be a key in colormap. In this case, the + colormap_type is forced to 'cat' or categorical + :param (str) colormap_type: determines how colormap is interpreted. + Valid choices are 'seq' (sequential) and 'cat' (categorical). If + 'seq' is selected, only the first two colors in colormap will be + considered (when colormap is a list) and the index values will be + linearly interpolated between those two colors. This option is + forced if all index values are numeric. + If 'cat' is selected, a color from colormap will be assigned to + each category from index, including the intervals if endpts is + being used + :param (dict) **kwargs: a dictionary of scatterplot arguments + The only forbidden parameters are 'size', 'color' and + 'colorscale' in 'marker' + + Example 1: Vanilla Scatterplot Matrix + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + + # Create dataframe + df = pd.DataFrame(np.random.randn(10, 2), + columns=['Column 1', 'Column 2']) + + # Create scatterplot matrix + fig = create_scatterplotmatrix(df) + + # Plot + py.iplot(fig, filename='Vanilla Scatterplot Matrix') + ``` + + Example 2: Indexing a Column + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + + # Create dataframe with index + df = pd.DataFrame(np.random.randn(10, 2), + columns=['A', 'B']) + + # Add another column of strings to the dataframe + df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple', + 'grape', 'pear', 'pear', 'apple', 'pear']) + + # Create scatterplot matrix + fig = create_scatterplotmatrix(df, index='Fruit', size=10) + + # Plot + py.iplot(fig, filename = 'Scatterplot Matrix with Index') + ``` + + Example 3: Styling the Diagonal Subplots + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + + # Create dataframe with index + df = pd.DataFrame(np.random.randn(10, 4), + columns=['A', 'B', 'C', 'D']) + + # Add another column of strings to the dataframe + df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple', + 'grape', 'pear', 'pear', 'apple', 'pear']) + + # Create scatterplot matrix + fig = create_scatterplotmatrix(df, diag='box', index='Fruit', height=1000, + width=1000) + + # Plot + py.iplot(fig, filename = 'Scatterplot Matrix - Diagonal Styling') + ``` + + Example 4: Use a Theme to Style the Subplots + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + + # Create dataframe with random data + df = pd.DataFrame(np.random.randn(100, 3), + columns=['A', 'B', 'C']) + + # Create scatterplot matrix using a built-in + # Plotly palette scale and indexing column 'A' + fig = create_scatterplotmatrix(df, diag='histogram', index='A', + colormap='Blues', height=800, width=800) + + # Plot + py.iplot(fig, filename = 'Scatterplot Matrix - Colormap Theme') + ``` + + Example 5: Example 4 with Interval Factoring + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + + # Create dataframe with random data + df = pd.DataFrame(np.random.randn(100, 3), + columns=['A', 'B', 'C']) + + # Create scatterplot matrix using a list of 2 rgb tuples + # and endpoints at -1, 0 and 1 + fig = create_scatterplotmatrix(df, diag='histogram', index='A', + colormap=['rgb(140, 255, 50)', + 'rgb(170, 60, 115)', '#6c4774', + (0.5, 0.1, 0.8)], + endpts=[-1, 0, 1], height=800, width=800) + + # Plot + py.iplot(fig, filename = 'Scatterplot Matrix - Intervals') + ``` + + Example 6: Using the colormap as a Dictionary + ``` + import plotly.plotly as py + from plotly.graph_objs import graph_objs + from plotly.figure_factory import create_scatterplotmatrix + + import numpy as np + import pandas as pd + import random + + # Create dataframe with random data + df = pd.DataFrame(np.random.randn(100, 3), + columns=['Column A', + 'Column B', + 'Column C']) + + # Add new color column to dataframe + new_column = [] + strange_colors = ['turquoise', 'limegreen', 'goldenrod'] + + for j in range(100): + new_column.append(random.choice(strange_colors)) + df['Colors'] = pd.Series(new_column, index=df.index) + + # Create scatterplot matrix using a dictionary of hex color values + # which correspond to actual color names in 'Colors' column + fig = create_scatterplotmatrix( + df, diag='box', index='Colors', + colormap= dict( + turquoise = '#00F5FF', + limegreen = '#32CD32', + goldenrod = '#DAA520' + ), + colormap_type='cat', + height=800, width=800 + ) + + # Plot + py.iplot(fig, filename = 'Scatterplot Matrix - colormap dictionary ') + ``` + """ + # TODO: protected until #282 + if dataframe is None: + dataframe = [] + if headers is None: + headers = [] + if index_vals is None: + index_vals = [] + + validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs) + + # Validate colormap + if isinstance(colormap, dict): + colormap = utils.validate_colors_dict(colormap, 'rgb') + elif isinstance(colormap, six.string_types) and 'rgb' not in colormap and '#' not in colormap: + if colormap not in utils.PLOTLY_SCALES.keys(): + raise exceptions.PlotlyError( + "If 'colormap' is a string, it must be the name " + "of a Plotly Colorscale. The available colorscale " + "names are {}".format(utils.PLOTLY_SCALES.keys()) + ) + else: + # TODO change below to allow the correct Plotly colorscale + colormap = utils.colorscale_to_colors(utils.PLOTLY_SCALES[colormap]) + # keep only first and last item - fix later + colormap = [colormap[0]] + [colormap[-1]] + colormap = utils.validate_colors(colormap, 'rgb') + else: + colormap = utils.validate_colors(colormap, 'rgb') + + if not index: + for name in df: + headers.append(name) + for name in headers: + dataframe.append(df[name].values.tolist()) + # Check for same data-type in df columns + utils.validate_dataframe(dataframe) + figure = scatterplot(dataframe, headers, diag, size, height, width, + title, **kwargs) + return figure + else: + # Validate index selection + if index not in df: + raise exceptions.PlotlyError("Make sure you set the index " + "input variable to one of the " + "column names of your " + "dataframe.") + index_vals = df[index].values.tolist() + for name in df: + if name != index: + headers.append(name) + for name in headers: + dataframe.append(df[name].values.tolist()) + + # check for same data-type in each df column + utils.validate_dataframe(dataframe) + utils.validate_index(index_vals) + + # check if all colormap keys are in the index + # if colormap is a dictionary + if isinstance(colormap, dict): + for key in colormap: + if not all(index in colormap for index in index_vals): + raise exceptions.PlotlyError("If colormap is a " + "dictionary, all the " + "names in the index " + "must be keys.") + figure = scatterplot_dict( + dataframe, headers, diag, size, height, width, title, + index, index_vals, endpts, colormap, colormap_type, + **kwargs + ) + return figure + + else: + figure = scatterplot_theme( + dataframe, headers, diag, size, height, width, title, + index, index_vals, endpts, colormap, colormap_type, + **kwargs + ) + return figure diff --git a/plotly/figure_factory/figure_factory/_streamline.py b/plotly/figure_factory/figure_factory/_streamline.py new file mode 100644 index 00000000000..a6773420f48 --- /dev/null +++ b/plotly/figure_factory/figure_factory/_streamline.py @@ -0,0 +1,415 @@ +from __future__ import absolute_import + +import math + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs + +np = optional_imports.get_module('numpy') + + +def validate_streamline(x, y): + """ + Streamline-specific validations + + Specifically, this checks that x and y are both evenly spaced, + and that the package numpy is available. + + See FigureFactory.create_streamline() for params + + :raises: (ImportError) If numpy is not available. + :raises: (PlotlyError) If x is not evenly spaced. + :raises: (PlotlyError) If y is not evenly spaced. + """ + if np is False: + raise ImportError("FigureFactory.create_streamline requires numpy") + for index in range(len(x) - 1): + if ((x[index + 1] - x[index]) - (x[1] - x[0])) > .0001: + raise exceptions.PlotlyError("x must be a 1 dimensional, " + "evenly spaced array") + for index in range(len(y) - 1): + if ((y[index + 1] - y[index]) - + (y[1] - y[0])) > .0001: + raise exceptions.PlotlyError("y must be a 1 dimensional, " + "evenly spaced array") + + +def create_streamline(x, y, u, v, density=1, angle=math.pi / 9, + arrow_scale=.09, **kwargs): + """ + Returns data for a streamline plot. + + :param (list|ndarray) x: 1 dimensional, evenly spaced list or array + :param (list|ndarray) y: 1 dimensional, evenly spaced list or array + :param (ndarray) u: 2 dimensional array + :param (ndarray) v: 2 dimensional array + :param (float|int) density: controls the density of streamlines in + plot. This is multiplied by 30 to scale similiarly to other + available streamline functions such as matplotlib. + Default = 1 + :param (angle in radians) angle: angle of arrowhead. Default = pi/9 + :param (float in [0,1]) arrow_scale: value to scale length of arrowhead + Default = .09 + :param kwargs: kwargs passed through plotly.graph_objs.Scatter + for more information on valid kwargs call + help(plotly.graph_objs.Scatter) + + :rtype (dict): returns a representation of streamline figure. + + Example 1: Plot simple streamline and increase arrow size + ``` + import plotly.plotly as py + from plotly.figure_factory import create_streamline + + import numpy as np + import math + + # Add data + x = np.linspace(-3, 3, 100) + y = np.linspace(-3, 3, 100) + Y, X = np.meshgrid(x, y) + u = -1 - X**2 + Y + v = 1 + X - Y**2 + u = u.T # Transpose + v = v.T # Transpose + + # Create streamline + fig = create_streamline(x, y, u, v, arrow_scale=.1) + + # Plot + py.plot(fig, filename='streamline') + ``` + + Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython + ``` + import plotly.plotly as py + from plotly.figure_factory import create_streamline + + import numpy as np + import math + + # Add data + N = 50 + x_start, x_end = -2.0, 2.0 + y_start, y_end = -1.0, 1.0 + x = np.linspace(x_start, x_end, N) + y = np.linspace(y_start, y_end, N) + X, Y = np.meshgrid(x, y) + ss = 5.0 + x_s, y_s = -1.0, 0.0 + + # Compute the velocity field on the mesh grid + u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2) + v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2) + + # Create streamline + fig = create_streamline(x, y, u_s, v_s, density=2, name='streamline') + + # Add source point + point = Scatter(x=[x_s], y=[y_s], mode='markers', + marker=Marker(size=14), name='source point') + + # Plot + fig['data'].append(point) + py.plot(fig, filename='streamline') + ``` + """ + utils.validate_equal_length(x, y) + utils.validate_equal_length(u, v) + validate_streamline(x, y) + utils.validate_positive_scalars(density=density, arrow_scale=arrow_scale) + + streamline_x, streamline_y = _Streamline(x, y, u, v, + density, angle, + arrow_scale).sum_streamlines() + arrow_x, arrow_y = _Streamline(x, y, u, v, + density, angle, + arrow_scale).get_streamline_arrows() + + streamline = graph_objs.Scatter(x=streamline_x + arrow_x, + y=streamline_y + arrow_y, + mode='lines', **kwargs) + + data = [streamline] + layout = graph_objs.Layout(hovermode='closest') + + return graph_objs.Figure(data=data, layout=layout) + + +class _Streamline(object): + """ + Refer to FigureFactory.create_streamline() for docstring + """ + def __init__(self, x, y, u, v, + density, angle, + arrow_scale, **kwargs): + self.x = np.array(x) + self.y = np.array(y) + self.u = np.array(u) + self.v = np.array(v) + self.angle = angle + self.arrow_scale = arrow_scale + self.density = int(30 * density) # Scale similarly to other functions + self.delta_x = self.x[1] - self.x[0] + self.delta_y = self.y[1] - self.y[0] + self.val_x = self.x + self.val_y = self.y + + # Set up spacing + self.blank = np.zeros((self.density, self.density)) + self.spacing_x = len(self.x) / float(self.density - 1) + self.spacing_y = len(self.y) / float(self.density - 1) + self.trajectories = [] + + # Rescale speed onto axes-coordinates + self.u = self.u / (self.x[-1] - self.x[0]) + self.v = self.v / (self.y[-1] - self.y[0]) + self.speed = np.sqrt(self.u ** 2 + self.v ** 2) + + # Rescale u and v for integrations. + self.u *= len(self.x) + self.v *= len(self.y) + self.st_x = [] + self.st_y = [] + self.get_streamlines() + streamline_x, streamline_y = self.sum_streamlines() + arrows_x, arrows_y = self.get_streamline_arrows() + + def blank_pos(self, xi, yi): + """ + Set up positions for trajectories to be used with rk4 function. + """ + return (int((xi / self.spacing_x) + 0.5), + int((yi / self.spacing_y) + 0.5)) + + def value_at(self, a, xi, yi): + """ + Set up for RK4 function, based on Bokeh's streamline code + """ + if isinstance(xi, np.ndarray): + self.x = xi.astype(np.int) + self.y = yi.astype(np.int) + else: + self.val_x = np.int(xi) + self.val_y = np.int(yi) + a00 = a[self.val_y, self.val_x] + a01 = a[self.val_y, self.val_x + 1] + a10 = a[self.val_y + 1, self.val_x] + a11 = a[self.val_y + 1, self.val_x + 1] + xt = xi - self.val_x + yt = yi - self.val_y + a0 = a00 * (1 - xt) + a01 * xt + a1 = a10 * (1 - xt) + a11 * xt + return a0 * (1 - yt) + a1 * yt + + def rk4_integrate(self, x0, y0): + """ + RK4 forward and back trajectories from the initial conditions. + + Adapted from Bokeh's streamline -uses Runge-Kutta method to fill + x and y trajectories then checks length of traj (s in units of axes) + """ + def f(xi, yi): + dt_ds = 1. / self.value_at(self.speed, xi, yi) + ui = self.value_at(self.u, xi, yi) + vi = self.value_at(self.v, xi, yi) + return ui * dt_ds, vi * dt_ds + + def g(xi, yi): + dt_ds = 1. / self.value_at(self.speed, xi, yi) + ui = self.value_at(self.u, xi, yi) + vi = self.value_at(self.v, xi, yi) + return -ui * dt_ds, -vi * dt_ds + + check = lambda xi, yi: (0 <= xi < len(self.x) - 1 and + 0 <= yi < len(self.y) - 1) + xb_changes = [] + yb_changes = [] + + def rk4(x0, y0, f): + ds = 0.01 + stotal = 0 + xi = x0 + yi = y0 + xb, yb = self.blank_pos(xi, yi) + xf_traj = [] + yf_traj = [] + while check(xi, yi): + xf_traj.append(xi) + yf_traj.append(yi) + try: + k1x, k1y = f(xi, yi) + k2x, k2y = f(xi + .5 * ds * k1x, yi + .5 * ds * k1y) + k3x, k3y = f(xi + .5 * ds * k2x, yi + .5 * ds * k2y) + k4x, k4y = f(xi + ds * k3x, yi + ds * k3y) + except IndexError: + break + xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6. + yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6. + if not check(xi, yi): + break + stotal += ds + new_xb, new_yb = self.blank_pos(xi, yi) + if new_xb != xb or new_yb != yb: + if self.blank[new_yb, new_xb] == 0: + self.blank[new_yb, new_xb] = 1 + xb_changes.append(new_xb) + yb_changes.append(new_yb) + xb = new_xb + yb = new_yb + else: + break + if stotal > 2: + break + return stotal, xf_traj, yf_traj + + sf, xf_traj, yf_traj = rk4(x0, y0, f) + sb, xb_traj, yb_traj = rk4(x0, y0, g) + stotal = sf + sb + x_traj = xb_traj[::-1] + xf_traj[1:] + y_traj = yb_traj[::-1] + yf_traj[1:] + + if len(x_traj) < 1: + return None + if stotal > .2: + initxb, inityb = self.blank_pos(x0, y0) + self.blank[inityb, initxb] = 1 + return x_traj, y_traj + else: + for xb, yb in zip(xb_changes, yb_changes): + self.blank[yb, xb] = 0 + return None + + def traj(self, xb, yb): + """ + Integrate trajectories + + :param (int) xb: results of passing xi through self.blank_pos + :param (int) xy: results of passing yi through self.blank_pos + + Calculate each trajectory based on rk4 integrate method. + """ + + if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density: + return + if self.blank[yb, xb] == 0: + t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y) + if t is not None: + self.trajectories.append(t) + + def get_streamlines(self): + """ + Get streamlines by building trajectory set. + """ + for indent in range(self.density // 2): + for xi in range(self.density - 2 * indent): + self.traj(xi + indent, indent) + self.traj(xi + indent, self.density - 1 - indent) + self.traj(indent, xi + indent) + self.traj(self.density - 1 - indent, xi + indent) + + self.st_x = [np.array(t[0]) * self.delta_x + self.x[0] for t in + self.trajectories] + self.st_y = [np.array(t[1]) * self.delta_y + self.y[0] for t in + self.trajectories] + + for index in range(len(self.st_x)): + self.st_x[index] = self.st_x[index].tolist() + self.st_x[index].append(np.nan) + + for index in range(len(self.st_y)): + self.st_y[index] = self.st_y[index].tolist() + self.st_y[index].append(np.nan) + + def get_streamline_arrows(self): + """ + Makes an arrow for each streamline. + + Gets angle of streamline at 1/3 mark and creates arrow coordinates + based off of user defined angle and arrow_scale. + + :param (array) st_x: x-values for all streamlines + :param (array) st_y: y-values for all streamlines + :param (angle in radians) angle: angle of arrowhead. Default = pi/9 + :param (float in [0,1]) arrow_scale: value to scale length of arrowhead + Default = .09 + :rtype (list, list) arrows_x: x-values to create arrowhead and + arrows_y: y-values to create arrowhead + """ + arrow_end_x = np.empty((len(self.st_x))) + arrow_end_y = np.empty((len(self.st_y))) + arrow_start_x = np.empty((len(self.st_x))) + arrow_start_y = np.empty((len(self.st_y))) + for index in range(len(self.st_x)): + arrow_end_x[index] = (self.st_x[index] + [int(len(self.st_x[index]) / 3)]) + arrow_start_x[index] = (self.st_x[index] + [(int(len(self.st_x[index]) / 3)) - 1]) + arrow_end_y[index] = (self.st_y[index] + [int(len(self.st_y[index]) / 3)]) + arrow_start_y[index] = (self.st_y[index] + [(int(len(self.st_y[index]) / 3)) - 1]) + + dif_x = arrow_end_x - arrow_start_x + dif_y = arrow_end_y - arrow_start_y + + orig_err = np.geterr() + np.seterr(divide='ignore', invalid='ignore') + streamline_ang = np.arctan(dif_y / dif_x) + np.seterr(**orig_err) + + + ang1 = streamline_ang + (self.angle) + ang2 = streamline_ang - (self.angle) + + seg1_x = np.cos(ang1) * self.arrow_scale + seg1_y = np.sin(ang1) * self.arrow_scale + seg2_x = np.cos(ang2) * self.arrow_scale + seg2_y = np.sin(ang2) * self.arrow_scale + + point1_x = np.empty((len(dif_x))) + point1_y = np.empty((len(dif_y))) + point2_x = np.empty((len(dif_x))) + point2_y = np.empty((len(dif_y))) + + for index in range(len(dif_x)): + if dif_x[index] >= 0: + point1_x[index] = arrow_end_x[index] - seg1_x[index] + point1_y[index] = arrow_end_y[index] - seg1_y[index] + point2_x[index] = arrow_end_x[index] - seg2_x[index] + point2_y[index] = arrow_end_y[index] - seg2_y[index] + else: + point1_x[index] = arrow_end_x[index] + seg1_x[index] + point1_y[index] = arrow_end_y[index] + seg1_y[index] + point2_x[index] = arrow_end_x[index] + seg2_x[index] + point2_y[index] = arrow_end_y[index] + seg2_y[index] + + space = np.empty((len(point1_x))) + space[:] = np.nan + + # Combine arrays into matrix + arrows_x = np.matrix([point1_x, arrow_end_x, point2_x, space]) + arrows_x = np.array(arrows_x) + arrows_x = arrows_x.flatten('F') + arrows_x = arrows_x.tolist() + + # Combine arrays into matrix + arrows_y = np.matrix([point1_y, arrow_end_y, point2_y, space]) + arrows_y = np.array(arrows_y) + arrows_y = arrows_y.flatten('F') + arrows_y = arrows_y.tolist() + + return arrows_x, arrows_y + + def sum_streamlines(self): + """ + Makes all streamlines readable as a single trace. + + :rtype (list, list): streamline_x: all x values for each streamline + combined into single list and streamline_y: all y values for each + streamline combined into single list + """ + streamline_x = sum(self.st_x, []) + streamline_y = sum(self.st_y, []) + return streamline_x, streamline_y diff --git a/plotly/figure_factory/figure_factory/_table.py b/plotly/figure_factory/figure_factory/_table.py new file mode 100644 index 00000000000..3bd48db905d --- /dev/null +++ b/plotly/figure_factory/figure_factory/_table.py @@ -0,0 +1,233 @@ +from __future__ import absolute_import + +from plotly import exceptions, optional_imports +from plotly.graph_objs import graph_objs + +pd = optional_imports.get_module('pandas') + + +def validate_table(table_text, font_colors): + """ + Table-specific validations + + Check that font_colors is supplied correctly (1, 3, or len(text) + colors). + + :raises: (PlotlyError) If font_colors is supplied incorretly. + + See FigureFactory.create_table() for params + """ + font_colors_len_options = [1, 3, len(table_text)] + if len(font_colors) not in font_colors_len_options: + raise exceptions.PlotlyError("Oops, font_colors should be a list " + "of length 1, 3 or len(text)") + + +def create_table(table_text, colorscale=None, font_colors=None, + index=False, index_title='', annotation_offset=.45, + height_constant=30, hoverinfo='none', **kwargs): + """ + BETA function that creates data tables + + :param (pandas.Dataframe | list[list]) text: data for table. + :param (str|list[list]) colorscale: Colorscale for table where the + color at value 0 is the header color, .5 is the first table color + and 1 is the second table color. (Set .5 and 1 to avoid the striped + table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'], + [1, '#ffffff']] + :param (list) font_colors: Color for fonts in table. Can be a single + color, three colors, or a color for each row in the table. + Default=['#000000'] (black text for the entire table) + :param (int) height_constant: Constant multiplied by # of rows to + create table height. Default=30. + :param (bool) index: Create (header-colored) index column index from + Pandas dataframe or list[0] for each list in text. Default=False. + :param (string) index_title: Title for index column. Default=''. + :param kwargs: kwargs passed through plotly.graph_objs.Heatmap. + These kwargs describe other attributes about the annotated Heatmap + trace such as the colorscale. For more information on valid kwargs + call help(plotly.graph_objs.Heatmap) + + Example 1: Simple Plotly Table + ``` + import plotly.plotly as py + from plotly.figure_factory import create_table + + text = [['Country', 'Year', 'Population'], + ['US', 2000, 282200000], + ['Canada', 2000, 27790000], + ['US', 2010, 309000000], + ['Canada', 2010, 34000000]] + + table = create_table(text) + py.iplot(table) + ``` + + Example 2: Table with Custom Coloring + ``` + import plotly.plotly as py + from plotly.figure_factory import create_table + + text = [['Country', 'Year', 'Population'], + ['US', 2000, 282200000], + ['Canada', 2000, 27790000], + ['US', 2010, 309000000], + ['Canada', 2010, 34000000]] + + table = create_table(text, + colorscale=[[0, '#000000'], + [.5, '#80beff'], + [1, '#cce5ff']], + font_colors=['#ffffff', '#000000', + '#000000']) + py.iplot(table) + ``` + Example 3: Simple Plotly Table with Pandas + ``` + import plotly.plotly as py + from plotly.figure_factory import create_table + + import pandas as pd + + df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t') + df_p = df[0:25] + + table_simple = create_table(df_p) + py.iplot(table_simple) + ``` + """ + + # Avoiding mutables in the call signature + colorscale = \ + colorscale if colorscale is not None else [[0, '#00083e'], + [.5, '#ededee'], + [1, '#ffffff']] + font_colors = font_colors if font_colors is not None else ['#ffffff', + '#000000', + '#000000'] + + validate_table(table_text, font_colors) + table_matrix = _Table(table_text, colorscale, font_colors, index, + index_title, annotation_offset, + **kwargs).get_table_matrix() + annotations = _Table(table_text, colorscale, font_colors, index, + index_title, annotation_offset, + **kwargs).make_table_annotations() + + trace = dict(type='heatmap', z=table_matrix, opacity=.75, + colorscale=colorscale, showscale=False, + hoverinfo=hoverinfo, **kwargs) + + data = [trace] + layout = dict(annotations=annotations, + height=len(table_matrix) * height_constant + 50, + margin=dict(t=0, b=0, r=0, l=0), + yaxis=dict(autorange='reversed', zeroline=False, + gridwidth=2, ticks='', dtick=1, tick0=.5, + showticklabels=False), + xaxis=dict(zeroline=False, gridwidth=2, ticks='', + dtick=1, tick0=-0.5, showticklabels=False)) + return graph_objs.Figure(data=data, layout=layout) + + +class _Table(object): + """ + Refer to TraceFactory.create_table() for docstring + """ + def __init__(self, table_text, colorscale, font_colors, index, + index_title, annotation_offset, **kwargs): + if pd and isinstance(table_text, pd.DataFrame): + headers = table_text.columns.tolist() + table_text_index = table_text.index.tolist() + table_text = table_text.values.tolist() + table_text.insert(0, headers) + if index: + table_text_index.insert(0, index_title) + for i in range(len(table_text)): + table_text[i].insert(0, table_text_index[i]) + self.table_text = table_text + self.colorscale = colorscale + self.font_colors = font_colors + self.index = index + self.annotation_offset = annotation_offset + self.x = range(len(table_text[0])) + self.y = range(len(table_text)) + + def get_table_matrix(self): + """ + Create z matrix to make heatmap with striped table coloring + + :rtype (list[list]) table_matrix: z matrix to make heatmap with striped + table coloring. + """ + header = [0] * len(self.table_text[0]) + odd_row = [.5] * len(self.table_text[0]) + even_row = [1] * len(self.table_text[0]) + table_matrix = [None] * len(self.table_text) + table_matrix[0] = header + for i in range(1, len(self.table_text), 2): + table_matrix[i] = odd_row + for i in range(2, len(self.table_text), 2): + table_matrix[i] = even_row + if self.index: + for array in table_matrix: + array[0] = 0 + return table_matrix + + def get_table_font_color(self): + """ + Fill font-color array. + + Table text color can vary by row so this extends a single color or + creates an array to set a header color and two alternating colors to + create the striped table pattern. + + :rtype (list[list]) all_font_colors: list of font colors for each row + in table. + """ + if len(self.font_colors) == 1: + all_font_colors = self.font_colors*len(self.table_text) + elif len(self.font_colors) == 3: + all_font_colors = list(range(len(self.table_text))) + all_font_colors[0] = self.font_colors[0] + for i in range(1, len(self.table_text), 2): + all_font_colors[i] = self.font_colors[1] + for i in range(2, len(self.table_text), 2): + all_font_colors[i] = self.font_colors[2] + elif len(self.font_colors) == len(self.table_text): + all_font_colors = self.font_colors + else: + all_font_colors = ['#000000']*len(self.table_text) + return all_font_colors + + def make_table_annotations(self): + """ + Generate annotations to fill in table text + + :rtype (list) annotations: list of annotations for each cell of the + table. + """ + table_matrix = _Table.get_table_matrix(self) + all_font_colors = _Table.get_table_font_color(self) + annotations = [] + for n, row in enumerate(self.table_text): + for m, val in enumerate(row): + # Bold text in header and index + format_text = ('' + str(val) + '' if n == 0 or + self.index and m < 1 else str(val)) + # Match font color of index to font color of header + font_color = (self.font_colors[0] if self.index and m == 0 + else all_font_colors[n]) + annotations.append( + graph_objs.layout.Annotation( + text=format_text, + x=self.x[m] - self.annotation_offset, + y=self.y[n], + xref='x1', + yref='y1', + align="left", + xanchor="left", + font=dict(color=font_color), + showarrow=False) + ) + return annotations diff --git a/plotly/figure_factory/figure_factory/_trisurf.py b/plotly/figure_factory/figure_factory/_trisurf.py new file mode 100644 index 00000000000..6899a8484b7 --- /dev/null +++ b/plotly/figure_factory/figure_factory/_trisurf.py @@ -0,0 +1,489 @@ +from __future__ import absolute_import + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs + +np = optional_imports.get_module('numpy') + + +def map_face2color(face, colormap, scale, vmin, vmax): + """ + Normalize facecolor values by vmin/vmax and return rgb-color strings + + This function takes a tuple color along with a colormap and a minimum + (vmin) and maximum (vmax) range of possible mean distances for the + given parametrized surface. It returns an rgb color based on the mean + distance between vmin and vmax + + """ + if vmin >= vmax: + raise exceptions.PlotlyError("Incorrect relation between vmin " + "and vmax. The vmin value cannot be " + "bigger than or equal to the value " + "of vmax.") + if len(colormap) == 1: + # color each triangle face with the same color in colormap + face_color = colormap[0] + face_color = utils.convert_to_RGB_255(face_color) + face_color = utils.label_rgb(face_color) + return face_color + if face == vmax: + # pick last color in colormap + face_color = colormap[-1] + face_color = utils.convert_to_RGB_255(face_color) + face_color = utils.label_rgb(face_color) + return face_color + else: + if scale is None: + # find the normalized distance t of a triangle face between + # vmin and vmax where the distance is between 0 and 1 + t = (face - vmin) / float((vmax - vmin)) + low_color_index = int(t / (1./(len(colormap) - 1))) + + face_color = utils.find_intermediate_color( + colormap[low_color_index], + colormap[low_color_index + 1], + t * (len(colormap) - 1) - low_color_index + ) + + face_color = utils.convert_to_RGB_255(face_color) + face_color = utils.label_rgb(face_color) + else: + # find the face color for a non-linearly interpolated scale + t = (face - vmin) / float((vmax - vmin)) + + low_color_index = 0 + for k in range(len(scale) - 1): + if scale[k] <= t < scale[k+1]: + break + low_color_index += 1 + + low_scale_val = scale[low_color_index] + high_scale_val = scale[low_color_index + 1] + + face_color = utils.find_intermediate_color( + colormap[low_color_index], + colormap[low_color_index + 1], + (t - low_scale_val)/(high_scale_val - low_scale_val) + ) + + face_color = utils.convert_to_RGB_255(face_color) + face_color = utils.label_rgb(face_color) + return face_color + + +def trisurf(x, y, z, simplices, show_colorbar, edges_color, scale, + colormap=None, color_func=None, plot_edges=False, x_edge=None, + y_edge=None, z_edge=None, facecolor=None): + """ + Refer to FigureFactory.create_trisurf() for docstring + """ + # numpy import check + if not np: + raise ImportError("FigureFactory._trisurf() requires " + "numpy imported.") + points3D = np.vstack((x, y, z)).T + simplices = np.atleast_2d(simplices) + + # vertices of the surface triangles + tri_vertices = points3D[simplices] + + # Define colors for the triangle faces + if color_func is None: + # mean values of z-coordinates of triangle vertices + mean_dists = tri_vertices[:, :, 2].mean(-1) + elif isinstance(color_func, (list, np.ndarray)): + # Pre-computed list / array of values to map onto color + if len(color_func) != len(simplices): + raise ValueError("If color_func is a list/array, it must " + "be the same length as simplices.") + + # convert all colors in color_func to rgb + for index in range(len(color_func)): + if isinstance(color_func[index], str): + if '#' in color_func[index]: + foo = utils.hex_to_rgb(color_func[index]) + color_func[index] = utils.label_rgb(foo) + + if isinstance(color_func[index], tuple): + foo = utils.convert_to_RGB_255(color_func[index]) + color_func[index] = utils.label_rgb(foo) + + mean_dists = np.asarray(color_func) + else: + # apply user inputted function to calculate + # custom coloring for triangle vertices + mean_dists = [] + for triangle in tri_vertices: + dists = [] + for vertex in triangle: + dist = color_func(vertex[0], vertex[1], vertex[2]) + dists.append(dist) + mean_dists.append(np.mean(dists)) + mean_dists = np.asarray(mean_dists) + + # Check if facecolors are already strings and can be skipped + if isinstance(mean_dists[0], str): + facecolor = mean_dists + else: + min_mean_dists = np.min(mean_dists) + max_mean_dists = np.max(mean_dists) + + if facecolor is None: + facecolor = [] + for index in range(len(mean_dists)): + color = map_face2color(mean_dists[index], colormap, scale, + min_mean_dists, max_mean_dists) + facecolor.append(color) + + # Make sure facecolor is a list so output is consistent across Pythons + facecolor = np.asarray(facecolor) + ii, jj, kk = simplices.T + + triangles = graph_objs.Mesh3d(x=x, y=y, z=z, facecolor=facecolor, + i=ii, j=jj, k=kk, name='') + + mean_dists_are_numbers = not isinstance(mean_dists[0], str) + + if mean_dists_are_numbers and show_colorbar is True: + # make a colorscale from the colors + colorscale = utils.make_colorscale(colormap, scale) + colorscale = utils.convert_colorscale_to_rgb(colorscale) + + colorbar = graph_objs.Scatter3d( + x=x[:1], + y=y[:1], + z=z[:1], + mode='markers', + marker=dict( + size=0.1, + color=[min_mean_dists, max_mean_dists], + colorscale=colorscale, + showscale=True), + hoverinfo='none', + showlegend=False + ) + + # the triangle sides are not plotted + if plot_edges is False: + if mean_dists_are_numbers and show_colorbar is True: + return [triangles, colorbar] + else: + return [triangles] + + # define the lists x_edge, y_edge and z_edge, of x, y, resp z + # coordinates of edge end points for each triangle + # None separates data corresponding to two consecutive triangles + is_none = [ii is None for ii in [x_edge, y_edge, z_edge]] + if any(is_none): + if not all(is_none): + raise ValueError("If any (x_edge, y_edge, z_edge) is None, " + "all must be None") + else: + x_edge = [] + y_edge = [] + z_edge = [] + + # Pull indices we care about, then add a None column to separate tris + ixs_triangles = [0, 1, 2, 0] + pull_edges = tri_vertices[:, ixs_triangles, :] + x_edge_pull = np.hstack([pull_edges[:, :, 0], + np.tile(None, [pull_edges.shape[0], 1])]) + y_edge_pull = np.hstack([pull_edges[:, :, 1], + np.tile(None, [pull_edges.shape[0], 1])]) + z_edge_pull = np.hstack([pull_edges[:, :, 2], + np.tile(None, [pull_edges.shape[0], 1])]) + + # Now unravel the edges into a 1-d vector for plotting + x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]]) + y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]]) + z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]]) + + if not (len(x_edge) == len(y_edge) == len(z_edge)): + raise exceptions.PlotlyError("The lengths of x_edge, y_edge and " + "z_edge are not the same.") + + # define the lines for plotting + lines = graph_objs.Scatter3d( + x=x_edge, y=y_edge, z=z_edge, mode='lines', + line=graph_objs.scatter3d.Line( + color=edges_color, + width=1.5 + ), + showlegend=False + ) + + if mean_dists_are_numbers and show_colorbar is True: + return [triangles, lines, colorbar] + else: + return [triangles, lines] + + +def create_trisurf(x, y, z, simplices, colormap=None, show_colorbar=True, + scale=None, color_func=None, title='Trisurf Plot', + plot_edges=True, showbackground=True, + backgroundcolor='rgb(230, 230, 230)', + gridcolor='rgb(255, 255, 255)', + zerolinecolor='rgb(255, 255, 255)', + edges_color='rgb(50, 50, 50)', + height=800, width=800, + aspectratio=None): + """ + Returns figure for a triangulated surface plot + + :param (array) x: data values of x in a 1D array + :param (array) y: data values of y in a 1D array + :param (array) z: data values of z in a 1D array + :param (array) simplices: an array of shape (ntri, 3) where ntri is + the number of triangles in the triangularization. Each row of the + array contains the indicies of the verticies of each triangle + :param (str|tuple|list) colormap: either a plotly scale name, an rgb + or hex color, a color tuple or a list of colors. An rgb color is + of the form 'rgb(x, y, z)' where x, y, z belong to the interval + [0, 255] and a color tuple is a tuple of the form (a, b, c) where + a, b and c belong to [0, 1]. If colormap is a list, it must + contain the valid color types aforementioned as its members + :param (bool) show_colorbar: determines if colorbar is visible + :param (list|array) scale: sets the scale values to be used if a non- + linearly interpolated colormap is desired. If left as None, a + linear interpolation between the colors will be excecuted + :param (function|list) color_func: The parameter that determines the + coloring of the surface. Takes either a function with 3 arguments + x, y, z or a list/array of color values the same length as + simplices. If None, coloring will only depend on the z axis + :param (str) title: title of the plot + :param (bool) plot_edges: determines if the triangles on the trisurf + are visible + :param (bool) showbackground: makes background in plot visible + :param (str) backgroundcolor: color of background. Takes a string of + the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive + :param (str) gridcolor: color of the gridlines besides the axes. Takes + a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255 + inclusive + :param (str) zerolinecolor: color of the axes. Takes a string of the + form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive + :param (str) edges_color: color of the edges, if plot_edges is True + :param (int|float) height: the height of the plot (in pixels) + :param (int|float) width: the width of the plot (in pixels) + :param (dict) aspectratio: a dictionary of the aspect ratio values for + the x, y and z axes. 'x', 'y' and 'z' take (int|float) values + + Example 1: Sphere + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.figure_factory import create_trisurf + from plotly.graph_objs import graph_objs + + # Make data for plot + u = np.linspace(0, 2*np.pi, 20) + v = np.linspace(0, np.pi, 20) + u,v = np.meshgrid(u,v) + u = u.flatten() + v = v.flatten() + + x = np.sin(v)*np.cos(u) + y = np.sin(v)*np.sin(u) + z = np.cos(v) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Create a figure + fig1 = create_trisurf(x=x, y=y, z=z, colormap="Rainbow", + simplices=simplices) + # Plot the data + py.iplot(fig1, filename='trisurf-plot-sphere') + ``` + + Example 2: Torus + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.figure_factory import create_trisurf + from plotly.graph_objs import graph_objs + + # Make data for plot + u = np.linspace(0, 2*np.pi, 20) + v = np.linspace(0, 2*np.pi, 20) + u,v = np.meshgrid(u,v) + u = u.flatten() + v = v.flatten() + + x = (3 + (np.cos(v)))*np.cos(u) + y = (3 + (np.cos(v)))*np.sin(u) + z = np.sin(v) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Create a figure + fig1 = create_trisurf(x=x, y=y, z=z, colormap="Viridis", + simplices=simplices) + # Plot the data + py.iplot(fig1, filename='trisurf-plot-torus') + ``` + + Example 3: Mobius Band + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.figure_factory import create_trisurf + from plotly.graph_objs import graph_objs + + # Make data for plot + u = np.linspace(0, 2*np.pi, 24) + v = np.linspace(-1, 1, 8) + u,v = np.meshgrid(u,v) + u = u.flatten() + v = v.flatten() + + tp = 1 + 0.5*v*np.cos(u/2.) + x = tp*np.cos(u) + y = tp*np.sin(u) + z = 0.5*v*np.sin(u/2.) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Create a figure + fig1 = create_trisurf(x=x, y=y, z=z, colormap=[(0.2, 0.4, 0.6), (1, 1, 1)], + simplices=simplices) + # Plot the data + py.iplot(fig1, filename='trisurf-plot-mobius-band') + ``` + + Example 4: Using a Custom Colormap Function with Light Cone + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + + import plotly.plotly as py + from plotly.figure_factory import create_trisurf + from plotly.graph_objs import graph_objs + + # Make data for plot + u=np.linspace(-np.pi, np.pi, 30) + v=np.linspace(-np.pi, np.pi, 30) + u,v=np.meshgrid(u,v) + u=u.flatten() + v=v.flatten() + + x = u + y = u*np.cos(v) + z = u*np.sin(v) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + # Define distance function + def dist_origin(x, y, z): + return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2) + + # Create a figure + fig1 = create_trisurf(x=x, y=y, z=z, + colormap=['#FFFFFF', '#E4FFFE', + '#A4F6F9', '#FF99FE', + '#BA52ED'], + scale=[0, 0.6, 0.71, 0.89, 1], + simplices=simplices, + color_func=dist_origin) + # Plot the data + py.iplot(fig1, filename='trisurf-plot-custom-coloring') + ``` + + Example 5: Enter color_func as a list of colors + ``` + # Necessary Imports for Trisurf + import numpy as np + from scipy.spatial import Delaunay + import random + + import plotly.plotly as py + from plotly.figure_factory import create_trisurf + from plotly.graph_objs import graph_objs + + # Make data for plot + u=np.linspace(-np.pi, np.pi, 30) + v=np.linspace(-np.pi, np.pi, 30) + u,v=np.meshgrid(u,v) + u=u.flatten() + v=v.flatten() + + x = u + y = u*np.cos(v) + z = u*np.sin(v) + + points2D = np.vstack([u,v]).T + tri = Delaunay(points2D) + simplices = tri.simplices + + + colors = [] + color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd'] + + for index in range(len(simplices)): + colors.append(random.choice(color_choices)) + + fig = create_trisurf( + x, y, z, simplices, + color_func=colors, + show_colorbar=True, + edges_color='rgb(2, 85, 180)', + title=' Modern Art' + ) + + py.iplot(fig, filename="trisurf-plot-modern-art") + ``` + """ + if aspectratio is None: + aspectratio = {'x': 1, 'y': 1, 'z': 1} + + # Validate colormap + utils.validate_colors(colormap) + colormap, scale = utils.convert_colors_to_same_type( + colormap, colortype='tuple', + return_default_colors=True, scale=scale + ) + + data1 = trisurf(x, y, z, simplices, show_colorbar=show_colorbar, + color_func=color_func, colormap=colormap, scale=scale, + edges_color=edges_color, plot_edges=plot_edges) + + axis = dict( + showbackground=showbackground, + backgroundcolor=backgroundcolor, + gridcolor=gridcolor, + zerolinecolor=zerolinecolor, + ) + layout = graph_objs.Layout( + title=title, + width=width, + height=height, + scene=graph_objs.layout.Scene( + xaxis=graph_objs.layout.scene.XAxis(**axis), + yaxis=graph_objs.layout.scene.YAxis(**axis), + zaxis=graph_objs.layout.scene.ZAxis(**axis), + aspectratio=dict( + x=aspectratio['x'], + y=aspectratio['y'], + z=aspectratio['z']), + ) + ) + + return graph_objs.Figure(data=data1, layout=layout) diff --git a/plotly/figure_factory/figure_factory/_violin.py b/plotly/figure_factory/figure_factory/_violin.py new file mode 100644 index 00000000000..384f3497a15 --- /dev/null +++ b/plotly/figure_factory/figure_factory/_violin.py @@ -0,0 +1,644 @@ +from __future__ import absolute_import + +from numbers import Number + +from plotly import exceptions, optional_imports +from plotly.figure_factory import utils +from plotly.graph_objs import graph_objs +from plotly.tools import make_subplots + +pd = optional_imports.get_module('pandas') +np = optional_imports.get_module('numpy') +scipy_stats = optional_imports.get_module('scipy.stats') + + +def calc_stats(data): + """ + Calculate statistics for use in violin plot. + """ + x = np.asarray(data, np.float) + vals_min = np.min(x) + vals_max = np.max(x) + q2 = np.percentile(x, 50, interpolation='linear') + q1 = np.percentile(x, 25, interpolation='lower') + q3 = np.percentile(x, 75, interpolation='higher') + iqr = q3 - q1 + whisker_dist = 1.5 * iqr + + # in order to prevent drawing whiskers outside the interval + # of data one defines the whisker positions as: + d1 = np.min(x[x >= (q1 - whisker_dist)]) + d2 = np.max(x[x <= (q3 + whisker_dist)]) + return { + 'min': vals_min, + 'max': vals_max, + 'q1': q1, + 'q2': q2, + 'q3': q3, + 'd1': d1, + 'd2': d2 + } + + +def make_half_violin(x, y, fillcolor='#1f77b4', linecolor='rgb(0, 0, 0)'): + """ + Produces a sideways probability distribution fig violin plot. + """ + text = ['(pdf(y), y)=(' + '{:0.2f}'.format(x[i]) + + ', ' + '{:0.2f}'.format(y[i]) + ')' + for i in range(len(x))] + + return graph_objs.Scatter( + x=x, + y=y, + mode='lines', + name='', + text=text, + fill='tonextx', + fillcolor=fillcolor, + line=graph_objs.scatter.Line(width=0.5, color=linecolor, shape='spline'), + hoverinfo='text', + opacity=0.5 + ) + + +def make_violin_rugplot(vals, pdf_max, distance, color='#1f77b4'): + """ + Returns a rugplot fig for a violin plot. + """ + return graph_objs.Scatter( + y=vals, + x=[-pdf_max-distance]*len(vals), + marker=graph_objs.scatter.Marker( + color=color, + symbol='line-ew-open' + ), + mode='markers', + name='', + showlegend=False, + hoverinfo='y' + ) + + +def make_non_outlier_interval(d1, d2): + """ + Returns the scatterplot fig of most of a violin plot. + """ + return graph_objs.Scatter( + x=[0, 0], + y=[d1, d2], + name='', + mode='lines', + line=graph_objs.scatter.Line(width=1.5, + color='rgb(0,0,0)') + ) + + +def make_quartiles(q1, q3): + """ + Makes the upper and lower quartiles for a violin plot. + """ + return graph_objs.Scatter( + x=[0, 0], + y=[q1, q3], + text=['lower-quartile: ' + '{:0.2f}'.format(q1), + 'upper-quartile: ' + '{:0.2f}'.format(q3)], + mode='lines', + line=graph_objs.scatter.Line( + width=4, + color='rgb(0,0,0)' + ), + hoverinfo='text' + ) + + +def make_median(q2): + """ + Formats the 'median' hovertext for a violin plot. + """ + return graph_objs.Scatter( + x=[0], + y=[q2], + text=['median: ' + '{:0.2f}'.format(q2)], + mode='markers', + marker=dict(symbol='square', + color='rgb(255,255,255)'), + hoverinfo='text' + ) + + +def make_XAxis(xaxis_title, xaxis_range): + """ + Makes the x-axis for a violin plot. + """ + xaxis = graph_objs.layout.XAxis(title=xaxis_title, + range=xaxis_range, + showgrid=False, + zeroline=False, + showline=False, + mirror=False, + ticks='', + showticklabels=False) + return xaxis + + +def make_YAxis(yaxis_title): + """ + Makes the y-axis for a violin plot. + """ + yaxis = graph_objs.layout.YAxis(title=yaxis_title, + showticklabels=True, + autorange=True, + ticklen=4, + showline=True, + zeroline=False, + showgrid=False, + mirror=False) + return yaxis + + +def violinplot(vals, fillcolor='#1f77b4', rugplot=True): + """ + Refer to FigureFactory.create_violin() for docstring. + """ + vals = np.asarray(vals, np.float) + # summary statistics + vals_min = calc_stats(vals)['min'] + vals_max = calc_stats(vals)['max'] + q1 = calc_stats(vals)['q1'] + q2 = calc_stats(vals)['q2'] + q3 = calc_stats(vals)['q3'] + d1 = calc_stats(vals)['d1'] + d2 = calc_stats(vals)['d2'] + + # kernel density estimation of pdf + pdf = scipy_stats.gaussian_kde(vals) + # grid over the data interval + xx = np.linspace(vals_min, vals_max, 100) + # evaluate the pdf at the grid xx + yy = pdf(xx) + max_pdf = np.max(yy) + # distance from the violin plot to rugplot + distance = (2.0 * max_pdf)/10 if rugplot else 0 + # range for x values in the plot + plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1] + plot_data = [make_half_violin(-yy, xx, fillcolor=fillcolor), + make_half_violin(yy, xx, fillcolor=fillcolor), + make_non_outlier_interval(d1, d2), + make_quartiles(q1, q3), + make_median(q2)] + if rugplot: + plot_data.append(make_violin_rugplot(vals, max_pdf, distance=distance, + color=fillcolor)) + return plot_data, plot_xrange + + +def violin_no_colorscale(data, data_header, group_header, colors, + use_colorscale, group_stats, rugplot, sort, + height, width, title): + """ + Refer to FigureFactory.create_violin() for docstring. + + Returns fig for violin plot without colorscale. + + """ + + # collect all group names + group_name = [] + for name in data[group_header]: + if name not in group_name: + group_name.append(name) + if sort: + group_name.sort() + + gb = data.groupby([group_header]) + L = len(group_name) + + fig = make_subplots(rows=1, cols=L, + shared_yaxes=True, + horizontal_spacing=0.025, + print_grid=False) + color_index = 0 + for k, gr in enumerate(group_name): + vals = np.asarray(gb.get_group(gr)[data_header], np.float) + if color_index >= len(colors): + color_index = 0 + plot_data, plot_xrange = violinplot(vals, + fillcolor=colors[color_index], + rugplot=rugplot) + layout = graph_objs.Layout() + + for item in plot_data: + fig.append_trace(item, 1, k + 1) + color_index += 1 + + # add violin plot labels + fig['layout'].update( + {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)} + ) + + # set the sharey axis style + fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')}) + fig['layout'].update( + title=title, + showlegend=False, + hovermode='closest', + autosize=False, + height=height, + width=width + ) + + return fig + + +def violin_colorscale(data, data_header, group_header, colors, use_colorscale, + group_stats, rugplot, sort, height, width, + title): + """ + Refer to FigureFactory.create_violin() for docstring. + + Returns fig for violin plot with colorscale. + + """ + + # collect all group names + group_name = [] + for name in data[group_header]: + if name not in group_name: + group_name.append(name) + if sort: + group_name.sort() + + # make sure all group names are keys in group_stats + for group in group_name: + if group not in group_stats: + raise exceptions.PlotlyError("All values/groups in the index " + "column must be represented " + "as a key in group_stats.") + + gb = data.groupby([group_header]) + L = len(group_name) + + fig = make_subplots(rows=1, cols=L, + shared_yaxes=True, + horizontal_spacing=0.025, + print_grid=False) + + # prepare low and high color for colorscale + lowcolor = utils.color_parser(colors[0], utils.unlabel_rgb) + highcolor = utils.color_parser(colors[1], utils.unlabel_rgb) + + # find min and max values in group_stats + group_stats_values = [] + for key in group_stats: + group_stats_values.append(group_stats[key]) + + max_value = max(group_stats_values) + min_value = min(group_stats_values) + + for k, gr in enumerate(group_name): + vals = np.asarray(gb.get_group(gr)[data_header], np.float) + + # find intermediate color from colorscale + intermed = (group_stats[gr] - min_value) / (max_value - min_value) + intermed_color = utils.find_intermediate_color( + lowcolor, highcolor, intermed + ) + + plot_data, plot_xrange = violinplot( + vals, + fillcolor='rgb{}'.format(intermed_color), + rugplot=rugplot + ) + layout = graph_objs.Layout() + + for item in plot_data: + fig.append_trace(item, 1, k + 1) + fig['layout'].update( + {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)} + ) + # add colorbar to plot + trace_dummy = graph_objs.Scatter( + x=[0], + y=[0], + mode='markers', + marker=dict( + size=2, + cmin=min_value, + cmax=max_value, + colorscale=[[0, colors[0]], + [1, colors[1]]], + showscale=True), + showlegend=False, + ) + fig.append_trace(trace_dummy, 1, L) + + # set the sharey axis style + fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')}) + fig['layout'].update( + title=title, + showlegend=False, + hovermode='closest', + autosize=False, + height=height, + width=width + ) + + return fig + + +def violin_dict(data, data_header, group_header, colors, use_colorscale, + group_stats, rugplot, sort, height, width, title): + """ + Refer to FigureFactory.create_violin() for docstring. + + Returns fig for violin plot without colorscale. + + """ + + # collect all group names + group_name = [] + for name in data[group_header]: + if name not in group_name: + group_name.append(name) + + if sort: + group_name.sort() + + # check if all group names appear in colors dict + for group in group_name: + if group not in colors: + raise exceptions.PlotlyError("If colors is a dictionary, all " + "the group names must appear as " + "keys in colors.") + + gb = data.groupby([group_header]) + L = len(group_name) + + fig = make_subplots(rows=1, cols=L, + shared_yaxes=True, + horizontal_spacing=0.025, + print_grid=False) + + for k, gr in enumerate(group_name): + vals = np.asarray(gb.get_group(gr)[data_header], np.float) + plot_data, plot_xrange = violinplot(vals, fillcolor=colors[gr], + rugplot=rugplot) + layout = graph_objs.Layout() + + for item in plot_data: + fig.append_trace(item, 1, k + 1) + + # add violin plot labels + fig['layout'].update( + {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)} + ) + + # set the sharey axis style + fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')}) + fig['layout'].update( + title=title, + showlegend=False, + hovermode='closest', + autosize=False, + height=height, + width=width + ) + + return fig + + +def create_violin(data, data_header=None, group_header=None, colors=None, + use_colorscale=False, group_stats=None, rugplot=True, + sort=False, height=450, width=600, + title='Violin and Rug Plot'): + """ + Returns figure for a violin plot + + :param (list|array) data: accepts either a list of numerical values, + a list of dictionaries all with identical keys and at least one + column of numeric values, or a pandas dataframe with at least one + column of numbers. + :param (str) data_header: the header of the data column to be used + from an inputted pandas dataframe. Not applicable if 'data' is + a list of numeric values. + :param (str) group_header: applicable if grouping data by a variable. + 'group_header' must be set to the name of the grouping variable. + :param (str|tuple|list|dict) colors: either a plotly scale name, + an rgb or hex color, a color tuple, a list of colors or a + dictionary. An rgb color is of the form 'rgb(x, y, z)' where + x, y and z belong to the interval [0, 255] and a color tuple is a + tuple of the form (a, b, c) where a, b and c belong to [0, 1]. + If colors is a list, it must contain valid color types as its + members. + :param (bool) use_colorscale: only applicable if grouping by another + variable. Will implement a colorscale based on the first 2 colors + of param colors. This means colors must be a list with at least 2 + colors in it (Plotly colorscales are accepted since they map to a + list of two rgb colors). Default = False + :param (dict) group_stats: a dictioanry where each key is a unique + value from the group_header column in data. Each value must be a + number and will be used to color the violin plots if a colorscale + is being used. + :param (bool) rugplot: determines if a rugplot is draw on violin plot. + Default = True + :param (bool) sort: determines if violins are sorted + alphabetically (True) or by input order (False). Default = False + :param (float) height: the height of the violin plot. + :param (float) width: the width of the violin plot. + :param (str) title: the title of the violin plot. + + Example 1: Single Violin Plot + ``` + import plotly.plotly as py + from plotly.figure_factory import create_violin + from plotly.graph_objs import graph_objs + + import numpy as np + from scipy import stats + + # create list of random values + data_list = np.random.randn(100) + data_list.tolist() + + # create violin fig + fig = create_violin(data_list, colors='#604d9e') + + # plot + py.iplot(fig, filename='Violin Plot') + ``` + + Example 2: Multiple Violin Plots with Qualitative Coloring + ``` + import plotly.plotly as py + from plotly.figure_factory import create_violin + from plotly.graph_objs import graph_objs + + import numpy as np + import pandas as pd + from scipy import stats + + # create dataframe + np.random.seed(619517) + Nr=250 + y = np.random.randn(Nr) + gr = np.random.choice(list("ABCDE"), Nr) + norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)] + + for i, letter in enumerate("ABCDE"): + y[gr == letter] *=norm_params[i][1]+ norm_params[i][0] + df = pd.DataFrame(dict(Score=y, Group=gr)) + + # create violin fig + fig = create_violin(df, data_header='Score', group_header='Group', + sort=True, height=600, width=1000) + + # plot + py.iplot(fig, filename='Violin Plot with Coloring') + ``` + + Example 3: Violin Plots with Colorscale + ``` + import plotly.plotly as py + from plotly.figure_factory import create_violin + from plotly.graph_objs import graph_objs + + import numpy as np + import pandas as pd + from scipy import stats + + # create dataframe + np.random.seed(619517) + Nr=250 + y = np.random.randn(Nr) + gr = np.random.choice(list("ABCDE"), Nr) + norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)] + + for i, letter in enumerate("ABCDE"): + y[gr == letter] *=norm_params[i][1]+ norm_params[i][0] + df = pd.DataFrame(dict(Score=y, Group=gr)) + + # define header params + data_header = 'Score' + group_header = 'Group' + + # make groupby object with pandas + group_stats = {} + groupby_data = df.groupby([group_header]) + + for group in "ABCDE": + data_from_group = groupby_data.get_group(group)[data_header] + # take a stat of the grouped data + stat = np.median(data_from_group) + # add to dictionary + group_stats[group] = stat + + # create violin fig + fig = create_violin(df, data_header='Score', group_header='Group', + height=600, width=1000, use_colorscale=True, + group_stats=group_stats) + + # plot + py.iplot(fig, filename='Violin Plot with Colorscale') + ``` + """ + + # Validate colors + if isinstance(colors, dict): + valid_colors = utils.validate_colors_dict(colors, 'rgb') + else: + valid_colors = utils.validate_colors(colors, 'rgb') + + # validate data and choose plot type + if group_header is None: + if isinstance(data, list): + if len(data) <= 0: + raise exceptions.PlotlyError("If data is a list, it must be " + "nonempty and contain either " + "numbers or dictionaries.") + + if not all(isinstance(element, Number) for element in data): + raise exceptions.PlotlyError("If data is a list, it must " + "contain only numbers.") + + if pd and isinstance(data, pd.core.frame.DataFrame): + if data_header is None: + raise exceptions.PlotlyError("data_header must be the " + "column name with the " + "desired numeric data for " + "the violin plot.") + + data = data[data_header].values.tolist() + + # call the plotting functions + plot_data, plot_xrange = violinplot(data, fillcolor=valid_colors[0], + rugplot=rugplot) + + layout = graph_objs.Layout( + title=title, + autosize=False, + font=graph_objs.layout.Font(size=11), + height=height, + showlegend=False, + width=width, + xaxis=make_XAxis('', plot_xrange), + yaxis=make_YAxis(''), + hovermode='closest' + ) + layout['yaxis'].update(dict(showline=False, + showticklabels=False, + ticks='')) + + fig = graph_objs.Figure(data=plot_data, + layout=layout) + + return fig + + else: + if not isinstance(data, pd.core.frame.DataFrame): + raise exceptions.PlotlyError("Error. You must use a pandas " + "DataFrame if you are using a " + "group header.") + + if data_header is None: + raise exceptions.PlotlyError("data_header must be the column " + "name with the desired numeric " + "data for the violin plot.") + + if use_colorscale is False: + if isinstance(valid_colors, dict): + # validate colors dict choice below + fig = violin_dict( + data, data_header, group_header, valid_colors, + use_colorscale, group_stats, rugplot, sort, + height, width, title + ) + return fig + else: + fig = violin_no_colorscale( + data, data_header, group_header, valid_colors, + use_colorscale, group_stats, rugplot, sort, + height, width, title + ) + return fig + else: + if isinstance(valid_colors, dict): + raise exceptions.PlotlyError("The colors param cannot be " + "a dictionary if you are " + "using a colorscale.") + + if len(valid_colors) < 2: + raise exceptions.PlotlyError("colors must be a list with " + "at least 2 colors. A " + "Plotly scale is allowed.") + + if not isinstance(group_stats, dict): + raise exceptions.PlotlyError("Your group_stats param " + "must be a dictionary.") + + fig = violin_colorscale( + data, data_header, group_header, valid_colors, + use_colorscale, group_stats, rugplot, sort, height, + width, title + ) + return fig diff --git a/plotly/figure_factory/utils.py b/plotly/figure_factory/utils.py index ee19e17cef9..ea27270c46b 100644 --- a/plotly/figure_factory/utils.py +++ b/plotly/figure_factory/utils.py @@ -2,6 +2,8 @@ import collections import decimal +import six +from numbers import Number from plotly import exceptions @@ -12,26 +14,136 @@ 'rgb(227, 119, 194)', 'rgb(127, 127, 127)', 'rgb(188, 189, 34)', 'rgb(23, 190, 207)'] -# TODO: make PLOTLY_SCALES below like version in plotly.colors -# requires rewritting scatterplot_matrix code PLOTLY_SCALES = { - 'Greys': ['rgb(0,0,0)', 'rgb(255,255,255)'], - 'YlGnBu': ['rgb(8,29,88)', 'rgb(255,255,217)'], - 'Greens': ['rgb(0,68,27)', 'rgb(247,252,245)'], - 'YlOrRd': ['rgb(128,0,38)', 'rgb(255,255,204)'], - 'Bluered': ['rgb(0,0,255)', 'rgb(255,0,0)'], - 'RdBu': ['rgb(5,10,172)', 'rgb(178,10,28)'], - 'Reds': ['rgb(220,220,220)', 'rgb(178,10,28)'], - 'Blues': ['rgb(5,10,172)', 'rgb(220,220,220)'], - 'Picnic': ['rgb(0,0,255)', 'rgb(255,0,0)'], - 'Rainbow': ['rgb(150,0,90)', 'rgb(255,0,0)'], - 'Portland': ['rgb(12,51,131)', 'rgb(217,30,30)'], - 'Jet': ['rgb(0,0,131)', 'rgb(128,0,0)'], - 'Hot': ['rgb(0,0,0)', 'rgb(255,255,255)'], - 'Blackbody': ['rgb(0,0,0)', 'rgb(160,200,255)'], - 'Earth': ['rgb(0,0,130)', 'rgb(255,255,255)'], - 'Electric': ['rgb(0,0,0)', 'rgb(255,250,220)'], - 'Viridis': ['#440154', '#fde725'] + 'Greys': [ + [0, 'rgb(0,0,0)'], [1, 'rgb(255,255,255)'] + ], + + 'YlGnBu': [ + [0, 'rgb(8,29,88)'], [0.125, 'rgb(37,52,148)'], + [0.25, 'rgb(34,94,168)'], [0.375, 'rgb(29,145,192)'], + [0.5, 'rgb(65,182,196)'], [0.625, 'rgb(127,205,187)'], + [0.75, 'rgb(199,233,180)'], [0.875, 'rgb(237,248,217)'], + [1, 'rgb(255,255,217)'] + ], + + 'Greens': [ + [0, 'rgb(0,68,27)'], [0.125, 'rgb(0,109,44)'], + [0.25, 'rgb(35,139,69)'], [0.375, 'rgb(65,171,93)'], + [0.5, 'rgb(116,196,118)'], [0.625, 'rgb(161,217,155)'], + [0.75, 'rgb(199,233,192)'], [0.875, 'rgb(229,245,224)'], + [1, 'rgb(247,252,245)'] + ], + + 'YlOrRd': [ + [0, 'rgb(128,0,38)'], [0.125, 'rgb(189,0,38)'], + [0.25, 'rgb(227,26,28)'], [0.375, 'rgb(252,78,42)'], + [0.5, 'rgb(253,141,60)'], [0.625, 'rgb(254,178,76)'], + [0.75, 'rgb(254,217,118)'], [0.875, 'rgb(255,237,160)'], + [1, 'rgb(255,255,204)'] + ], + + 'Bluered': [ + [0, 'rgb(0,0,255)'], [1, 'rgb(255,0,0)'] + ], + + # modified RdBu based on + # www.sandia.gov/~kmorel/documents/ColorMaps/ColorMapsExpanded.pdf + 'RdBu': [ + [0, 'rgb(5,10,172)'], [0.35, 'rgb(106,137,247)'], + [0.5, 'rgb(190,190,190)'], [0.6, 'rgb(220,170,132)'], + [0.7, 'rgb(230,145,90)'], [1, 'rgb(178,10,28)'] + ], + + # Scale for non-negative numeric values + 'Reds': [ + [0, 'rgb(220,220,220)'], [0.2, 'rgb(245,195,157)'], + [0.4, 'rgb(245,160,105)'], [1, 'rgb(178,10,28)'] + ], + + # Scale for non-positive numeric values + 'Blues': [ + [0, 'rgb(5,10,172)'], [0.35, 'rgb(40,60,190)'], + [0.5, 'rgb(70,100,245)'], [0.6, 'rgb(90,120,245)'], + [0.7, 'rgb(106,137,247)'], [1, 'rgb(220,220,220)'] + ], + + 'Picnic': [ + [0, 'rgb(0,0,255)'], [0.1, 'rgb(51,153,255)'], + [0.2, 'rgb(102,204,255)'], [0.3, 'rgb(153,204,255)'], + [0.4, 'rgb(204,204,255)'], [0.5, 'rgb(255,255,255)'], + [0.6, 'rgb(255,204,255)'], [0.7, 'rgb(255,153,255)'], + [0.8, 'rgb(255,102,204)'], [0.9, 'rgb(255,102,102)'], + [1, 'rgb(255,0,0)'] + ], + + 'Rainbow': [ + [0, 'rgb(150,0,90)'], [0.125, 'rgb(0,0,200)'], + [0.25, 'rgb(0,25,255)'], [0.375, 'rgb(0,152,255)'], + [0.5, 'rgb(44,255,150)'], [0.625, 'rgb(151,255,0)'], + [0.75, 'rgb(255,234,0)'], [0.875, 'rgb(255,111,0)'], + [1, 'rgb(255,0,0)'] + ], + + 'Portland': [ + [0, 'rgb(12,51,131)'], [0.25, 'rgb(10,136,186)'], + [0.5, 'rgb(242,211,56)'], [0.75, 'rgb(242,143,56)'], + [1, 'rgb(217,30,30)'] + ], + + 'Jet': [ + [0, 'rgb(0,0,131)'], [0.125, 'rgb(0,60,170)'], + [0.375, 'rgb(5,255,255)'], [0.625, 'rgb(255,255,0)'], + [0.875, 'rgb(250,0,0)'], [1, 'rgb(128,0,0)'] + ], + + 'Hot': [ + [0, 'rgb(0,0,0)'], [0.3, 'rgb(230,0,0)'], + [0.6, 'rgb(255,210,0)'], [1, 'rgb(255,255,255)'] + ], + + 'Blackbody': [ + [0, 'rgb(0,0,0)'], [0.2, 'rgb(230,0,0)'], + [0.4, 'rgb(230,210,0)'], [0.7, 'rgb(255,255,255)'], + [1, 'rgb(160,200,255)'] + ], + + 'Earth': [ + [0, 'rgb(0,0,130)'], [0.1, 'rgb(0,180,180)'], + [0.2, 'rgb(40,210,40)'], [0.4, 'rgb(230,230,50)'], + [0.6, 'rgb(120,70,20)'], [1, 'rgb(255,255,255)'] + ], + + 'Electric': [ + [0, 'rgb(0,0,0)'], [0.15, 'rgb(30,0,100)'], + [0.4, 'rgb(120,0,100)'], [0.6, 'rgb(160,90,0)'], + [0.8, 'rgb(230,200,0)'], [1, 'rgb(255,250,220)'] + ], + + 'Viridis': [ + [0, '#440154'], [0.06274509803921569, '#48186a'], + [0.12549019607843137, '#472d7b'], [0.18823529411764706, '#424086'], + [0.25098039215686274, '#3b528b'], [0.3137254901960784, '#33638d'], + [0.3764705882352941, '#2c728e'], [0.4392156862745098, '#26828e'], + [0.5019607843137255, '#21918c'], [0.5647058823529412, '#1fa088'], + [0.6274509803921569, '#28ae80'], [0.6901960784313725, '#3fbc73'], + [0.7529411764705882, '#5ec962'], [0.8156862745098039, '#84d44b'], + [0.8784313725490196, '#addc30'], [0.9411764705882353, '#d8e219'], + [1, '#fde725'] + ], + + 'Cividis': [ + [0.000000, 'rgb(0,32,76)'], [0.058824, 'rgb(0,42,102)'], + [0.117647, 'rgb(0,52,110)'], [0.176471, 'rgb(39,63,108)'], + [0.235294, 'rgb(60,74,107)'], [0.294118, 'rgb(76,85,107)'], + [0.352941, 'rgb(91,95,109)'], [0.411765, 'rgb(104,106,112)'], + [0.470588, 'rgb(117,117,117)'], [0.529412, 'rgb(131,129,120)'], + [0.588235, 'rgb(146,140,120)'], [0.647059, 'rgb(161,152,118)'], + [0.705882, 'rgb(176,165,114)'], [0.764706, 'rgb(192,177,109)'], + [0.823529, 'rgb(209,191,102)'], [0.882353, 'rgb(225,204,92)'], + [0.941176, 'rgb(243,219,79)'], [1.000000, 'rgb(255,233,69)'] + ] + } @@ -47,7 +159,6 @@ def validate_index(index_vals): :raises: (PlotlyError) If there are any two items in the list whose types differ """ - from numbers import Number if isinstance(index_vals[0], Number): if not all(isinstance(item, Number) for item in index_vals): raise exceptions.PlotlyError("Error in indexing column. " @@ -70,7 +181,6 @@ def validate_dataframe(array): :raises: (PlotlyError) If there are any two items in any list whose types differ """ - from numbers import Number for vector in array: if isinstance(vector[0], Number): if not all(isinstance(item, Number) for item in vector): @@ -149,30 +259,39 @@ def find_intermediate_color(lowcolor, highcolor, intermed): lowcolor[2] + intermed * diff_2) -def n_colors(lowcolor, highcolor, n_colors): +def n_colors(lowcolor, highcolor, n_colors, colortype='tuple'): """ Splits a low and high color into a list of n_colors colors in it Accepts two color tuples and returns a list of n_colors colors which form the intermediate colors between lowcolor and highcolor - from linearly interpolating through RGB space - + from linearly interpolating through RGB space. If colortype is 'rgb' + the function will return a list of colors in the same form. """ + if colortype == 'rgb': + # convert to tuple + lowcolor = unlabel_rgb(lowcolor) + highcolor = unlabel_rgb(highcolor) + diff_0 = float(highcolor[0] - lowcolor[0]) incr_0 = diff_0/(n_colors - 1) diff_1 = float(highcolor[1] - lowcolor[1]) incr_1 = diff_1/(n_colors - 1) diff_2 = float(highcolor[2] - lowcolor[2]) incr_2 = diff_2/(n_colors - 1) - color_tuples = [] + list_of_colors = [] for index in range(n_colors): new_tuple = (lowcolor[0] + (index * incr_0), lowcolor[1] + (index * incr_1), lowcolor[2] + (index * incr_2)) - color_tuples.append(new_tuple) + list_of_colors.append(new_tuple) + + if colortype == 'rgb': + # back to an rgb string + list_of_colors = color_parser(list_of_colors, label_rgb) - return color_tuples + return list_of_colors def label_rgb(colors): @@ -274,7 +393,6 @@ def color_parser(colors, function): - rgb string, hex string or tuple """ - from numbers import Number if isinstance(colors, str): return function(colors) @@ -295,13 +413,13 @@ def validate_colors(colors, colortype='tuple'): """ Validates color(s) and returns a list of color(s) of a specified type """ - from numbers import Number if colors is None: colors = DEFAULT_PLOTLY_COLORS if isinstance(colors, str): if colors in PLOTLY_SCALES: - colors = PLOTLY_SCALES[colors] + colors_list = colorscale_to_colors(PLOTLY_SCALES[colors]) + colors = [colors_list[0]] + [colors_list[-1]] elif 'rgb' in colors or '#' in colors: colors = [colors] else: @@ -343,7 +461,7 @@ def validate_colors(colors, colortype='tuple'): ) colors[j] = each_color - if colortype == 'rgb': + if colortype == 'rgb' and not isinstance(colors, six.string_types): for j, each_color in enumerate(colors): rgb_color = color_parser(each_color, convert_to_RGB_255) colors[j] = color_parser(rgb_color, label_rgb) @@ -386,6 +504,173 @@ def validate_colors_dict(colors, colortype='tuple'): return colors + +def convert_colors_to_same_type(colors, colortype='rgb', scale=None, + return_default_colors=False, + num_of_defualt_colors=2): + """ + Converts color(s) to the specified color type + + Takes a single color or an iterable of colors, as well as a list of scale + values, and outputs a 2-pair of the list of color(s) converted all to an + rgb or tuple color type, aswell as the scale as the second element. If + colors is a Plotly Scale name, then 'scale' will be forced to the scale + from the respective colorscale and the colors in that colorscale will also + be coverted to the selected colortype. If colors is None, then there is an + option to return portion of the DEFAULT_PLOTLY_COLORS + + :param (str|tuple|list) colors: either a plotly scale name, an rgb or hex + color, a color tuple or a list/tuple of colors + :param (list) scale: see docs for validate_scale_values() + + :rtype (tuple) (colors_list, scale) if scale is None in the function call, + then scale will remain None in the returned tuple + """ + #if colors_list is None: + colors_list = [] + + if colors is None and return_default_colors is True: + colors_list = DEFAULT_PLOTLY_COLORS[0:num_of_defualt_colors] + + if isinstance(colors, str): + if colors in PLOTLY_SCALES: + colors_list = colorscale_to_colors(PLOTLY_SCALES[colors]) + if scale is None: + scale = colorscale_to_scale(PLOTLY_SCALES[colors]) + + elif 'rgb' in colors or '#' in colors: + colors_list = [colors] + + elif isinstance(colors, tuple): + if isinstance(colors[0], Number): + colors_list = [colors] + else: + colors_list = list(colors) + + elif isinstance(colors, list): + colors_list = colors + + # validate scale + if scale is not None: + validate_scale_values(scale) + + if len(colors_list) != len(scale): + raise exceptions.PlotlyError( + 'Make sure that the length of your scale matches the length ' + 'of your list of colors which is {}.'.format(len(colors_list)) + ) + + # convert all colors to rgb + for j, each_color in enumerate(colors_list): + if '#' in each_color: + each_color = color_parser( + each_color, hex_to_rgb + ) + each_color = color_parser( + each_color, label_rgb + ) + colors_list[j] = each_color + + elif isinstance(each_color, tuple): + each_color = color_parser( + each_color, convert_to_RGB_255 + ) + each_color = color_parser( + each_color, label_rgb + ) + colors_list[j] = each_color + + if colortype == 'rgb': + return (colors_list, scale) + elif colortype == 'tuple': + for j, each_color in enumerate(colors_list): + each_color = color_parser( + each_color, unlabel_rgb + ) + each_color = color_parser( + each_color, unconvert_from_RGB_255 + ) + colors_list[j] = each_color + return (colors_list, scale) + else: + raise exceptions.PlotlyError('You must select either rgb or tuple ' + 'for your colortype variable.') + + +def convert_dict_colors_to_same_type(colors_dict, colortype='rgb'): + """ + Converts a colors in a dictioanry of colors to the specified color type + + :param (dict) colors_dict: a dictioanry whose values are single colors + """ + for key in colors_dict: + if '#' in colors_dict[key]: + colors_dict[key] = color_parser( + colors_dict[key], hex_to_rgb + ) + colors_dict[key] = color_parser( + colors_dict[key], label_rgb + ) + + elif isinstance(colors_dict[key], tuple): + colors_dict[key] = color_parser( + colors_dict[key], convert_to_RGB_255 + ) + colors_dict[key] = color_parser( + colors_dict[key], label_rgb + ) + + if colortype == 'rgb': + return colors_dict + elif colortype == 'tuple': + for key in colors_dict: + colors_dict[key] = color_parser( + colors_dict[key], unlabel_rgb + ) + colors_dict[key] = color_parser( + colors_dict[key], unconvert_from_RGB_255 + ) + return colors_dict + else: + raise exceptions.PlotlyError('You must select either rgb or tuple ' + 'for your colortype variable.') + + +def make_colorscale(colors, scale=None): + """ + Makes a colorscale from a list of colors and a scale + + Takes a list of colors and scales and constructs a colorscale based + on the colors in sequential order. If 'scale' is left empty, a linear- + interpolated colorscale will be generated. If 'scale' is a specificed + list, it must be the same legnth as colors and must contain all floats + For documentation regarding to the form of the output, see + https://plot.ly/python/reference/#mesh3d-colorscale + + :param (list) colors: a list of single colors + """ + colorscale = [] + + # validate minimum colors length of 2 + if len(colors) < 2: + raise exceptions.PlotlyError('You must input a list of colors that ' + 'has at least two colors.') + + if scale is None: + scale_incr = 1./(len(colors) - 1) + return [[i * scale_incr, color] for i, color in enumerate(colors)] + + else: + if len(colors) != len(scale): + raise exceptions.PlotlyError('The length of colors and scale ' + 'must be the same.') + + validate_scale_values(scale) + + colorscale = [list(tup) for tup in zip(scale, colors)] + return colorscale + + def colorscale_to_colors(colorscale): """ Extracts the colors from colorscale as a list @@ -406,6 +691,23 @@ def colorscale_to_scale(colorscale): return scale_list +def convert_colorscale_to_rgb(colorscale): + """ + Converts the colors in a colorscale to rgb colors + + A colorscale is an array of arrays, each with a numeric value as the + first item and a color as the second. This function specifically is + converting a colorscale with tuple colors (each coordinate between 0 + and 1) into a colorscale with the colors transformed into rgb colors + """ + for color in colorscale: + color[1] = convert_to_RGB_255(color[1]) + + for color in colorscale: + color[1] = label_rgb(color[1]) + return colorscale + + def validate_scale_values(scale): """ Validates scale values from a colorscale diff --git a/plotly/tests/test_optional/test_figure_factory/test_figure_factory_utils.py b/plotly/tests/test_optional/test_figure_factory/test_figure_factory_utils.py new file mode 100644 index 00000000000..7e3399e96e0 --- /dev/null +++ b/plotly/tests/test_optional/test_figure_factory/test_figure_factory_utils.py @@ -0,0 +1,244 @@ +from unittest import TestCase + +from nose.tools import raises +import plotly.figure_factory.utils as utils +import plotly.tools as tls +from plotly.exceptions import PlotlyError + + +class TestFigureFactoryUtils(TestCase): + + def test_validate_index(self): + pattern = ( + "Error in indexing column. " + "Make sure all entries of each " + "column are all numbers or " + "all strings." + ) + + self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_index, + [13, 'foo']) + + self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_index, + ['number', 42]) + + def validate_dataframe(self): + pattern = ( + "Error in dataframe. " + "Make sure all entries of " + "each column are either " + "numbers or strings." + ) + + df = [ + [8, 'foo'], + ['amazing', 64] + ] + + self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_index, df) + + df = [ + ['amazing', 64], + [8, 'foo'] + ] + + self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_index, df) + + + def validate_equal_length(self): + pattern = ( + "Oops! Your data lists or ndarrays should be the same length." + ) + + self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_index, + (0,0,0), (0,0,0,0,0)) + + + def test_validate_positive_scalars(self): + # only one negative number + self.assertRaises(ValueError, utils.validate_positive_scalars, + number0=1, number1=0.001, number2=-2) + + def test_flatten(self): + + self.assertRaises( + PlotlyError, utils.flatten, + array=0 + ) + + def test_validate_colors(self): + + # test string input + color_string = 'foo' + + pattern = ("If your colors variable is a string, it must be a " + "Plotly scale, an rgb color or a hex color.") + + self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_colors, + color_string) + + # test rgb color + color_string2 = 'rgb(265, 0, 0)' + + pattern2 = ("Whoops! The elements in your rgb colors tuples cannot " + "exceed 255.0.") + + self.assertRaisesRegexp(PlotlyError, pattern2, utils.validate_colors, + color_string2) + + # test dictionary + colors_dict = { + 'apple': 'rgb(300, 0, 0)', + 'pear': (0, 0.5, 1) + } + + self.assertRaisesRegexp(PlotlyError, pattern2, utils.validate_colors_dict, + colors_dict) + + # test tuple color + color_tuple = (1, 1, 2) + + pattern3 = ("Whoops! The elements in your colors tuples cannot " + "exceed 1.0.") + + self.assertRaisesRegexp(PlotlyError, pattern3, utils.validate_colors, + color_tuple) + + colors_dict2 = { + 'apple': 'rgb(255, 100, 50)', + 'pear': (0, 0.5, 2) + } + + self.assertRaisesRegexp(PlotlyError, pattern3, utils.validate_colors_dict, + colors_dict2) + + def test_convert_colors_to_same_type(self): + + # test colortype + color_tuple = ['#aaaaaa', '#bbbbbb', '#cccccc'] + scale = [0, 1] + + self.assertRaises(PlotlyError, utils.convert_colors_to_same_type, + color_tuple, scale=scale) + + # test colortype + color_tuple = (1, 1, 1) + colortype = 2 + + pattern2 = ("You must select either rgb or tuple for your colortype " + "variable.") + + self.assertRaisesRegexp(PlotlyError, pattern2, + utils.convert_colors_to_same_type, + color_tuple, colortype) + + def test_convert_dict_colors_to_same_type(self): + + # test colortype + color_dict = dict(apple='rgb(1, 1, 1)') + colortype = 2 + + pattern = ("You must select either rgb or tuple for your colortype " + "variable.") + + self.assertRaisesRegexp(PlotlyError, pattern, + utils.convert_dict_colors_to_same_type, + color_dict, colortype) + + def test_validate_scale_values(self): + + # test that scale length is at least 2 + scale = [0] + + pattern = ("You must input a list of scale values that has at least " + "two values.") + + self.assertRaisesRegexp(PlotlyError, pattern, + utils.validate_scale_values, + scale) + + # test if first and last number is 0 and 1 respectively + scale = [0, 1.1] + + pattern = ("The first and last number in your scale must be 0.0 and " + "1.0 respectively.") + + self.assertRaisesRegexp(PlotlyError, pattern, + utils.validate_scale_values, + scale) + + # test numbers increase + scale = [0, 2, 1] + + pattern = ("'scale' must be a list that contains a strictly " + "increasing sequence of numbers.") + + self.assertRaisesRegexp(PlotlyError, pattern, + utils.validate_scale_values, + scale) + + def test_make_colorscale(self): + + # test minimum colors length + color_list = [(0, 0, 0)] + + pattern = ( + "You must input a list of colors that has at least two colors." + ) + + self.assertRaisesRegexp(PlotlyError, pattern, utils.make_colorscale, + color_list) + + # test length of colors and scale + color_list2 = [(0, 0, 0), (1, 1, 1)] + scale = [0] + + pattern2 = ("The length of colors and scale must be the same.") + + self.assertRaisesRegexp(PlotlyError, pattern2, utils.make_colorscale, + color_list2, scale) + + + def test_endpts_to_intervals(self): + + pattern = ("The intervals_endpts argument must " + "be a list or tuple of a sequence " + "of increasing numbers.") + + endpts = 'foo' + self.assertRaisesRegexp(PlotlyError, pattern, + utils.endpts_to_intervals, endpts) + + endpts = ['foo'] + self.assertRaisesRegexp(PlotlyError, pattern, + utils.endpts_to_intervals, endpts) + + endpts = [1, 0] + self.assertRaisesRegexp(PlotlyError, pattern, + utils.endpts_to_intervals, endpts) + + def test_validate_colorscale(self): + + pattern = "A valid colorscale must be a list." + + colorscale = 55 + self.assertRaisesRegexp(PlotlyError, pattern, + utils.validate_colorscale, colorscale) + + pattern2 = "A valid colorscale must be a list of lists." + + colorscale = [[], [], 'foo', []] + self.assertRaisesRegexp(PlotlyError, pattern, + utils.validate_colorscale, colorscale) + + def test_list_of_options(self): + + pattern = 'Your list or tuple must contain at least 2 items.' + + self.assertRaisesRegexp(PlotlyError, pattern, + utils.list_of_options, ['item #1']) + + + + + diff --git a/plotly/tools.py b/plotly/tools.py index 3f6fd488069..95e87efe4e8 100644 --- a/plotly/tools.py +++ b/plotly/tools.py @@ -27,23 +27,6 @@ REQUIRED_GANTT_KEYS = ['Task', 'Start', 'Finish'] -PLOTLY_SCALES = {'Greys': ['rgb(0,0,0)', 'rgb(255,255,255)'], - 'YlGnBu': ['rgb(8,29,88)', 'rgb(255,255,217)'], - 'Greens': ['rgb(0,68,27)', 'rgb(247,252,245)'], - 'YlOrRd': ['rgb(128,0,38)', 'rgb(255,255,204)'], - 'Bluered': ['rgb(0,0,255)', 'rgb(255,0,0)'], - 'RdBu': ['rgb(5,10,172)', 'rgb(178,10,28)'], - 'Reds': ['rgb(220,220,220)', 'rgb(178,10,28)'], - 'Blues': ['rgb(5,10,172)', 'rgb(220,220,220)'], - 'Picnic': ['rgb(0,0,255)', 'rgb(255,0,0)'], - 'Rainbow': ['rgb(150,0,90)', 'rgb(255,0,0)'], - 'Portland': ['rgb(12,51,131)', 'rgb(217,30,30)'], - 'Jet': ['rgb(0,0,131)', 'rgb(128,0,0)'], - 'Hot': ['rgb(0,0,0)', 'rgb(255,255,255)'], - 'Blackbody': ['rgb(0,0,0)', 'rgb(160,200,255)'], - 'Earth': ['rgb(0,0,130)', 'rgb(255,255,255)'], - 'Electric': ['rgb(0,0,0)', 'rgb(255,250,220)'], - 'Viridis': ['rgb(68,1,84)', 'rgb(253,231,37)']} # color constants for violin plot DEFAULT_FILLCOLOR = '#1f77b4'