diff --git a/plotly/figure_factory/_bullet.py b/plotly/figure_factory/_bullet.py
index bcb5b2541d1..e2c42a7855a 100644
--- a/plotly/figure_factory/_bullet.py
+++ b/plotly/figure_factory/_bullet.py
@@ -78,7 +78,7 @@ def _bullet(df, markers, measures, ranges, subtitles, titles, orientation,
for row in range(num_of_lanes):
# ranges bars
for idx in range(len(df.iloc[row]['ranges'])):
- inter_colors = colors.n_colors(
+ inter_colors = utils.n_colors(
range_colors[0], range_colors[1],
len(df.iloc[row]['ranges']), 'rgb'
)
@@ -104,7 +104,7 @@ def _bullet(df, markers, measures, ranges, subtitles, titles, orientation,
# measures bars
for idx in range(len(df.iloc[row]['measures'])):
- inter_colors = colors.n_colors(
+ inter_colors = utils.n_colors(
measure_colors[0], measure_colors[1],
len(df.iloc[row]['measures']), 'rgb'
)
@@ -318,7 +318,7 @@ def create_bullet(data, markers=None, measures=None, ranges=None,
"of two valid colors."
)
colors.validate_colors(colors_list)
- colors_list = colors.convert_colors_to_same_type(colors_list,
+ colors_list = utils.convert_colors_to_same_type(colors_list,
'rgb')[0]
# default scatter options
diff --git a/plotly/figure_factory/_county_choropleth.py b/plotly/figure_factory/_county_choropleth.py
index 8ee9fef7c2c..5c8674ce07a 100644
--- a/plotly/figure_factory/_county_choropleth.py
+++ b/plotly/figure_factory/_county_choropleth.py
@@ -642,7 +642,7 @@ def create_choropleth(fips, values, scope=['usa'], binning_endpoints=None,
if not colorscale:
colorscale = []
viridis_colors = colors.colorscale_to_colors(
- colors.PLOTLY_SCALES['Viridis']
+ utils.PLOTLY_SCALES['Viridis']
)
viridis_colors = colors.color_parser(
viridis_colors, colors.hex_to_rgb
@@ -674,7 +674,7 @@ def create_choropleth(fips, values, scope=['usa'], binning_endpoints=None,
# make R,G,B into int values
float_color = colors.unlabel_rgb(float_color)
- float_color = colors.unconvert_from_RGB_255(float_color)
+ float_color = utils.unconvert_from_RGB_255(float_color)
int_rgb = colors.convert_to_RGB_255(float_color)
int_rgb = colors.label_rgb(int_rgb)
diff --git a/plotly/figure_factory/_facet_grid.py b/plotly/figure_factory/_facet_grid.py
index a202599d747..20b6ede8326 100644
--- a/plotly/figure_factory/_facet_grid.py
+++ b/plotly/figure_factory/_facet_grid.py
@@ -935,13 +935,13 @@ def create_facet_grid(df, x=None, y=None, facet_row=None, facet_col=None,
marker_color, kwargs_trace, kwargs_marker
)
elif isinstance(colormap, str):
- if colormap in colors.PLOTLY_SCALES.keys():
- colorscale_list = colors.PLOTLY_SCALES[colormap]
+ if colormap in utils.PLOTLY_SCALES.keys():
+ colorscale_list = utils.PLOTLY_SCALES[colormap]
else:
raise exceptions.PlotlyError(
"If 'colormap' is a string, it must be the name "
"of a Plotly Colorscale. The available colorscale "
- "names are {}".format(colors.PLOTLY_SCALES.keys())
+ "names are {}".format(utils.PLOTLY_SCALES.keys())
)
fig, annotations = _facet_grid_color_numerical(
df, x, y, facet_row, facet_col, color_name,
@@ -951,7 +951,7 @@ def create_facet_grid(df, x=None, y=None, facet_row=None, facet_col=None,
marker_color, kwargs_trace, kwargs_marker
)
else:
- colorscale_list = colors.PLOTLY_SCALES['Reds']
+ colorscale_list = utils.PLOTLY_SCALES['Reds']
fig, annotations = _facet_grid_color_numerical(
df, x, y, facet_row, facet_col, color_name,
colorscale_list, num_of_rows, num_of_cols,
diff --git a/plotly/figure_factory/_scatterplot.py b/plotly/figure_factory/_scatterplot.py
index 3c2de2452f7..d56af50620d 100644
--- a/plotly/figure_factory/_scatterplot.py
+++ b/plotly/figure_factory/_scatterplot.py
@@ -1,5 +1,7 @@
from __future__ import absolute_import
+import six
+
from plotly import colors, exceptions, optional_imports
from plotly.figure_factory import utils
from plotly.graph_objs import graph_objs
@@ -1080,6 +1082,19 @@ def create_scatterplotmatrix(df, index=None, endpts=None, diag='scatter',
# Validate colormap
if isinstance(colormap, dict):
colormap = utils.validate_colors_dict(colormap, 'rgb')
+ elif isinstance(colormap, six.string_types) and 'rgb' not in colormap and '#' not in colormap:
+ if colormap not in utils.PLOTLY_SCALES.keys():
+ raise exceptions.PlotlyError(
+ "If 'colormap' is a string, it must be the name "
+ "of a Plotly Colorscale. The available colorscale "
+ "names are {}".format(utils.PLOTLY_SCALES.keys())
+ )
+ else:
+ # TODO change below to allow the correct Plotly colorscale
+ colormap = utils.colorscale_to_colors(utils.PLOTLY_SCALES[colormap])
+ # keep only first and last item - fix later
+ colormap = [colormap[0]] + [colormap[-1]]
+ colormap = utils.validate_colors(colormap, 'rgb')
else:
colormap = utils.validate_colors(colormap, 'rgb')
diff --git a/plotly/figure_factory/_trisurf.py b/plotly/figure_factory/_trisurf.py
index 06e9cc04fee..9b84b61cf46 100644
--- a/plotly/figure_factory/_trisurf.py
+++ b/plotly/figure_factory/_trisurf.py
@@ -1,6 +1,7 @@
from __future__ import absolute_import
from plotly import colors, exceptions, optional_imports
+from plotly.figure_factory import utils
from plotly.graph_objs import graph_objs
np = optional_imports.get_module('numpy')
@@ -147,8 +148,8 @@ def trisurf(x, y, z, simplices, show_colorbar, edges_color, scale,
if mean_dists_are_numbers and show_colorbar is True:
# make a colorscale from the colors
- colorscale = colors.make_colorscale(colormap, scale)
- colorscale = colors.convert_colorscale_to_rgb(colorscale)
+ colorscale = utils.make_colorscale(colormap, scale)
+ colorscale = utils.convert_colorscale_to_rgb(colorscale)
colorbar = graph_objs.Scatter3d(
x=x[:1],
@@ -455,7 +456,7 @@ def dist_origin(x, y, z):
# Validate colormap
colors.validate_colors(colormap)
- colormap, scale = colors.convert_colors_to_same_type(
+ colormap, scale = utils.convert_colors_to_same_type(
colormap, colortype='tuple',
return_default_colors=True, scale=scale
)
diff --git a/plotly/figure_factory/figure_factory/_2d_density.py b/plotly/figure_factory/figure_factory/_2d_density.py
new file mode 100644
index 00000000000..8d004bff147
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_2d_density.py
@@ -0,0 +1,166 @@
+from __future__ import absolute_import
+
+from numbers import Number
+
+from plotly import exceptions
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+
+
+def make_linear_colorscale(colors):
+ """
+ Makes a list of colors into a colorscale-acceptable form
+
+ For documentation regarding to the form of the output, see
+ https://plot.ly/python/reference/#mesh3d-colorscale
+ """
+ scale = 1. / (len(colors) - 1)
+ return [[i * scale, color] for i, color in enumerate(colors)]
+
+
+def create_2d_density(x, y, colorscale='Earth', ncontours=20,
+ hist_color=(0, 0, 0.5), point_color=(0, 0, 0.5),
+ point_size=2, title='2D Density Plot',
+ height=600, width=600):
+ """
+ Returns figure for a 2D density plot
+
+ :param (list|array) x: x-axis data for plot generation
+ :param (list|array) y: y-axis data for plot generation
+ :param (str|tuple|list) colorscale: either a plotly scale name, an rgb
+ or hex color, a color tuple or a list or tuple of colors. An rgb
+ color is of the form 'rgb(x, y, z)' where x, y, z belong to the
+ interval [0, 255] and a color tuple is a tuple of the form
+ (a, b, c) where a, b and c belong to [0, 1]. If colormap is a
+ list, it must contain the valid color types aforementioned as its
+ members.
+ :param (int) ncontours: the number of 2D contours to draw on the plot
+ :param (str) hist_color: the color of the plotted histograms
+ :param (str) point_color: the color of the scatter points
+ :param (str) point_size: the color of the scatter points
+ :param (str) title: set the title for the plot
+ :param (float) height: the height of the chart
+ :param (float) width: the width of the chart
+
+ Example 1: Simple 2D Density Plot
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory create_2d_density
+
+ import numpy as np
+
+ # Make data points
+ t = np.linspace(-1,1.2,2000)
+ x = (t**3)+(0.3*np.random.randn(2000))
+ y = (t**6)+(0.3*np.random.randn(2000))
+
+ # Create a figure
+ fig = create_2D_density(x, y)
+
+ # Plot the data
+ py.iplot(fig, filename='simple-2d-density')
+ ```
+
+ Example 2: Using Parameters
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory create_2d_density
+
+ import numpy as np
+
+ # Make data points
+ t = np.linspace(-1,1.2,2000)
+ x = (t**3)+(0.3*np.random.randn(2000))
+ y = (t**6)+(0.3*np.random.randn(2000))
+
+ # Create custom colorscale
+ colorscale = ['#7A4579', '#D56073', 'rgb(236,158,105)',
+ (1, 1, 0.2), (0.98,0.98,0.98)]
+
+ # Create a figure
+ fig = create_2D_density(
+ x, y, colorscale=colorscale,
+ hist_color='rgb(255, 237, 222)', point_size=3)
+
+ # Plot the data
+ py.iplot(fig, filename='use-parameters')
+ ```
+ """
+
+ # validate x and y are filled with numbers only
+ for array in [x, y]:
+ if not all(isinstance(element, Number) for element in array):
+ raise exceptions.PlotlyError(
+ "All elements of your 'x' and 'y' lists must be numbers."
+ )
+
+ # validate x and y are the same length
+ if len(x) != len(y):
+ raise exceptions.PlotlyError(
+ "Both lists 'x' and 'y' must be the same length."
+ )
+
+ colorscale = utils.validate_colors(colorscale, 'rgb')
+ colorscale = make_linear_colorscale(colorscale)
+
+ # validate hist_color and point_color
+ hist_color = utils.validate_colors(hist_color, 'rgb')
+ point_color = utils.validate_colors(point_color, 'rgb')
+
+ trace1 = graph_objs.Scatter(
+ x=x, y=y, mode='markers', name='points',
+ marker=dict(
+ color=point_color[0],
+ size=point_size,
+ opacity=0.4
+ )
+ )
+ trace2 = graph_objs.Histogram2dContour(
+ x=x, y=y, name='density', ncontours=ncontours,
+ colorscale=colorscale, reversescale=True, showscale=False
+ )
+ trace3 = graph_objs.Histogram(
+ x=x, name='x density',
+ marker=dict(color=hist_color[0]), yaxis='y2'
+ )
+ trace4 = graph_objs.Histogram(
+ y=y, name='y density',
+ marker=dict(color=hist_color[0]), xaxis='x2'
+ )
+ data = [trace1, trace2, trace3, trace4]
+
+ layout = graph_objs.Layout(
+ showlegend=False,
+ autosize=False,
+ title=title,
+ height=height,
+ width=width,
+ xaxis=dict(
+ domain=[0, 0.85],
+ showgrid=False,
+ zeroline=False
+ ),
+ yaxis=dict(
+ domain=[0, 0.85],
+ showgrid=False,
+ zeroline=False
+ ),
+ margin=dict(
+ t=50
+ ),
+ hovermode='closest',
+ bargap=0,
+ xaxis2=dict(
+ domain=[0.85, 1],
+ showgrid=False,
+ zeroline=False
+ ),
+ yaxis2=dict(
+ domain=[0.85, 1],
+ showgrid=False,
+ zeroline=False
+ )
+ )
+
+ fig = graph_objs.Figure(data=data, layout=layout)
+ return fig
diff --git a/plotly/figure_factory/figure_factory/__init__.py b/plotly/figure_factory/figure_factory/__init__.py
new file mode 100644
index 00000000000..a8be19872e1
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/__init__.py
@@ -0,0 +1,24 @@
+from __future__ import absolute_import
+
+from plotly import optional_imports
+
+# Require that numpy exists for figure_factory
+import numpy
+
+from plotly.figure_factory._2d_density import create_2d_density
+from plotly.figure_factory._annotated_heatmap import create_annotated_heatmap
+from plotly.figure_factory._bullet import create_bullet
+from plotly.figure_factory._candlestick import create_candlestick
+from plotly.figure_factory._dendrogram import create_dendrogram
+from plotly.figure_factory._distplot import create_distplot
+from plotly.figure_factory._facet_grid import create_facet_grid
+from plotly.figure_factory._gantt import create_gantt
+from plotly.figure_factory._ohlc import create_ohlc
+from plotly.figure_factory._quiver import create_quiver
+from plotly.figure_factory._scatterplot import create_scatterplotmatrix
+from plotly.figure_factory._streamline import create_streamline
+from plotly.figure_factory._table import create_table
+from plotly.figure_factory._trisurf import create_trisurf
+from plotly.figure_factory._violin import create_violin
+if optional_imports.get_module('pandas') is not None:
+ from plotly.figure_factory._county_choropleth import create_choropleth
diff --git a/plotly/figure_factory/figure_factory/_annotated_heatmap.py b/plotly/figure_factory/figure_factory/_annotated_heatmap.py
new file mode 100644
index 00000000000..3400aba0e48
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_annotated_heatmap.py
@@ -0,0 +1,263 @@
+from __future__ import absolute_import
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+from plotly.validators.heatmap import ColorscaleValidator
+
+# Optional imports, may be None for users that only use our core functionality.
+np = optional_imports.get_module('numpy')
+
+
+def validate_annotated_heatmap(z, x, y, annotation_text):
+ """
+ Annotated-heatmap-specific validations
+
+ Check that if a text matrix is supplied, it has the same
+ dimensions as the z matrix.
+
+ See FigureFactory.create_annotated_heatmap() for params
+
+ :raises: (PlotlyError) If z and text matrices do not have the same
+ dimensions.
+ """
+ if annotation_text is not None and isinstance(annotation_text, list):
+ utils.validate_equal_length(z, annotation_text)
+ for lst in range(len(z)):
+ if len(z[lst]) != len(annotation_text[lst]):
+ raise exceptions.PlotlyError("z and text should have the "
+ "same dimensions")
+
+ if x:
+ if len(x) != len(z[0]):
+ raise exceptions.PlotlyError("oops, the x list that you "
+ "provided does not match the "
+ "width of your z matrix ")
+
+ if y:
+ if len(y) != len(z):
+ raise exceptions.PlotlyError("oops, the y list that you "
+ "provided does not match the "
+ "length of your z matrix ")
+
+
+def create_annotated_heatmap(z, x=None, y=None, annotation_text=None,
+ colorscale='RdBu', font_colors=None,
+ showscale=False, reversescale=False,
+ **kwargs):
+ """
+ BETA function that creates annotated heatmaps
+
+ This function adds annotations to each cell of the heatmap.
+
+ :param (list[list]|ndarray) z: z matrix to create heatmap.
+ :param (list) x: x axis labels.
+ :param (list) y: y axis labels.
+ :param (list[list]|ndarray) annotation_text: Text strings for
+ annotations. Should have the same dimensions as the z matrix. If no
+ text is added, the values of the z matrix are annotated. Default =
+ z matrix values.
+ :param (list|str) colorscale: heatmap colorscale.
+ :param (list) font_colors: List of two color strings: [min_text_color,
+ max_text_color] where min_text_color is applied to annotations for
+ heatmap values < (max_value - min_value)/2. If font_colors is not
+ defined, the colors are defined logically as black or white
+ depending on the heatmap's colorscale.
+ :param (bool) showscale: Display colorscale. Default = False
+ :param (bool) reversescale: Reverse colorscale. Default = False
+ :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
+ These kwargs describe other attributes about the annotated Heatmap
+ trace such as the colorscale. For more information on valid kwargs
+ call help(plotly.graph_objs.Heatmap)
+
+ Example 1: Simple annotated heatmap with default configuration
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as FF
+
+ z = [[0.300000, 0.00000, 0.65, 0.300000],
+ [1, 0.100005, 0.45, 0.4300],
+ [0.300000, 0.00000, 0.65, 0.300000],
+ [1, 0.100005, 0.45, 0.00000]]
+
+ figure = FF.create_annotated_heatmap(z)
+ py.iplot(figure)
+ ```
+ """
+
+ # Avoiding mutables in the call signature
+ font_colors = font_colors if font_colors is not None else []
+ validate_annotated_heatmap(z, x, y, annotation_text)
+
+ # validate colorscale
+ colorscale_validator = ColorscaleValidator()
+ colorscale = colorscale_validator.validate_coerce(colorscale)
+
+ annotations = _AnnotatedHeatmap(z, x, y, annotation_text,
+ colorscale, font_colors, reversescale,
+ **kwargs).make_annotations()
+
+ if x or y:
+ trace = dict(type='heatmap', z=z, x=x, y=y, colorscale=colorscale,
+ showscale=showscale, reversescale=reversescale, **kwargs)
+ layout = dict(annotations=annotations,
+ xaxis=dict(ticks='', dtick=1, side='top',
+ gridcolor='rgb(0, 0, 0)'),
+ yaxis=dict(ticks='', dtick=1, ticksuffix=' '))
+ else:
+ trace = dict(type='heatmap', z=z, colorscale=colorscale,
+ showscale=showscale, reversescale=reversescale, **kwargs)
+ layout = dict(annotations=annotations,
+ xaxis=dict(ticks='', side='top',
+ gridcolor='rgb(0, 0, 0)',
+ showticklabels=False),
+ yaxis=dict(ticks='', ticksuffix=' ',
+ showticklabels=False))
+
+ data = [trace]
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+def to_rgb_color_list(color_str, default):
+ if 'rgb' in color_str:
+ return [int(v) for v in color_str.strip('rgb()').split(',')]
+ elif '#' in color_str:
+ return utils.hex_to_rgb(color_str)
+ else:
+ return default
+
+
+def should_use_black_text(background_color):
+ return (background_color[0] * 0.299 +
+ background_color[1] * 0.587 +
+ background_color[2] * 0.114) > 186
+
+
+class _AnnotatedHeatmap(object):
+ """
+ Refer to TraceFactory.create_annotated_heatmap() for docstring
+ """
+ def __init__(self, z, x, y, annotation_text, colorscale,
+ font_colors, reversescale, **kwargs):
+
+ self.z = z
+ if x:
+ self.x = x
+ else:
+ self.x = range(len(z[0]))
+ if y:
+ self.y = y
+ else:
+ self.y = range(len(z))
+ if annotation_text is not None:
+ self.annotation_text = annotation_text
+ else:
+ self.annotation_text = self.z
+ self.colorscale = colorscale
+ self.reversescale = reversescale
+ self.font_colors = font_colors
+
+ def get_text_color(self):
+ """
+ Get font color for annotations.
+
+ The annotated heatmap can feature two text colors: min_text_color and
+ max_text_color. The min_text_color is applied to annotations for
+ heatmap values < (max_value - min_value)/2. The user can define these
+ two colors. Otherwise the colors are defined logically as black or
+ white depending on the heatmap's colorscale.
+
+ :rtype (string, string) min_text_color, max_text_color: text
+ color for annotations for heatmap values <
+ (max_value - min_value)/2 and text color for annotations for
+ heatmap values >= (max_value - min_value)/2
+ """
+ # Plotly colorscales ranging from a lighter shade to a darker shade
+ colorscales = ['Greys', 'Greens', 'Blues',
+ 'YIGnBu', 'YIOrRd', 'RdBu',
+ 'Picnic', 'Jet', 'Hot', 'Blackbody',
+ 'Earth', 'Electric', 'Viridis', 'Cividis']
+ # Plotly colorscales ranging from a darker shade to a lighter shade
+ colorscales_reverse = ['Reds']
+
+ white = '#FFFFFF'
+ black = '#000000'
+ if self.font_colors:
+ min_text_color = self.font_colors[0]
+ max_text_color = self.font_colors[-1]
+ elif self.colorscale in colorscales and self.reversescale:
+ min_text_color = black
+ max_text_color = white
+ elif self.colorscale in colorscales:
+ min_text_color = white
+ max_text_color = black
+ elif self.colorscale in colorscales_reverse and self.reversescale:
+ min_text_color = white
+ max_text_color = black
+ elif self.colorscale in colorscales_reverse:
+ min_text_color = black
+ max_text_color = white
+ elif isinstance(self.colorscale, list):
+
+ min_col = to_rgb_color_list(self.colorscale[0][1],
+ [255, 255, 255])
+ max_col = to_rgb_color_list(self.colorscale[-1][1],
+ [255, 255, 255])
+
+ # swap min/max colors if reverse scale
+ if self.reversescale:
+ min_col, max_col = max_col, min_col
+
+ if should_use_black_text(min_col):
+ min_text_color = black
+ else:
+ min_text_color = white
+
+ if should_use_black_text(max_col):
+ max_text_color = black
+ else:
+ max_text_color = white
+ else:
+ min_text_color = black
+ max_text_color = black
+ return min_text_color, max_text_color
+
+ def get_z_mid(self):
+ """
+ Get the mid value of z matrix
+
+ :rtype (float) z_avg: average val from z matrix
+ """
+ if np and isinstance(self.z, np.ndarray):
+ z_min = np.amin(self.z)
+ z_max = np.amax(self.z)
+ else:
+ z_min = min(min(self.z))
+ z_max = max(max(self.z))
+ z_mid = (z_max+z_min) / 2
+ return z_mid
+
+ def make_annotations(self):
+ """
+ Get annotations for each cell of the heatmap with graph_objs.Annotation
+
+ :rtype (list[dict]) annotations: list of annotations for each cell of
+ the heatmap
+ """
+ min_text_color, max_text_color = _AnnotatedHeatmap.get_text_color(self)
+ z_mid = _AnnotatedHeatmap.get_z_mid(self)
+ annotations = []
+ for n, row in enumerate(self.z):
+ for m, val in enumerate(row):
+ font_color = min_text_color if val < z_mid else max_text_color
+ annotations.append(
+ graph_objs.layout.Annotation(
+ text=str(self.annotation_text[n][m]),
+ x=self.x[m],
+ y=self.y[n],
+ xref='x1',
+ yref='y1',
+ font=dict(color=font_color),
+ showarrow=False))
+ return annotations
diff --git a/plotly/figure_factory/figure_factory/_bullet.py b/plotly/figure_factory/figure_factory/_bullet.py
new file mode 100644
index 00000000000..046dea3224a
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_bullet.py
@@ -0,0 +1,340 @@
+from __future__ import absolute_import
+
+import collections
+import math
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+
+import plotly
+import plotly.graph_objs as go
+
+pd = optional_imports.get_module('pandas')
+
+
+def _bullet(df, markers, measures, ranges, subtitles, titles, orientation,
+ range_colors, measure_colors, horizontal_spacing,
+ vertical_spacing, scatter_options, layout_options):
+
+ num_of_lanes = len(df)
+ num_of_rows = num_of_lanes if orientation == 'h' else 1
+ num_of_cols = 1 if orientation == 'h' else num_of_lanes
+ if not horizontal_spacing:
+ horizontal_spacing = 1./num_of_lanes
+ if not vertical_spacing:
+ vertical_spacing = 1./num_of_lanes
+ fig = plotly.tools.make_subplots(
+ num_of_rows, num_of_cols, print_grid=False,
+ horizontal_spacing=horizontal_spacing,
+ vertical_spacing=vertical_spacing
+ )
+
+ # layout
+ fig['layout'].update(
+ dict(shapes=[]),
+ title='Bullet Chart',
+ height=600,
+ width=1000,
+ showlegend=False,
+ barmode='stack',
+ annotations=[],
+ margin=dict(l=120 if orientation == 'h' else 80),
+ )
+
+ # update layout
+ fig['layout'].update(layout_options)
+
+ if orientation == 'h':
+ width_axis = 'yaxis'
+ length_axis = 'xaxis'
+ else:
+ width_axis = 'xaxis'
+ length_axis = 'yaxis'
+
+ for key in fig['layout']:
+ if 'xaxis' in key or 'yaxis' in key:
+ fig['layout'][key]['showgrid'] = False
+ fig['layout'][key]['zeroline'] = False
+ if length_axis in key:
+ fig['layout'][key]['tickwidth'] = 1
+ if width_axis in key:
+ fig['layout'][key]['showticklabels'] = False
+ fig['layout'][key]['range'] = [0, 1]
+
+ # narrow domain if 1 bar
+ if num_of_lanes <= 1:
+ fig['layout'][width_axis + '1']['domain'] = [0.4, 0.6]
+
+ if not range_colors:
+ range_colors = ['rgb(200, 200, 200)', 'rgb(245, 245, 245)']
+ if not measure_colors:
+ measure_colors = ['rgb(31, 119, 180)', 'rgb(176, 196, 221)']
+
+ for row in range(num_of_lanes):
+ # ranges bars
+ for idx in range(len(df.iloc[row]['ranges'])):
+ inter_colors = utils.n_colors(
+ range_colors[0], range_colors[1],
+ len(df.iloc[row]['ranges']), 'rgb'
+ )
+ x = ([sorted(df.iloc[row]['ranges'])[-1 - idx]] if
+ orientation == 'h' else [0])
+ y = ([0] if orientation == 'h' else
+ [sorted(df.iloc[row]['ranges'])[-1 - idx]])
+ bar = go.Bar(
+ x=x,
+ y=y,
+ marker=dict(
+ color=inter_colors[-1 - idx]
+ ),
+ name='ranges',
+ hoverinfo='x' if orientation == 'h' else 'y',
+ orientation=orientation,
+ width=2,
+ base=0,
+ xaxis='x{}'.format(row + 1),
+ yaxis='y{}'.format(row + 1)
+ )
+ fig.add_trace(bar)
+
+ # measures bars
+ for idx in range(len(df.iloc[row]['measures'])):
+ inter_colors = utils.n_colors(
+ measure_colors[0], measure_colors[1],
+ len(df.iloc[row]['measures']), 'rgb'
+ )
+ x = ([sorted(df.iloc[row]['measures'])[-1 - idx]] if
+ orientation == 'h' else [0.5])
+ y = ([0.5] if orientation == 'h'
+ else [sorted(df.iloc[row]['measures'])[-1 - idx]])
+ bar = go.Bar(
+ x=x,
+ y=y,
+ marker=dict(
+ color=inter_colors[-1 - idx]
+ ),
+ name='measures',
+ hoverinfo='x' if orientation == 'h' else 'y',
+ orientation=orientation,
+ width=0.4,
+ base=0,
+ xaxis='x{}'.format(row + 1),
+ yaxis='y{}'.format(row + 1)
+ )
+ fig.add_trace(bar)
+
+ # markers
+ x = df.iloc[row]['markers'] if orientation == 'h' else [0.5]
+ y = [0.5] if orientation == 'h' else df.iloc[row]['markers']
+ markers = go.Scatter(
+ x=x,
+ y=y,
+ name='markers',
+ hoverinfo='x' if orientation == 'h' else 'y',
+ xaxis='x{}'.format(row + 1),
+ yaxis='y{}'.format(row + 1),
+ **scatter_options
+ )
+
+ fig.add_trace(markers)
+
+ # titles and subtitles
+ title = df.iloc[row]['titles']
+ if 'subtitles' in df:
+ subtitle = '
{}'.format(df.iloc[row]['subtitles'])
+ else:
+ subtitle = ''
+ label = '{}'.format(title) + subtitle
+ annot = utils.annotation_dict_for_label(
+ label,
+ (num_of_lanes - row if orientation == 'h' else row + 1),
+ num_of_lanes,
+ vertical_spacing if orientation == 'h' else horizontal_spacing,
+ 'row' if orientation == 'h' else 'col',
+ True if orientation == 'h' else False,
+ False
+ )
+ fig['layout']['annotations'] += (annot,)
+
+ return fig
+
+
+def create_bullet(data, markers=None, measures=None, ranges=None,
+ subtitles=None, titles=None, orientation='h',
+ range_colors=('rgb(200, 200, 200)', 'rgb(245, 245, 245)'),
+ measure_colors=('rgb(31, 119, 180)', 'rgb(176, 196, 221)'),
+ horizontal_spacing=None, vertical_spacing=None,
+ scatter_options={}, **layout_options):
+ """
+ Returns figure for bullet chart.
+
+ :param (pd.DataFrame | list | tuple) data: either a list/tuple of
+ dictionaries or a pandas DataFrame.
+ :param (str) markers: the column name or dictionary key for the markers in
+ each subplot.
+ :param (str) measures: the column name or dictionary key for the measure
+ bars in each subplot. This bar usually represents the quantitative
+ measure of performance, usually a list of two values [a, b] and are
+ the blue bars in the foreground of each subplot by default.
+ :param (str) ranges: the column name or dictionary key for the qualitative
+ ranges of performance, usually a 3-item list [bad, okay, good]. They
+ correspond to the grey bars in the background of each chart.
+ :param (str) subtitles: the column name or dictionary key for the subtitle
+ of each subplot chart. The subplots are displayed right underneath
+ each title.
+ :param (str) titles: the column name or dictionary key for the main label
+ of each subplot chart.
+ :param (bool) orientation: if 'h', the bars are placed horizontally as
+ rows. If 'v' the bars are placed vertically in the chart.
+ :param (list) range_colors: a tuple of two colors between which all
+ the rectangles for the range are drawn. These rectangles are meant to
+ be qualitative indicators against which the marker and measure bars
+ are compared.
+ Default=('rgb(200, 200, 200)', 'rgb(245, 245, 245)')
+ :param (list) measure_colors: a tuple of two colors which is used to color
+ the thin quantitative bars in the bullet chart.
+ Default=('rgb(31, 119, 180)', 'rgb(176, 196, 221)')
+ :param (float) horizontal_spacing: see the 'horizontal_spacing' param in
+ plotly.tools.make_subplots. Ranges between 0 and 1.
+ :param (float) vertical_spacing: see the 'vertical_spacing' param in
+ plotly.tools.make_subplots. Ranges between 0 and 1.
+ :param (dict) scatter_options: describes attributes for the scatter trace
+ in each subplot such as name and marker size. Call
+ help(plotly.graph_objs.Scatter) for more information on valid params.
+ :param layout_options: describes attributes for the layout of the figure
+ such as title, height and width. Call help(plotly.graph_objs.Layout)
+ for more information on valid params.
+
+ Example 1: Use a Dictionary
+ ```
+ import plotly
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ data = [
+ {"label": "Revenue", "sublabel": "US$, in thousands",
+ "range": [150, 225, 300], "performance": [220,270], "point": [250]},
+ {"label": "Profit", "sublabel": "%", "range": [20, 25, 30],
+ "performance": [21, 23], "point": [26]},
+ {"label": "Order Size", "sublabel":"US$, average","range": [350, 500, 600],
+ "performance": [100,320],"point": [550]},
+ {"label": "New Customers", "sublabel": "count", "range": [1400, 2000, 2500],
+ "performance": [1000, 1650],"point": [2100]},
+ {"label": "Satisfaction", "sublabel": "out of 5","range": [3.5, 4.25, 5],
+ "performance": [3.2, 4.7], "point": [4.4]}
+ ]
+
+ fig = ff.create_bullet(
+ data, titles='label', subtitles='sublabel', markers='point',
+ measures='performance', ranges='range', orientation='h',
+ title='my simple bullet chart'
+ )
+ py.iplot(fig)
+ ```
+
+ Example 2: Use a DataFrame with Custom Colors
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import pandas as pd
+
+ data = pd.read_json('https://cdn.rawgit.com/plotly/datasets/master/BulletData.json')
+
+ fig = ff.create_bullet(
+ data, titles='title', markers='markers', measures='measures',
+ orientation='v', measure_colors=['rgb(14, 52, 75)', 'rgb(31, 141, 127)'],
+ scatter_options={'marker': {'symbol': 'circle'}}, width=700
+
+ )
+ py.iplot(fig)
+ ```
+ """
+ # validate df
+ if not pd:
+ raise exceptions.ImportError(
+ "'pandas' must be installed for this figure factory."
+ )
+
+ if utils.is_sequence(data):
+ if not all(isinstance(item, dict) for item in data):
+ raise exceptions.PlotlyError(
+ 'Every entry of the data argument list, tuple, etc must '
+ 'be a dictionary.'
+ )
+
+ elif not isinstance(data, pd.DataFrame):
+ raise exceptions.PlotlyError(
+ 'You must input a pandas DataFrame, or a list of dictionaries.'
+ )
+
+ # make DataFrame from data with correct column headers
+ col_names = ['titles', 'subtitle', 'markers', 'measures', 'ranges']
+ if utils.is_sequence(data):
+ df = pd.DataFrame(
+ [
+ [d[titles] for d in data] if titles else [''] * len(data),
+ [d[subtitles] for d in data] if subtitles else [''] * len(data),
+ [d[markers] for d in data] if markers else [[]] * len(data),
+ [d[measures] for d in data] if measures else [[]] * len(data),
+ [d[ranges] for d in data] if ranges else [[]] * len(data),
+ ],
+ index=col_names
+ )
+ elif isinstance(data, pd.DataFrame):
+ df = pd.DataFrame(
+ [
+ data[titles].tolist() if titles else [''] * len(data),
+ data[subtitles].tolist() if subtitles else [''] * len(data),
+ data[markers].tolist() if markers else [[]] * len(data),
+ data[measures].tolist() if measures else [[]] * len(data),
+ data[ranges].tolist() if ranges else [[]] * len(data),
+ ],
+ index=col_names
+ )
+ df = pd.DataFrame.transpose(df)
+
+ # make sure ranges, measures, 'markers' are not NAN or NONE
+ for needed_key in ['ranges', 'measures', 'markers']:
+ for idx, r in enumerate(df[needed_key]):
+ try:
+ r_is_nan = math.isnan(r)
+ if r_is_nan or r is None:
+ df[needed_key][idx] = []
+ except TypeError:
+ pass
+
+ # validate custom colors
+ for colors_list in [range_colors, measure_colors]:
+ if colors_list:
+ if len(colors_list) != 2:
+ raise exceptions.PlotlyError(
+ "Both 'range_colors' or 'measure_colors' must be a list "
+ "of two valid colors."
+ )
+ utils.validate_colors(colors_list)
+ colors_list = utils.convert_colors_to_same_type(colors_list,
+ 'rgb')[0]
+
+ # default scatter options
+ default_scatter = {
+ 'marker': {'size': 12,
+ 'symbol': 'diamond-tall',
+ 'color': 'rgb(0, 0, 0)'}
+ }
+
+ if scatter_options == {}:
+ scatter_options.update(default_scatter)
+ else:
+ # add default options to scatter_options if they are not present
+ for k in default_scatter['marker']:
+ if k not in scatter_options['marker']:
+ scatter_options['marker'][k] = default_scatter['marker'][k]
+
+ fig = _bullet(
+ df, markers, measures, ranges, subtitles, titles, orientation,
+ range_colors, measure_colors, horizontal_spacing, vertical_spacing,
+ scatter_options, layout_options,
+ )
+
+ return fig
diff --git a/plotly/figure_factory/figure_factory/_candlestick.py b/plotly/figure_factory/figure_factory/_candlestick.py
new file mode 100644
index 00000000000..925b4c1a62b
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_candlestick.py
@@ -0,0 +1,294 @@
+from __future__ import absolute_import
+
+from plotly.figure_factory import utils
+from plotly.figure_factory._ohlc import (_DEFAULT_INCREASING_COLOR,
+ _DEFAULT_DECREASING_COLOR,
+ validate_ohlc)
+from plotly.graph_objs import graph_objs
+
+
+def make_increasing_candle(open, high, low, close, dates, **kwargs):
+ """
+ Makes boxplot trace for increasing candlesticks
+
+ _make_increasing_candle() and _make_decreasing_candle separate the
+ increasing traces from the decreasing traces so kwargs (such as
+ color) can be passed separately to increasing or decreasing traces
+ when direction is set to 'increasing' or 'decreasing' in
+ FigureFactory.create_candlestick()
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to increasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (list) candle_incr_data: list of the box trace for
+ increasing candlesticks.
+ """
+ increase_x, increase_y = _Candlestick(
+ open, high, low, close, dates, **kwargs).get_candle_increase()
+
+ if 'line' in kwargs:
+ kwargs.setdefault('fillcolor', kwargs['line']['color'])
+ else:
+ kwargs.setdefault('fillcolor', _DEFAULT_INCREASING_COLOR)
+ if 'name' in kwargs:
+ kwargs.setdefault('showlegend', True)
+ else:
+ kwargs.setdefault('showlegend', False)
+ kwargs.setdefault('name', 'Increasing')
+ kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR))
+
+ candle_incr_data = dict(type='box',
+ x=increase_x,
+ y=increase_y,
+ whiskerwidth=0,
+ boxpoints=False,
+ **kwargs)
+
+ return [candle_incr_data]
+
+
+def make_decreasing_candle(open, high, low, close, dates, **kwargs):
+ """
+ Makes boxplot trace for decreasing candlesticks
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to decreasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (list) candle_decr_data: list of the box trace for
+ decreasing candlesticks.
+ """
+
+ decrease_x, decrease_y = _Candlestick(
+ open, high, low, close, dates, **kwargs).get_candle_decrease()
+
+ if 'line' in kwargs:
+ kwargs.setdefault('fillcolor', kwargs['line']['color'])
+ else:
+ kwargs.setdefault('fillcolor', _DEFAULT_DECREASING_COLOR)
+ kwargs.setdefault('showlegend', False)
+ kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR))
+ kwargs.setdefault('name', 'Decreasing')
+
+ candle_decr_data = dict(type='box',
+ x=decrease_x,
+ y=decrease_y,
+ whiskerwidth=0,
+ boxpoints=False,
+ **kwargs)
+
+ return [candle_decr_data]
+
+
+def create_candlestick(open, high, low, close, dates=None, direction='both',
+ **kwargs):
+ """
+ BETA function that creates a candlestick chart
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param (string) direction: direction can be 'increasing', 'decreasing',
+ or 'both'. When the direction is 'increasing', the returned figure
+ consists of all candlesticks where the close value is greater than
+ the corresponding open value, and when the direction is
+ 'decreasing', the returned figure consists of all candlesticks
+ where the close value is less than or equal to the corresponding
+ open value. When the direction is 'both', both increasing and
+ decreasing candlesticks are returned. Default: 'both'
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
+ These kwargs describe other attributes about the ohlc Scatter trace
+ such as the color or the legend name. For more information on valid
+ kwargs call help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of candlestick chart figure.
+
+ Example 1: Simple candlestick chart from a Pandas DataFrame
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_candlestick
+ from datetime import datetime
+
+ import pandas.io.data as web
+
+ df = web.DataReader("aapl", 'yahoo', datetime(2007, 10, 1), datetime(2009, 4, 1))
+ fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index)
+ py.plot(fig, filename='finance/aapl-candlestick', validate=False)
+ ```
+
+ Example 2: Add text and annotations to the candlestick chart
+ ```
+ fig = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index)
+ # Update the fig - all options here: https://plot.ly/python/reference/#Layout
+ fig['layout'].update({
+ 'title': 'The Great Recession',
+ 'yaxis': {'title': 'AAPL Stock'},
+ 'shapes': [{
+ 'x0': '2007-12-01', 'x1': '2007-12-01',
+ 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper',
+ 'line': {'color': 'rgb(30,30,30)', 'width': 1}
+ }],
+ 'annotations': [{
+ 'x': '2007-12-01', 'y': 0.05, 'xref': 'x', 'yref': 'paper',
+ 'showarrow': False, 'xanchor': 'left',
+ 'text': 'Official start of the recession'
+ }]
+ })
+ py.plot(fig, filename='finance/aapl-recession-candlestick', validate=False)
+ ```
+
+ Example 3: Customize the candlestick colors
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_candlestick
+ from plotly.graph_objs import Line, Marker
+ from datetime import datetime
+
+ import pandas.io.data as web
+
+ df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1), datetime(2009, 4, 1))
+
+ # Make increasing candlesticks and customize their color and name
+ fig_increasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index,
+ direction='increasing', name='AAPL',
+ marker=Marker(color='rgb(150, 200, 250)'),
+ line=Line(color='rgb(150, 200, 250)'))
+
+ # Make decreasing candlesticks and customize their color and name
+ fig_decreasing = create_candlestick(df.Open, df.High, df.Low, df.Close, dates=df.index,
+ direction='decreasing',
+ marker=Marker(color='rgb(128, 128, 128)'),
+ line=Line(color='rgb(128, 128, 128)'))
+
+ # Initialize the figure
+ fig = fig_increasing
+
+ # Add decreasing data with .extend()
+ fig['data'].extend(fig_decreasing['data'])
+
+ py.iplot(fig, filename='finance/aapl-candlestick-custom', validate=False)
+ ```
+
+ Example 4: Candlestick chart with datetime objects
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_candlestick
+
+ from datetime import datetime
+
+ # Add data
+ open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
+ high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
+ low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
+ close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
+ dates = [datetime(year=2013, month=10, day=10),
+ datetime(year=2013, month=11, day=10),
+ datetime(year=2013, month=12, day=10),
+ datetime(year=2014, month=1, day=10),
+ datetime(year=2014, month=2, day=10)]
+
+ # Create ohlc
+ fig = create_candlestick(open_data, high_data,
+ low_data, close_data, dates=dates)
+
+ py.iplot(fig, filename='finance/simple-candlestick', validate=False)
+ ```
+ """
+ if dates is not None:
+ utils.validate_equal_length(open, high, low, close, dates)
+ else:
+ utils.validate_equal_length(open, high, low, close)
+ validate_ohlc(open, high, low, close, direction, **kwargs)
+
+ if direction is 'increasing':
+ candle_incr_data = make_increasing_candle(open, high, low, close,
+ dates, **kwargs)
+ data = candle_incr_data
+ elif direction is 'decreasing':
+ candle_decr_data = make_decreasing_candle(open, high, low, close,
+ dates, **kwargs)
+ data = candle_decr_data
+ else:
+ candle_incr_data = make_increasing_candle(open, high, low, close,
+ dates, **kwargs)
+ candle_decr_data = make_decreasing_candle(open, high, low, close,
+ dates, **kwargs)
+ data = candle_incr_data + candle_decr_data
+
+ layout = graph_objs.Layout()
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Candlestick(object):
+ """
+ Refer to FigureFactory.create_candlestick() for docstring.
+ """
+ def __init__(self, open, high, low, close, dates, **kwargs):
+ self.open = open
+ self.high = high
+ self.low = low
+ self.close = close
+ if dates is not None:
+ self.x = dates
+ else:
+ self.x = [x for x in range(len(self.open))]
+ self.get_candle_increase()
+
+ def get_candle_increase(self):
+ """
+ Separate increasing data from decreasing data.
+
+ The data is increasing when close value > open value
+ and decreasing when the close value <= open value.
+ """
+ increase_y = []
+ increase_x = []
+ for index in range(len(self.open)):
+ if self.close[index] > self.open[index]:
+ increase_y.append(self.low[index])
+ increase_y.append(self.open[index])
+ increase_y.append(self.close[index])
+ increase_y.append(self.close[index])
+ increase_y.append(self.close[index])
+ increase_y.append(self.high[index])
+ increase_x.append(self.x[index])
+
+ increase_x = [[x, x, x, x, x, x] for x in increase_x]
+ increase_x = utils.flatten(increase_x)
+
+ return increase_x, increase_y
+
+ def get_candle_decrease(self):
+ """
+ Separate increasing data from decreasing data.
+
+ The data is increasing when close value > open value
+ and decreasing when the close value <= open value.
+ """
+ decrease_y = []
+ decrease_x = []
+ for index in range(len(self.open)):
+ if self.close[index] <= self.open[index]:
+ decrease_y.append(self.low[index])
+ decrease_y.append(self.open[index])
+ decrease_y.append(self.close[index])
+ decrease_y.append(self.close[index])
+ decrease_y.append(self.close[index])
+ decrease_y.append(self.high[index])
+ decrease_x.append(self.x[index])
+
+ decrease_x = [[x, x, x, x, x, x] for x in decrease_x]
+ decrease_x = utils.flatten(decrease_x)
+
+ return decrease_x, decrease_y
diff --git a/plotly/figure_factory/figure_factory/_county_choropleth.py b/plotly/figure_factory/figure_factory/_county_choropleth.py
new file mode 100644
index 00000000000..d2688439d86
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_county_choropleth.py
@@ -0,0 +1,940 @@
+from plotly import exceptions, optional_imports
+
+from plotly.figure_factory import utils
+
+import io
+import numpy as np
+import os
+import pandas as pd
+import warnings
+
+from math import log, floor
+from numbers import Number
+
+pd.options.mode.chained_assignment = None
+
+shapely = optional_imports.get_module('shapely')
+shapefile = optional_imports.get_module('shapefile')
+gp = optional_imports.get_module('geopandas')
+
+
+def _create_us_counties_df(st_to_state_name_dict, state_to_st_dict):
+ # URLS
+ abs_file_path = os.path.realpath(__file__)
+ abs_dir_path = os.path.dirname(abs_file_path)
+
+ abs_plotly_dir_path = os.path.dirname(abs_dir_path)
+
+ abs_package_data_dir_path = os.path.join(abs_plotly_dir_path,
+ 'package_data')
+
+ shape_pre2010 = 'gz_2010_us_050_00_500k.shp'
+ shape_pre2010 = os.path.join(abs_package_data_dir_path, shape_pre2010)
+
+ df_shape_pre2010 = gp.read_file(shape_pre2010)
+ df_shape_pre2010['FIPS'] = (df_shape_pre2010['STATE'] +
+ df_shape_pre2010['COUNTY'])
+ df_shape_pre2010['FIPS'] = pd.to_numeric(df_shape_pre2010['FIPS'])
+
+ states_path = 'cb_2016_us_state_500k.shp'
+ states_path = os.path.join(abs_package_data_dir_path, states_path)
+
+ # state df
+ df_state = gp.read_file(states_path)
+ df_state = df_state[['STATEFP', 'NAME', 'geometry']]
+ df_state = df_state.rename(columns={'NAME': 'STATE_NAME'})
+
+ filenames = ['cb_2016_us_county_500k.dbf',
+ 'cb_2016_us_county_500k.shp',
+ 'cb_2016_us_county_500k.shx']
+
+ for j in range(len(filenames)):
+ filenames[j] = os.path.join(abs_package_data_dir_path, filenames[j])
+
+ dbf = io.open(filenames[0], 'rb')
+ shp = io.open(filenames[1], 'rb')
+ shx = io.open(filenames[2], 'rb')
+
+ r = shapefile.Reader(shp=shp, shx=shx, dbf=dbf)
+
+ attributes, geometry = [], []
+ field_names = [field[0] for field in r.fields[1:]]
+ for row in r.shapeRecords():
+ geometry.append(shapely.geometry.shape(row.shape.__geo_interface__))
+ attributes.append(dict(zip(field_names, row.record)))
+
+ gdf = gp.GeoDataFrame(data=attributes, geometry=geometry)
+
+ gdf['FIPS'] = gdf['STATEFP'] + gdf['COUNTYFP']
+ gdf['FIPS'] = pd.to_numeric(gdf['FIPS'])
+
+ # add missing counties
+ f = 46113
+ singlerow = pd.DataFrame(
+ [
+ [st_to_state_name_dict['SD'], 'SD',
+ df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['geometry'].iloc[0],
+ df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['FIPS'].iloc[0],
+ '46', 'Shannon']
+ ],
+ columns=['State', 'ST', 'geometry', 'FIPS', 'STATEFP', 'NAME'],
+ index=[max(gdf.index) + 1]
+ )
+ gdf = gdf.append(singlerow)
+
+ f = 51515
+ singlerow = pd.DataFrame(
+ [
+ [st_to_state_name_dict['VA'], 'VA',
+ df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['geometry'].iloc[0],
+ df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['FIPS'].iloc[0],
+ '51', 'Bedford City']
+ ],
+ columns=['State', 'ST', 'geometry', 'FIPS', 'STATEFP', 'NAME'],
+ index=[max(gdf.index) + 1]
+ )
+ gdf = gdf.append(singlerow)
+
+ f = 2270
+ singlerow = pd.DataFrame(
+ [
+ [st_to_state_name_dict['AK'], 'AK',
+ df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['geometry'].iloc[0],
+ df_shape_pre2010[df_shape_pre2010['FIPS'] == f]['FIPS'].iloc[0],
+ '02', 'Wade Hampton']
+ ],
+ columns=['State', 'ST', 'geometry', 'FIPS', 'STATEFP', 'NAME'],
+ index=[max(gdf.index) + 1]
+ )
+ gdf = gdf.append(singlerow)
+
+ row_2198 = gdf[gdf['FIPS'] == 2198]
+ row_2198.index = [max(gdf.index) + 1]
+ row_2198.loc[row_2198.index[0], 'FIPS'] = 2201
+ row_2198.loc[row_2198.index[0], 'STATEFP'] = '02'
+ gdf = gdf.append(row_2198)
+
+ row_2105 = gdf[gdf['FIPS'] == 2105]
+ row_2105.index = [max(gdf.index) + 1]
+ row_2105.loc[row_2105.index[0], 'FIPS'] = 2232
+ row_2105.loc[row_2105.index[0], 'STATEFP'] = '02'
+ gdf = gdf.append(row_2105)
+ gdf = gdf.rename(columns={'NAME': 'COUNTY_NAME'})
+
+ gdf_reduced = gdf[['FIPS', 'STATEFP', 'COUNTY_NAME', 'geometry']]
+ gdf_statefp = gdf_reduced.merge(df_state[['STATEFP', 'STATE_NAME']],
+ on='STATEFP')
+
+ ST = []
+ for n in gdf_statefp['STATE_NAME']:
+ ST.append(state_to_st_dict[n])
+
+ gdf_statefp['ST'] = ST
+ return gdf_statefp, df_state
+
+
+st_to_state_name_dict = {
+ 'AK': 'Alaska',
+ 'AL': 'Alabama',
+ 'AR': 'Arkansas',
+ 'AZ': 'Arizona',
+ 'CA': 'California',
+ 'CO': 'Colorado',
+ 'CT': 'Connecticut',
+ 'DC': 'District of Columbia',
+ 'DE': 'Delaware',
+ 'FL': 'Florida',
+ 'GA': 'Georgia',
+ 'HI': 'Hawaii',
+ 'IA': 'Iowa',
+ 'ID': 'Idaho',
+ 'IL': 'Illinois',
+ 'IN': 'Indiana',
+ 'KS': 'Kansas',
+ 'KY': 'Kentucky',
+ 'LA': 'Louisiana',
+ 'MA': 'Massachusetts',
+ 'MD': 'Maryland',
+ 'ME': 'Maine',
+ 'MI': 'Michigan',
+ 'MN': 'Minnesota',
+ 'MO': 'Missouri',
+ 'MS': 'Mississippi',
+ 'MT': 'Montana',
+ 'NC': 'North Carolina',
+ 'ND': 'North Dakota',
+ 'NE': 'Nebraska',
+ 'NH': 'New Hampshire',
+ 'NJ': 'New Jersey',
+ 'NM': 'New Mexico',
+ 'NV': 'Nevada',
+ 'NY': 'New York',
+ 'OH': 'Ohio',
+ 'OK': 'Oklahoma',
+ 'OR': 'Oregon',
+ 'PA': 'Pennsylvania',
+ 'RI': 'Rhode Island',
+ 'SC': 'South Carolina',
+ 'SD': 'South Dakota',
+ 'TN': 'Tennessee',
+ 'TX': 'Texas',
+ 'UT': 'Utah',
+ 'VA': 'Virginia',
+ 'VT': 'Vermont',
+ 'WA': 'Washington',
+ 'WI': 'Wisconsin',
+ 'WV': 'West Virginia',
+ 'WY': 'Wyoming'
+}
+
+state_to_st_dict = {
+ 'Alabama': 'AL',
+ 'Alaska': 'AK',
+ 'American Samoa': 'AS',
+ 'Arizona': 'AZ',
+ 'Arkansas': 'AR',
+ 'California': 'CA',
+ 'Colorado': 'CO',
+ 'Commonwealth of the Northern Mariana Islands': 'MP',
+ 'Connecticut': 'CT',
+ 'Delaware': 'DE',
+ 'District of Columbia': 'DC',
+ 'Florida': 'FL',
+ 'Georgia': 'GA',
+ 'Guam': 'GU',
+ 'Hawaii': 'HI',
+ 'Idaho': 'ID',
+ 'Illinois': 'IL',
+ 'Indiana': 'IN',
+ 'Iowa': 'IA',
+ 'Kansas': 'KS',
+ 'Kentucky': 'KY',
+ 'Louisiana': 'LA',
+ 'Maine': 'ME',
+ 'Maryland': 'MD',
+ 'Massachusetts': 'MA',
+ 'Michigan': 'MI',
+ 'Minnesota': 'MN',
+ 'Mississippi': 'MS',
+ 'Missouri': 'MO',
+ 'Montana': 'MT',
+ 'Nebraska': 'NE',
+ 'Nevada': 'NV',
+ 'New Hampshire': 'NH',
+ 'New Jersey': 'NJ',
+ 'New Mexico': 'NM',
+ 'New York': 'NY',
+ 'North Carolina': 'NC',
+ 'North Dakota': 'ND',
+ 'Ohio': 'OH',
+ 'Oklahoma': 'OK',
+ 'Oregon': 'OR',
+ 'Pennsylvania': 'PA',
+ 'Puerto Rico': '',
+ 'Rhode Island': 'RI',
+ 'South Carolina': 'SC',
+ 'South Dakota': 'SD',
+ 'Tennessee': 'TN',
+ 'Texas': 'TX',
+ 'United States Virgin Islands': 'VI',
+ 'Utah': 'UT',
+ 'Vermont': 'VT',
+ 'Virginia': 'VA',
+ 'Washington': 'WA',
+ 'West Virginia': 'WV',
+ 'Wisconsin': 'WI',
+ 'Wyoming': 'WY'
+}
+
+USA_XRANGE = [-125.0, -65.0]
+USA_YRANGE = [25.0, 49.0]
+
+
+def _human_format(number):
+ units = ['', 'K', 'M', 'G', 'T', 'P']
+ k = 1000.0
+ magnitude = int(floor(log(number, k)))
+ return '%.2f%s' % (number / k**magnitude, units[magnitude])
+
+
+def _intervals_as_labels(array_of_intervals, round_legend_values, exponent_format):
+ """
+ Transform an number interval to a clean string for legend
+
+ Example: [-inf, 30] to '< 30'
+ """
+ infs = [float('-inf'), float('inf')]
+ string_intervals = []
+ for interval in array_of_intervals:
+ # round to 2nd decimal place
+ if round_legend_values:
+ rnd_interval = [
+ (int(interval[i]) if interval[i] not in infs else
+ interval[i])
+ for i in range(2)
+ ]
+ else:
+ rnd_interval = [round(interval[0], 2),
+ round(interval[1], 2)]
+
+ num0 = rnd_interval[0]
+ num1 = rnd_interval[1]
+ if exponent_format:
+ if num0 not in infs:
+ num0 = _human_format(num0)
+ if num1 not in infs:
+ num1 = _human_format(num1)
+ else:
+ if num0 not in infs:
+ num0 = "{:,}".format(num0)
+ if num1 not in infs:
+ num1 = "{:,}".format(num1)
+
+ if num0 == float('-inf'):
+ as_str = '< {}'.format(num1)
+ elif num1 == float('inf'):
+ as_str = '> {}'.format(num0)
+ else:
+ as_str = '{} - {}'.format(num0, num1)
+ string_intervals.append(as_str)
+ return string_intervals
+
+
+def _calculations(df, fips, values, index, f, simplify_county, level,
+ x_centroids, y_centroids, centroid_text, x_traces,
+ y_traces, fips_polygon_map):
+ # 0-pad FIPS code to ensure exactly 5 digits
+ padded_f = str(f).zfill(5)
+ if fips_polygon_map[f].type == 'Polygon':
+ x = fips_polygon_map[f].simplify(
+ simplify_county
+ ).exterior.xy[0].tolist()
+ y = fips_polygon_map[f].simplify(
+ simplify_county
+ ).exterior.xy[1].tolist()
+
+ x_c, y_c = fips_polygon_map[f].centroid.xy
+ county_name_str = str(df[df['FIPS'] == f]['COUNTY_NAME'].iloc[0])
+ state_name_str = str(df[df['FIPS'] == f]['STATE_NAME'].iloc[0])
+
+ t_c = (
+ 'County: ' + county_name_str + '
' +
+ 'State: ' + state_name_str + '
' +
+ 'FIPS: ' + padded_f + '
Value: ' + str(values[index])
+ )
+
+ x_centroids.append(x_c[0])
+ y_centroids.append(y_c[0])
+ centroid_text.append(t_c)
+
+ x_traces[level] = x_traces[level] + x + [np.nan]
+ y_traces[level] = y_traces[level] + y + [np.nan]
+ elif fips_polygon_map[f].type == 'MultiPolygon':
+ x = ([poly.simplify(simplify_county).exterior.xy[0].tolist() for
+ poly in fips_polygon_map[f]])
+ y = ([poly.simplify(simplify_county).exterior.xy[1].tolist() for
+ poly in fips_polygon_map[f]])
+
+ x_c = [poly.centroid.xy[0].tolist() for poly in fips_polygon_map[f]]
+ y_c = [poly.centroid.xy[1].tolist() for poly in fips_polygon_map[f]]
+
+ county_name_str = str(df[df['FIPS'] == f]['COUNTY_NAME'].iloc[0])
+ state_name_str = str(df[df['FIPS'] == f]['STATE_NAME'].iloc[0])
+ text = (
+ 'County: ' + county_name_str + '
' +
+ 'State: ' + state_name_str + '
' +
+ 'FIPS: ' + padded_f + '
Value: ' + str(values[index])
+ )
+ t_c = [text for poly in fips_polygon_map[f]]
+ x_centroids = x_c + x_centroids
+ y_centroids = y_c + y_centroids
+ centroid_text = t_c + centroid_text
+ for x_y_idx in range(len(x)):
+ x_traces[level] = x_traces[level] + x[x_y_idx] + [np.nan]
+ y_traces[level] = y_traces[level] + y[x_y_idx] + [np.nan]
+
+ return x_traces, y_traces, x_centroids, y_centroids, centroid_text
+
+
+def create_choropleth(fips, values, scope=['usa'], binning_endpoints=None,
+ colorscale=None, order=None, simplify_county=0.02,
+ simplify_state=0.02, asp=None, show_hover=True,
+ show_state_data=True, state_outline=None,
+ county_outline=None, centroid_marker=None,
+ round_legend_values=False, exponent_format=False,
+ legend_title='', **layout_options):
+ """
+ Returns figure for county choropleth. Uses data from package_data.
+
+ :param (list) fips: list of FIPS values which correspond to the con
+ catination of state and county ids. An example is '01001'.
+ :param (list) values: list of numbers/strings which correspond to the
+ fips list. These are the values that will determine how the counties
+ are colored.
+ :param (list) scope: list of states and/or states abbreviations. Fits
+ all states in the camera tightly. Selecting ['usa'] is the equivalent
+ of appending all 50 states into your scope list. Selecting only 'usa'
+ does not include 'Alaska', 'Puerto Rico', 'American Samoa',
+ 'Commonwealth of the Northern Mariana Islands', 'Guam',
+ 'United States Virgin Islands'. These must be added manually to the
+ list.
+ Default = ['usa']
+ :param (list) binning_endpoints: ascending numbers which implicitly define
+ real number intervals which are used as bins. The colorscale used must
+ have the same number of colors as the number of bins and this will
+ result in a categorical colormap.
+ :param (list) colorscale: a list of colors with length equal to the
+ number of categories of colors. The length must match either all
+ unique numbers in the 'values' list or if endpoints is being used, the
+ number of categories created by the endpoints.\n
+ For example, if binning_endpoints = [4, 6, 8], then there are 4 bins:
+ [-inf, 4), [4, 6), [6, 8), [8, inf)
+ :param (list) order: a list of the unique categories (numbers/bins) in any
+ desired order. This is helpful if you want to order string values to
+ a chosen colorscale.
+ :param (float) simplify_county: determines the simplification factor
+ for the counties. The larger the number, the fewer vertices and edges
+ each polygon has. See
+ http://toblerity.org/shapely/manual.html#object.simplify for more
+ information.
+ Default = 0.02
+ :param (float) simplify_state: simplifies the state outline polygon.
+ See http://toblerity.org/shapely/manual.html#object.simplify for more
+ information.
+ Default = 0.02
+ :param (float) asp: the width-to-height aspect ratio for the camera.
+ Default = 2.5
+ :param (bool) show_hover: show county hover and centroid info
+ :param (bool) show_state_data: reveals state boundary lines
+ :param (dict) state_outline: dict of attributes of the state outline
+ including width and color. See
+ https://plot.ly/python/reference/#scatter-marker-line for all valid
+ params
+ :param (dict) county_outline: dict of attributes of the county outline
+ including width and color. See
+ https://plot.ly/python/reference/#scatter-marker-line for all valid
+ params
+ :param (dict) centroid_marker: dict of attributes of the centroid marker.
+ The centroid markers are invisible by default and appear visible on
+ selection. See https://plot.ly/python/reference/#scatter-marker for
+ all valid params
+ :param (bool) round_legend_values: automatically round the numbers that
+ appear in the legend to the nearest integer.
+ Default = False
+ :param (bool) exponent_format: if set to True, puts numbers in the K, M,
+ B number format. For example 4000.0 becomes 4.0K
+ Default = False
+ :param (str) legend_title: title that appears above the legend
+ :param **layout_options: a **kwargs argument for all layout parameters
+
+
+ Example 1: Florida
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import numpy as np
+ import pandas as pd
+
+ df_sample = pd.read_csv(
+ 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv'
+ )
+ df_sample_r = df_sample[df_sample['STNAME'] == 'Florida']
+
+ values = df_sample_r['TOT_POP'].tolist()
+ fips = df_sample_r['FIPS'].tolist()
+
+ binning_endpoints = list(np.mgrid[min(values):max(values):4j])
+ colorscale = ["#030512","#1d1d3b","#323268","#3d4b94","#3e6ab0",
+ "#4989bc","#60a7c7","#85c5d3","#b7e0e4","#eafcfd"]
+ fig = ff.create_choropleth(
+ fips=fips, values=values, scope=['Florida'], show_state_data=True,
+ colorscale=colorscale, binning_endpoints=binning_endpoints,
+ round_legend_values=True, plot_bgcolor='rgb(229,229,229)',
+ paper_bgcolor='rgb(229,229,229)', legend_title='Florida Population',
+ county_outline={'color': 'rgb(255,255,255)', 'width': 0.5},
+ exponent_format=True,
+ )
+ py.iplot(fig, filename='choropleth_florida')
+ ```
+
+ Example 2: New England
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import pandas as pd
+
+ NE_states = ['Connecticut', 'Maine', 'Massachusetts',
+ 'New Hampshire', 'Rhode Island']
+ df_sample = pd.read_csv(
+ 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv'
+ )
+ df_sample_r = df_sample[df_sample['STNAME'].isin(NE_states)]
+ colorscale = ['rgb(68.0, 1.0, 84.0)',
+ 'rgb(66.0, 64.0, 134.0)',
+ 'rgb(38.0, 130.0, 142.0)',
+ 'rgb(63.0, 188.0, 115.0)',
+ 'rgb(216.0, 226.0, 25.0)']
+
+ values = df_sample_r['TOT_POP'].tolist()
+ fips = df_sample_r['FIPS'].tolist()
+ fig = ff.create_choropleth(
+ fips=fips, values=values, scope=NE_states, show_state_data=True
+ )
+ py.iplot(fig, filename='choropleth_new_england')
+ ```
+
+ Example 3: California and Surrounding States
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import pandas as pd
+
+ df_sample = pd.read_csv(
+ 'https://raw.githubusercontent.com/plotly/datasets/master/minoritymajority.csv'
+ )
+ df_sample_r = df_sample[df_sample['STNAME'] == 'California']
+
+ values = df_sample_r['TOT_POP'].tolist()
+ fips = df_sample_r['FIPS'].tolist()
+
+ colorscale = [
+ 'rgb(193, 193, 193)',
+ 'rgb(239,239,239)',
+ 'rgb(195, 196, 222)',
+ 'rgb(144,148,194)',
+ 'rgb(101,104,168)',
+ 'rgb(65, 53, 132)'
+ ]
+
+ fig = ff.create_choropleth(
+ fips=fips, values=values, colorscale=colorscale,
+ scope=['CA', 'AZ', 'Nevada', 'Oregon', ' Idaho'],
+ binning_endpoints=[14348, 63983, 134827, 426762, 2081313],
+ county_outline={'color': 'rgb(255,255,255)', 'width': 0.5},
+ legend_title='California Counties',
+ title='California and Nearby States'
+ )
+ py.iplot(fig, filename='choropleth_california_and_surr_states_outlines')
+ ```
+
+ Example 4: USA
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import numpy as np
+ import pandas as pd
+
+ df_sample = pd.read_csv(
+ 'https://raw.githubusercontent.com/plotly/datasets/master/laucnty16.csv'
+ )
+ df_sample['State FIPS Code'] = df_sample['State FIPS Code'].apply(
+ lambda x: str(x).zfill(2)
+ )
+ df_sample['County FIPS Code'] = df_sample['County FIPS Code'].apply(
+ lambda x: str(x).zfill(3)
+ )
+ df_sample['FIPS'] = (
+ df_sample['State FIPS Code'] + df_sample['County FIPS Code']
+ )
+
+ binning_endpoints = list(np.linspace(1, 12, len(colorscale) - 1))
+ colorscale = ["#f7fbff", "#ebf3fb", "#deebf7", "#d2e3f3", "#c6dbef",
+ "#b3d2e9", "#9ecae1", "#85bcdb", "#6baed6", "#57a0ce",
+ "#4292c6", "#3082be", "#2171b5", "#1361a9", "#08519c",
+ "#0b4083","#08306b"]
+ fips = df_sample['FIPS']
+ values = df_sample['Unemployment Rate (%)']
+ fig = ff.create_choropleth(
+ fips=fips, values=values, scope=['usa'],
+ binning_endpoints=binning_endpoints, colorscale=colorscale,
+ show_hover=True, centroid_marker={'opacity': 0},
+ asp=2.9, title='USA by Unemployment %',
+ legend_title='Unemployment %'
+ )
+
+ py.iplot(fig, filename='choropleth_full_usa')
+ ```
+ """
+ # ensure optional modules imported
+ if not gp or not shapefile or not shapely:
+ raise ImportError(
+ "geopandas, pyshp and shapely must be installed for this figure "
+ "factory.\n\nRun the following commands to install the correct "
+ "versions of the following modules:\n\n"
+ "```\n"
+ "pip install geopandas==0.3.0\n"
+ "pip install pyshp==1.2.10\n"
+ "pip install shapely==1.6.3\n"
+ "```\n"
+ "If you are using Windows, follow this post to properly "
+ "install geopandas and dependencies:"
+ "http://geoffboeing.com/2014/09/using-geopandas-windows/\n\n"
+ "If you are using Anaconda, do not use PIP to install the "
+ "packages above. Instead use conda to install them:\n\n"
+ "```\n"
+ "conda install plotly\n"
+ "conda install geopandas\n"
+ "```"
+ )
+
+ df, df_state = _create_us_counties_df(st_to_state_name_dict,
+ state_to_st_dict)
+
+ fips_polygon_map = dict(
+ zip(
+ df['FIPS'].tolist(),
+ df['geometry'].tolist()
+ )
+ )
+
+ if not state_outline:
+ state_outline = {'color': 'rgb(240, 240, 240)',
+ 'width': 1}
+ if not county_outline:
+ county_outline = {'color': 'rgb(0, 0, 0)',
+ 'width': 0}
+ if not centroid_marker:
+ centroid_marker = {'size': 3, 'color': 'white', 'opacity': 1}
+
+ # ensure centroid markers appear on selection
+ if 'opacity' not in centroid_marker:
+ centroid_marker.update({'opacity': 1})
+
+ if len(fips) != len(values):
+ raise exceptions.PlotlyError(
+ 'fips and values must be the same length'
+ )
+
+ # make fips, values into lists
+ if isinstance(fips, pd.core.series.Series):
+ fips = fips.tolist()
+ if isinstance(values, pd.core.series.Series):
+ values = values.tolist()
+
+ # make fips numeric
+ fips = map(lambda x: int(x), fips)
+
+ if binning_endpoints:
+ intervals = utils.endpts_to_intervals(binning_endpoints)
+ LEVELS = _intervals_as_labels(intervals, round_legend_values,
+ exponent_format)
+ else:
+ if not order:
+ LEVELS = sorted(list(set(values)))
+ else:
+ # check if order is permutation
+ # of unique color col values
+ same_sets = sorted(list(set(values))) == set(order)
+ no_duplicates = not any(order.count(x) > 1 for x in order)
+ if same_sets and no_duplicates:
+ LEVELS = order
+ else:
+ raise exceptions.PlotlyError(
+ 'if you are using a custom order of unique values from '
+ 'your color column, you must: have all the unique values '
+ 'in your order and have no duplicate items'
+ )
+
+ if not colorscale:
+ colorscale = []
+ viridis_colors = utils.colorscale_to_colors(
+ utils.PLOTLY_SCALES['Viridis']
+ )
+ viridis_colors = utils.color_parser(
+ viridis_colors, utils.hex_to_rgb
+ )
+ viridis_colors = utils.color_parser(
+ viridis_colors, utils.label_rgb
+ )
+ viri_len = len(viridis_colors) + 1
+ viri_intervals = utils.endpts_to_intervals(
+ list(np.linspace(0, 1, viri_len))
+ )[1:-1]
+
+ for L in np.linspace(0, 1, len(LEVELS)):
+ for idx, inter in enumerate(viri_intervals):
+ if L == 0:
+ break
+ elif inter[0] < L <= inter[1]:
+ break
+
+ intermed = ((L - viri_intervals[idx][0]) /
+ (viri_intervals[idx][1] - viri_intervals[idx][0]))
+
+ float_color = utils.find_intermediate_color(
+ viridis_colors[idx],
+ viridis_colors[idx],
+ intermed,
+ colortype='rgb'
+ )
+
+ # make R,G,B into int values
+ float_color = utils.unlabel_rgb(float_color)
+ float_color = utils.unconvert_from_RGB_255(float_color)
+ int_rgb = utils.convert_to_RGB_255(float_color)
+ int_rgb = utils.label_rgb(int_rgb)
+
+ colorscale.append(int_rgb)
+
+ if len(colorscale) < len(LEVELS):
+ raise exceptions.PlotlyError(
+ "You have {} LEVELS. Your number of colors in 'colorscale' must "
+ "be at least the number of LEVELS: {}. If you are "
+ "using 'binning_endpoints' then 'colorscale' must have at "
+ "least len(binning_endpoints) + 2 colors".format(
+ len(LEVELS), min(LEVELS, LEVELS[:20])
+ )
+ )
+
+ color_lookup = dict(zip(LEVELS, colorscale))
+ x_traces = dict(zip(LEVELS, [[] for i in range(len(LEVELS))]))
+ y_traces = dict(zip(LEVELS, [[] for i in range(len(LEVELS))]))
+
+ # scope
+ if isinstance(scope, str):
+ raise exceptions.PlotlyError(
+ "'scope' must be a list/tuple/sequence"
+ )
+
+ scope_names = []
+ extra_states = ['Alaska', 'Commonwealth of the Northern Mariana Islands',
+ 'Puerto Rico', 'Guam', 'United States Virgin Islands',
+ 'American Samoa']
+ for state in scope:
+ if state.lower() == 'usa':
+ scope_names = df['STATE_NAME'].unique()
+ scope_names = list(scope_names)
+ for ex_st in extra_states:
+ try:
+ scope_names.remove(ex_st)
+ except ValueError:
+ pass
+ else:
+ if state in st_to_state_name_dict.keys():
+ state = st_to_state_name_dict[state]
+ scope_names.append(state)
+ df_state = df_state[df_state['STATE_NAME'].isin(scope_names)]
+
+ plot_data = []
+ x_centroids = []
+ y_centroids = []
+ centroid_text = []
+ fips_not_in_shapefile = []
+ if not binning_endpoints:
+ for index, f in enumerate(fips):
+ level = values[index]
+ try:
+ fips_polygon_map[f].type
+
+ (x_traces, y_traces, x_centroids,
+ y_centroids, centroid_text) = _calculations(
+ df, fips, values, index, f, simplify_county, level,
+ x_centroids, y_centroids, centroid_text, x_traces,
+ y_traces, fips_polygon_map
+ )
+ except KeyError:
+ fips_not_in_shapefile.append(f)
+
+ else:
+ for index, f in enumerate(fips):
+ for j, inter in enumerate(intervals):
+ if inter[0] < values[index] <= inter[1]:
+ break
+ level = LEVELS[j]
+
+ try:
+ fips_polygon_map[f].type
+
+ (x_traces, y_traces, x_centroids,
+ y_centroids, centroid_text) = _calculations(
+ df, fips, values, index, f, simplify_county, level,
+ x_centroids, y_centroids, centroid_text, x_traces,
+ y_traces, fips_polygon_map
+ )
+ except KeyError:
+ fips_not_in_shapefile.append(f)
+
+ if len(fips_not_in_shapefile) > 0:
+ msg = (
+ 'Unrecognized FIPS Values\n\nWhoops! It looks like you are '
+ 'trying to pass at least one FIPS value that is not in '
+ 'our shapefile of FIPS and data for the counties. Your '
+ 'choropleth will still show up but these counties cannot '
+ 'be shown.\nUnrecognized FIPS are: {}'.format(
+ fips_not_in_shapefile
+ )
+ )
+ warnings.warn(msg)
+
+ x_states = []
+ y_states = []
+ for index, row in df_state.iterrows():
+ if df_state['geometry'][index].type == 'Polygon':
+ x = row.geometry.simplify(simplify_state).exterior.xy[0].tolist()
+ y = row.geometry.simplify(simplify_state).exterior.xy[1].tolist()
+ x_states = x_states + x
+ y_states = y_states + y
+ elif df_state['geometry'][index].type == 'MultiPolygon':
+ x = ([poly.simplify(simplify_state).exterior.xy[0].tolist() for
+ poly in df_state['geometry'][index]])
+ y = ([poly.simplify(simplify_state).exterior.xy[1].tolist() for
+ poly in df_state['geometry'][index]])
+ for segment in range(len(x)):
+ x_states = x_states + x[segment]
+ y_states = y_states + y[segment]
+ x_states.append(np.nan)
+ y_states.append(np.nan)
+ x_states.append(np.nan)
+ y_states.append(np.nan)
+
+ for lev in LEVELS:
+ county_data = dict(
+ type='scatter',
+ mode='lines',
+ x=x_traces[lev],
+ y=y_traces[lev],
+ line=county_outline,
+ fill='toself',
+ fillcolor=color_lookup[lev],
+ name=lev,
+ hoverinfo='none',
+ )
+ plot_data.append(county_data)
+
+ if show_hover:
+ hover_points = dict(
+ type='scatter',
+ showlegend=False,
+ legendgroup='centroids',
+ x=x_centroids,
+ y=y_centroids,
+ text=centroid_text,
+ name='US Counties',
+ mode='markers',
+ marker={'color': 'white', 'opacity': 0},
+ hoverinfo='text'
+ )
+ centroids_on_select = dict(
+ selected=dict(marker=centroid_marker),
+ unselected=dict(marker=dict(opacity=0))
+ )
+ hover_points.update(centroids_on_select)
+ plot_data.append(hover_points)
+
+ if show_state_data:
+ state_data = dict(
+ type='scatter',
+ legendgroup='States',
+ line=state_outline,
+ x=x_states,
+ y=y_states,
+ hoverinfo='text',
+ showlegend=False,
+ mode='lines'
+ )
+ plot_data.append(state_data)
+
+ DEFAULT_LAYOUT = dict(
+ hovermode='closest',
+ xaxis=dict(
+ autorange=False,
+ range=USA_XRANGE,
+ showgrid=False,
+ zeroline=False,
+ fixedrange=True,
+ showticklabels=False
+ ),
+ yaxis=dict(
+ autorange=False,
+ range=USA_YRANGE,
+ showgrid=False,
+ zeroline=False,
+ fixedrange=True,
+ showticklabels=False
+ ),
+ margin=dict(t=40, b=20, r=20, l=20),
+ width=900,
+ height=450,
+ dragmode='select',
+ legend=dict(
+ traceorder='reversed',
+ xanchor='right',
+ yanchor='top',
+ x=1,
+ y=1
+ ),
+ annotations=[]
+ )
+ fig = dict(data=plot_data, layout=DEFAULT_LAYOUT)
+ fig['layout'].update(layout_options)
+ fig['layout']['annotations'].append(
+ dict(
+ x=1,
+ y=1.05,
+ xref='paper',
+ yref='paper',
+ xanchor='right',
+ showarrow=False,
+ text='' + legend_title + ''
+ )
+ )
+
+ if len(scope) == 1 and scope[0].lower() == 'usa':
+ xaxis_range_low = -125.0
+ xaxis_range_high = -55.0
+ yaxis_range_low = 25.0
+ yaxis_range_high = 49.0
+ else:
+ xaxis_range_low = float('inf')
+ xaxis_range_high = float('-inf')
+ yaxis_range_low = float('inf')
+ yaxis_range_high = float('-inf')
+ for trace in fig['data']:
+ if all(isinstance(n, Number) for n in trace['x']):
+ calc_x_min = min(trace['x'] or [float('inf')])
+ calc_x_max = max(trace['x'] or [float('-inf')])
+ if calc_x_min < xaxis_range_low:
+ xaxis_range_low = calc_x_min
+ if calc_x_max > xaxis_range_high:
+ xaxis_range_high = calc_x_max
+ if all(isinstance(n, Number) for n in trace['y']):
+ calc_y_min = min(trace['y'] or [float('inf')])
+ calc_y_max = max(trace['y'] or [float('-inf')])
+ if calc_y_min < yaxis_range_low:
+ yaxis_range_low = calc_y_min
+ if calc_y_max > yaxis_range_high:
+ yaxis_range_high = calc_y_max
+
+ # camera zoom
+ fig['layout']['xaxis']['range'] = [xaxis_range_low, xaxis_range_high]
+ fig['layout']['yaxis']['range'] = [yaxis_range_low, yaxis_range_high]
+
+ # aspect ratio
+ if asp is None:
+ usa_x_range = USA_XRANGE[1] - USA_XRANGE[0]
+ usa_y_range = USA_YRANGE[1] - USA_YRANGE[0]
+ asp = usa_x_range / usa_y_range
+
+ # based on your figure
+ width = float(fig['layout']['xaxis']['range'][1] -
+ fig['layout']['xaxis']['range'][0])
+ height = float(fig['layout']['yaxis']['range'][1] -
+ fig['layout']['yaxis']['range'][0])
+
+ center = (sum(fig['layout']['xaxis']['range']) / 2.,
+ sum(fig['layout']['yaxis']['range']) / 2.)
+
+ if height / width > (1 / asp):
+ new_width = asp * height
+ fig['layout']['xaxis']['range'][0] = center[0] - new_width * 0.5
+ fig['layout']['xaxis']['range'][1] = center[0] + new_width * 0.5
+ else:
+ new_height = (1 / asp) * width
+ fig['layout']['yaxis']['range'][0] = center[1] - new_height * 0.5
+ fig['layout']['yaxis']['range'][1] = center[1] + new_height * 0.5
+
+ return fig
diff --git a/plotly/figure_factory/figure_factory/_dendrogram.py b/plotly/figure_factory/figure_factory/_dendrogram.py
new file mode 100644
index 00000000000..4bafc976c1d
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_dendrogram.py
@@ -0,0 +1,330 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import absolute_import
+
+from collections import OrderedDict
+
+from plotly import exceptions, optional_imports
+from plotly.graph_objs import graph_objs
+
+# Optional imports, may be None for users that only use our core functionality.
+np = optional_imports.get_module('numpy')
+scp = optional_imports.get_module('scipy')
+sch = optional_imports.get_module('scipy.cluster.hierarchy')
+scs = optional_imports.get_module('scipy.spatial')
+
+
+def create_dendrogram(X, orientation="bottom", labels=None,
+ colorscale=None, distfun=None,
+ linkagefun=lambda x: sch.linkage(x, 'complete'),
+ hovertext=None, color_threshold=None):
+ """
+ BETA function that returns a dendrogram Plotly figure object.
+
+ :param (ndarray) X: Matrix of observations as array of arrays
+ :param (str) orientation: 'top', 'right', 'bottom', or 'left'
+ :param (list) labels: List of axis category labels(observation labels)
+ :param (list) colorscale: Optional colorscale for dendrogram tree
+ :param (function) distfun: Function to compute the pairwise distance from
+ the observations
+ :param (function) linkagefun: Function to compute the linkage matrix from
+ the pairwise distances
+ :param (list[list]) hovertext: List of hovertext for constituent traces of dendrogram
+ clusters
+ :param (double) color_threshold: Value at which the separation of clusters will be made
+
+ Example 1: Simple bottom oriented dendrogram
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_dendrogram
+
+ import numpy as np
+
+ X = np.random.rand(10,10)
+ dendro = create_dendrogram(X)
+ plot_url = py.plot(dendro, filename='simple-dendrogram')
+
+ ```
+
+ Example 2: Dendrogram to put on the left of the heatmap
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_dendrogram
+
+ import numpy as np
+
+ X = np.random.rand(5,5)
+ names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark']
+ dendro = create_dendrogram(X, orientation='right', labels=names)
+ dendro['layout'].update({'width':700, 'height':500})
+
+ py.iplot(dendro, filename='vertical-dendrogram')
+ ```
+
+ Example 3: Dendrogram with Pandas
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_dendrogram
+
+ import numpy as np
+ import pandas as pd
+
+ Index= ['A','B','C','D','E','F','G','H','I','J']
+ df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index)
+ fig = create_dendrogram(df, labels=Index)
+ url = py.plot(fig, filename='pandas-dendrogram')
+ ```
+ """
+ if not scp or not scs or not sch:
+ raise ImportError("FigureFactory.create_dendrogram requires scipy, \
+ scipy.spatial and scipy.hierarchy")
+
+ s = X.shape
+ if len(s) != 2:
+ exceptions.PlotlyError("X should be 2-dimensional array.")
+
+ if distfun is None:
+ distfun = scs.distance.pdist
+
+ dendrogram = _Dendrogram(X, orientation, labels, colorscale,
+ distfun=distfun, linkagefun=linkagefun,
+ hovertext=hovertext, color_threshold=color_threshold)
+
+ return graph_objs.Figure(data=dendrogram.data,
+ layout=dendrogram.layout)
+
+
+class _Dendrogram(object):
+ """Refer to FigureFactory.create_dendrogram() for docstring."""
+
+ def __init__(self, X, orientation='bottom', labels=None, colorscale=None,
+ width=np.inf, height=np.inf, xaxis='xaxis', yaxis='yaxis',
+ distfun=None,
+ linkagefun=lambda x: sch.linkage(x, 'complete'),
+ hovertext=None, color_threshold=None):
+ self.orientation = orientation
+ self.labels = labels
+ self.xaxis = xaxis
+ self.yaxis = yaxis
+ self.data = []
+ self.leaves = []
+ self.sign = {self.xaxis: 1, self.yaxis: 1}
+ self.layout = {self.xaxis: {}, self.yaxis: {}}
+
+ if self.orientation in ['left', 'bottom']:
+ self.sign[self.xaxis] = 1
+ else:
+ self.sign[self.xaxis] = -1
+
+ if self.orientation in ['right', 'bottom']:
+ self.sign[self.yaxis] = 1
+ else:
+ self.sign[self.yaxis] = -1
+
+ if distfun is None:
+ distfun = scs.distance.pdist
+
+ (dd_traces, xvals, yvals,
+ ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale,
+ distfun,
+ linkagefun,
+ hovertext,
+ color_threshold)
+
+ self.labels = ordered_labels
+ self.leaves = leaves
+ yvals_flat = yvals.flatten()
+ xvals_flat = xvals.flatten()
+
+ self.zero_vals = []
+
+ for i in range(len(yvals_flat)):
+ if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals:
+ self.zero_vals.append(xvals_flat[i])
+
+ if len(self.zero_vals) > len(yvals) + 1:
+ # If the length of zero_vals is larger than the length of yvals,
+ # it means that there are wrong vals because of the identicial samples.
+ # Three and more identicial samples will make the yvals of spliting center into 0 and it will \
+ # accidentally take it as leaves.
+ l_border = int(min(self.zero_vals))
+ r_border = int(max(self.zero_vals))
+ correct_leaves_pos = range(l_border,
+ r_border + 1,
+ int((r_border - l_border) / len(yvals)))
+ # Regenerating the leaves pos from the self.zero_vals with equally intervals.
+ self.zero_vals = [v for v in correct_leaves_pos]
+
+ self.zero_vals.sort()
+ self.layout = self.set_figure_layout(width, height)
+ self.data = dd_traces
+
+ def get_color_dict(self, colorscale):
+ """
+ Returns colorscale used for dendrogram tree clusters.
+
+ :param (list) colorscale: Colors to use for the plot in rgb format.
+ :rtype (dict): A dict of default colors mapped to the user colorscale.
+
+ """
+
+ # These are the color codes returned for dendrograms
+ # We're replacing them with nicer colors
+ d = {'r': 'red',
+ 'g': 'green',
+ 'b': 'blue',
+ 'c': 'cyan',
+ 'm': 'magenta',
+ 'y': 'yellow',
+ 'k': 'black',
+ 'w': 'white'}
+ default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0]))
+
+ if colorscale is None:
+ colorscale = [
+ 'rgb(0,116,217)', # blue
+ 'rgb(35,205,205)', # cyan
+ 'rgb(61,153,112)', # green
+ 'rgb(40,35,35)', # black
+ 'rgb(133,20,75)', # magenta
+ 'rgb(255,65,54)', # red
+ 'rgb(255,255,255)', # white
+ 'rgb(255,220,0)'] # yellow
+
+ for i in range(len(default_colors.keys())):
+ k = list(default_colors.keys())[i] # PY3 won't index keys
+ if i < len(colorscale):
+ default_colors[k] = colorscale[i]
+
+ return default_colors
+
+ def set_axis_layout(self, axis_key):
+ """
+ Sets and returns default axis object for dendrogram figure.
+
+ :param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc.
+ :rtype (dict): An axis_key dictionary with set parameters.
+
+ """
+ axis_defaults = {
+ 'type': 'linear',
+ 'ticks': 'outside',
+ 'mirror': 'allticks',
+ 'rangemode': 'tozero',
+ 'showticklabels': True,
+ 'zeroline': False,
+ 'showgrid': False,
+ 'showline': True,
+ }
+
+ if len(self.labels) != 0:
+ axis_key_labels = self.xaxis
+ if self.orientation in ['left', 'right']:
+ axis_key_labels = self.yaxis
+ if axis_key_labels not in self.layout:
+ self.layout[axis_key_labels] = {}
+ self.layout[axis_key_labels]['tickvals'] = \
+ [zv*self.sign[axis_key] for zv in self.zero_vals]
+ self.layout[axis_key_labels]['ticktext'] = self.labels
+ self.layout[axis_key_labels]['tickmode'] = 'array'
+
+ self.layout[axis_key].update(axis_defaults)
+
+ return self.layout[axis_key]
+
+ def set_figure_layout(self, width, height):
+ """
+ Sets and returns default layout object for dendrogram figure.
+
+ """
+ self.layout.update({
+ 'showlegend': False,
+ 'autosize': False,
+ 'hovermode': 'closest',
+ 'width': width,
+ 'height': height
+ })
+
+ self.set_axis_layout(self.xaxis)
+ self.set_axis_layout(self.yaxis)
+
+ return self.layout
+
+ def get_dendrogram_traces(self, X, colorscale, distfun, linkagefun, hovertext, color_threshold):
+ """
+ Calculates all the elements needed for plotting a dendrogram.
+
+ :param (ndarray) X: Matrix of observations as array of arrays
+ :param (list) colorscale: Color scale for dendrogram tree clusters
+ :param (function) distfun: Function to compute the pairwise distance
+ from the observations
+ :param (function) linkagefun: Function to compute the linkage matrix
+ from the pairwise distances
+ :param (list) hovertext: List of hovertext for constituent traces of dendrogram
+ :rtype (tuple): Contains all the traces in the following order:
+ (a) trace_list: List of Plotly trace objects for dendrogram tree
+ (b) icoord: All X points of the dendrogram tree as array of arrays
+ with length 4
+ (c) dcoord: All Y points of the dendrogram tree as array of arrays
+ with length 4
+ (d) ordered_labels: leaf labels in the order they are going to
+ appear on the plot
+ (e) P['leaves']: left-to-right traversal of the leaves
+
+ """
+ d = distfun(X)
+ Z = linkagefun(d)
+ P = sch.dendrogram(Z, orientation=self.orientation,
+ labels=self.labels, no_plot=True,
+ color_threshold=color_threshold)
+
+ icoord = scp.array(P['icoord'])
+ dcoord = scp.array(P['dcoord'])
+ ordered_labels = scp.array(P['ivl'])
+ color_list = scp.array(P['color_list'])
+ colors = self.get_color_dict(colorscale)
+
+ trace_list = []
+
+ for i in range(len(icoord)):
+ # xs and ys are arrays of 4 points that make up the '∩' shapes
+ # of the dendrogram tree
+ if self.orientation in ['top', 'bottom']:
+ xs = icoord[i]
+ else:
+ xs = dcoord[i]
+
+ if self.orientation in ['top', 'bottom']:
+ ys = dcoord[i]
+ else:
+ ys = icoord[i]
+ color_key = color_list[i]
+ hovertext_label = None
+ if hovertext:
+ hovertext_label = hovertext[i]
+ trace = dict(
+ type='scatter',
+ x=np.multiply(self.sign[self.xaxis], xs),
+ y=np.multiply(self.sign[self.yaxis], ys),
+ mode='lines',
+ marker=dict(color=colors[color_key]),
+ text=hovertext_label,
+ hoverinfo='text'
+ )
+
+ try:
+ x_index = int(self.xaxis[-1])
+ except ValueError:
+ x_index = ''
+
+ try:
+ y_index = int(self.yaxis[-1])
+ except ValueError:
+ y_index = ''
+
+ trace['xaxis'] = 'x' + x_index
+ trace['yaxis'] = 'y' + y_index
+
+ trace_list.append(trace)
+
+ return trace_list, icoord, dcoord, ordered_labels, P['leaves']
diff --git a/plotly/figure_factory/figure_factory/_distplot.py b/plotly/figure_factory/figure_factory/_distplot.py
new file mode 100644
index 00000000000..0d88847eeb4
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_distplot.py
@@ -0,0 +1,390 @@
+from __future__ import absolute_import
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+
+# Optional imports, may be None for users that only use our core functionality.
+np = optional_imports.get_module('numpy')
+pd = optional_imports.get_module('pandas')
+scipy = optional_imports.get_module('scipy')
+scipy_stats = optional_imports.get_module('scipy.stats')
+
+
+DEFAULT_HISTNORM = 'probability density'
+ALTERNATIVE_HISTNORM = 'probability'
+
+
+def validate_distplot(hist_data, curve_type):
+ """
+ Distplot-specific validations
+
+ :raises: (PlotlyError) If hist_data is not a list of lists
+ :raises: (PlotlyError) If curve_type is not valid (i.e. not 'kde' or
+ 'normal').
+ """
+ hist_data_types = (list,)
+ if np:
+ hist_data_types += (np.ndarray,)
+ if pd:
+ hist_data_types += (pd.core.series.Series,)
+
+ if not isinstance(hist_data[0], hist_data_types):
+ raise exceptions.PlotlyError("Oops, this function was written "
+ "to handle multiple datasets, if "
+ "you want to plot just one, make "
+ "sure your hist_data variable is "
+ "still a list of lists, i.e. x = "
+ "[1, 2, 3] -> x = [[1, 2, 3]]")
+
+ curve_opts = ('kde', 'normal')
+ if curve_type not in curve_opts:
+ raise exceptions.PlotlyError("curve_type must be defined as "
+ "'kde' or 'normal'")
+
+ if not scipy:
+ raise ImportError("FigureFactory.create_distplot requires scipy")
+
+
+def create_distplot(hist_data, group_labels, bin_size=1., curve_type='kde',
+ colors=None, rug_text=None, histnorm=DEFAULT_HISTNORM,
+ show_hist=True, show_curve=True, show_rug=True):
+ """
+ BETA function that creates a distplot similar to seaborn.distplot
+
+ The distplot can be composed of all or any combination of the following
+ 3 components: (1) histogram, (2) curve: (a) kernel density estimation
+ or (b) normal curve, and (3) rug plot. Additionally, multiple distplots
+ (from multiple datasets) can be created in the same plot.
+
+ :param (list[list]) hist_data: Use list of lists to plot multiple data
+ sets on the same plot.
+ :param (list[str]) group_labels: Names for each data set.
+ :param (list[float]|float) bin_size: Size of histogram bins.
+ Default = 1.
+ :param (str) curve_type: 'kde' or 'normal'. Default = 'kde'
+ :param (str) histnorm: 'probability density' or 'probability'
+ Default = 'probability density'
+ :param (bool) show_hist: Add histogram to distplot? Default = True
+ :param (bool) show_curve: Add curve to distplot? Default = True
+ :param (bool) show_rug: Add rug to distplot? Default = True
+ :param (list[str]) colors: Colors for traces.
+ :param (list[list]) rug_text: Hovertext values for rug_plot,
+ :return (dict): Representation of a distplot figure.
+
+ Example 1: Simple distplot of 1 data set
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_distplot
+
+ hist_data = [[1.1, 1.1, 2.5, 3.0, 3.5,
+ 3.5, 4.1, 4.4, 4.5, 4.5,
+ 5.0, 5.0, 5.2, 5.5, 5.5,
+ 5.5, 5.5, 5.5, 6.1, 7.0]]
+
+ group_labels = ['distplot example']
+
+ fig = create_distplot(hist_data, group_labels)
+
+ url = py.plot(fig, filename='Simple distplot', validate=False)
+ ```
+
+ Example 2: Two data sets and added rug text
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_distplot
+
+ # Add histogram data
+ hist1_x = [0.8, 1.2, 0.2, 0.6, 1.6,
+ -0.9, -0.07, 1.95, 0.9, -0.2,
+ -0.5, 0.3, 0.4, -0.37, 0.6]
+ hist2_x = [0.8, 1.5, 1.5, 0.6, 0.59,
+ 1.0, 0.8, 1.7, 0.5, 0.8,
+ -0.3, 1.2, 0.56, 0.3, 2.2]
+
+ # Group data together
+ hist_data = [hist1_x, hist2_x]
+
+ group_labels = ['2012', '2013']
+
+ # Add text
+ rug_text_1 = ['a1', 'b1', 'c1', 'd1', 'e1',
+ 'f1', 'g1', 'h1', 'i1', 'j1',
+ 'k1', 'l1', 'm1', 'n1', 'o1']
+
+ rug_text_2 = ['a2', 'b2', 'c2', 'd2', 'e2',
+ 'f2', 'g2', 'h2', 'i2', 'j2',
+ 'k2', 'l2', 'm2', 'n2', 'o2']
+
+ # Group text together
+ rug_text_all = [rug_text_1, rug_text_2]
+
+ # Create distplot
+ fig = create_distplot(
+ hist_data, group_labels, rug_text=rug_text_all, bin_size=.2)
+
+ # Add title
+ fig['layout'].update(title='Dist Plot')
+
+ # Plot!
+ url = py.plot(fig, filename='Distplot with rug text', validate=False)
+ ```
+
+ Example 3: Plot with normal curve and hide rug plot
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_distplot
+ import numpy as np
+
+ x1 = np.random.randn(190)
+ x2 = np.random.randn(200)+1
+ x3 = np.random.randn(200)-1
+ x4 = np.random.randn(210)+2
+
+ hist_data = [x1, x2, x3, x4]
+ group_labels = ['2012', '2013', '2014', '2015']
+
+ fig = create_distplot(
+ hist_data, group_labels, curve_type='normal',
+ show_rug=False, bin_size=.4)
+
+ url = py.plot(fig, filename='hist and normal curve', validate=False)
+
+ Example 4: Distplot with Pandas
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_distplot
+ import numpy as np
+ import pandas as pd
+
+ df = pd.DataFrame({'2012': np.random.randn(200),
+ '2013': np.random.randn(200)+1})
+ py.iplot(create_distplot([df[c] for c in df.columns], df.columns),
+ filename='examples/distplot with pandas',
+ validate=False)
+ ```
+ """
+ if colors is None:
+ colors = []
+ if rug_text is None:
+ rug_text = []
+
+ validate_distplot(hist_data, curve_type)
+ utils.validate_equal_length(hist_data, group_labels)
+
+ if isinstance(bin_size, (float, int)):
+ bin_size = [bin_size] * len(hist_data)
+
+ hist = _Distplot(
+ hist_data, histnorm, group_labels, bin_size,
+ curve_type, colors, rug_text,
+ show_hist, show_curve).make_hist()
+
+ if curve_type == 'normal':
+ curve = _Distplot(
+ hist_data, histnorm, group_labels, bin_size,
+ curve_type, colors, rug_text,
+ show_hist, show_curve).make_normal()
+ else:
+ curve = _Distplot(
+ hist_data, histnorm, group_labels, bin_size,
+ curve_type, colors, rug_text,
+ show_hist, show_curve).make_kde()
+
+ rug = _Distplot(
+ hist_data, histnorm, group_labels, bin_size,
+ curve_type, colors, rug_text,
+ show_hist, show_curve).make_rug()
+
+ data = []
+ if show_hist:
+ data.append(hist)
+ if show_curve:
+ data.append(curve)
+ if show_rug:
+ data.append(rug)
+ layout = graph_objs.Layout(
+ barmode='overlay',
+ hovermode='closest',
+ legend=dict(traceorder='reversed'),
+ xaxis1=dict(domain=[0.0, 1.0],
+ anchor='y2',
+ zeroline=False),
+ yaxis1=dict(domain=[0.35, 1],
+ anchor='free',
+ position=0.0),
+ yaxis2=dict(domain=[0, 0.25],
+ anchor='x1',
+ dtick=1,
+ showticklabels=False))
+ else:
+ layout = graph_objs.Layout(
+ barmode='overlay',
+ hovermode='closest',
+ legend=dict(traceorder='reversed'),
+ xaxis1=dict(domain=[0.0, 1.0],
+ anchor='y2',
+ zeroline=False),
+ yaxis1=dict(domain=[0., 1],
+ anchor='free',
+ position=0.0))
+
+ data = sum(data, [])
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Distplot(object):
+ """
+ Refer to TraceFactory.create_distplot() for docstring
+ """
+ def __init__(self, hist_data, histnorm, group_labels,
+ bin_size, curve_type, colors,
+ rug_text, show_hist, show_curve):
+ self.hist_data = hist_data
+ self.histnorm = histnorm
+ self.group_labels = group_labels
+ self.bin_size = bin_size
+ self.show_hist = show_hist
+ self.show_curve = show_curve
+ self.trace_number = len(hist_data)
+ if rug_text:
+ self.rug_text = rug_text
+ else:
+ self.rug_text = [None] * self.trace_number
+
+ self.start = []
+ self.end = []
+ if colors:
+ self.colors = colors
+ else:
+ self.colors = [
+ "rgb(31, 119, 180)", "rgb(255, 127, 14)",
+ "rgb(44, 160, 44)", "rgb(214, 39, 40)",
+ "rgb(148, 103, 189)", "rgb(140, 86, 75)",
+ "rgb(227, 119, 194)", "rgb(127, 127, 127)",
+ "rgb(188, 189, 34)", "rgb(23, 190, 207)"]
+ self.curve_x = [None] * self.trace_number
+ self.curve_y = [None] * self.trace_number
+
+ for trace in self.hist_data:
+ self.start.append(min(trace) * 1.)
+ self.end.append(max(trace) * 1.)
+
+ def make_hist(self):
+ """
+ Makes the histogram(s) for FigureFactory.create_distplot().
+
+ :rtype (list) hist: list of histogram representations
+ """
+ hist = [None] * self.trace_number
+
+ for index in range(self.trace_number):
+ hist[index] = dict(type='histogram',
+ x=self.hist_data[index],
+ xaxis='x1',
+ yaxis='y1',
+ histnorm=self.histnorm,
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ marker=dict(color=self.colors[index % len(self.colors)]),
+ autobinx=False,
+ xbins=dict(start=self.start[index],
+ end=self.end[index],
+ size=self.bin_size[index]),
+ opacity=.7)
+ return hist
+
+ def make_kde(self):
+ """
+ Makes the kernel density estimation(s) for create_distplot().
+
+ This is called when curve_type = 'kde' in create_distplot().
+
+ :rtype (list) curve: list of kde representations
+ """
+ curve = [None] * self.trace_number
+ for index in range(self.trace_number):
+ self.curve_x[index] = [self.start[index] +
+ x * (self.end[index] - self.start[index])
+ / 500 for x in range(500)]
+ self.curve_y[index] = (scipy_stats.gaussian_kde
+ (self.hist_data[index])
+ (self.curve_x[index]))
+
+ if self.histnorm == ALTERNATIVE_HISTNORM:
+ self.curve_y[index] *= self.bin_size[index]
+
+ for index in range(self.trace_number):
+ curve[index] = dict(type='scatter',
+ x=self.curve_x[index],
+ y=self.curve_y[index],
+ xaxis='x1',
+ yaxis='y1',
+ mode='lines',
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ showlegend=False if self.show_hist else True,
+ marker=dict(color=self.colors[index % len(self.colors)]))
+ return curve
+
+ def make_normal(self):
+ """
+ Makes the normal curve(s) for create_distplot().
+
+ This is called when curve_type = 'normal' in create_distplot().
+
+ :rtype (list) curve: list of normal curve representations
+ """
+ curve = [None] * self.trace_number
+ mean = [None] * self.trace_number
+ sd = [None] * self.trace_number
+
+ for index in range(self.trace_number):
+ mean[index], sd[index] = (scipy_stats.norm.fit
+ (self.hist_data[index]))
+ self.curve_x[index] = [self.start[index] +
+ x * (self.end[index] - self.start[index])
+ / 500 for x in range(500)]
+ self.curve_y[index] = scipy_stats.norm.pdf(
+ self.curve_x[index], loc=mean[index], scale=sd[index])
+
+ if self.histnorm == ALTERNATIVE_HISTNORM:
+ self.curve_y[index] *= self.bin_size[index]
+
+ for index in range(self.trace_number):
+ curve[index] = dict(type='scatter',
+ x=self.curve_x[index],
+ y=self.curve_y[index],
+ xaxis='x1',
+ yaxis='y1',
+ mode='lines',
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ showlegend=False if self.show_hist else True,
+ marker=dict(color=self.colors[index % len(self.colors)]))
+ return curve
+
+ def make_rug(self):
+ """
+ Makes the rug plot(s) for create_distplot().
+
+ :rtype (list) rug: list of rug plot representations
+ """
+ rug = [None] * self.trace_number
+ for index in range(self.trace_number):
+
+ rug[index] = dict(type='scatter',
+ x=self.hist_data[index],
+ y=([self.group_labels[index]] *
+ len(self.hist_data[index])),
+ xaxis='x1',
+ yaxis='y2',
+ mode='markers',
+ name=self.group_labels[index],
+ legendgroup=self.group_labels[index],
+ showlegend=(False if self.show_hist or
+ self.show_curve else True),
+ text=self.rug_text[index],
+ marker=dict(color=self.colors[index % len(self.colors)],
+ symbol='line-ns-open'))
+ return rug
diff --git a/plotly/figure_factory/figure_factory/_facet_grid.py b/plotly/figure_factory/figure_factory/_facet_grid.py
new file mode 100644
index 00000000000..8ef1825ebf8
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_facet_grid.py
@@ -0,0 +1,1111 @@
+from __future__ import absolute_import
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.tools import make_subplots
+
+import math
+from numbers import Number
+
+pd = optional_imports.get_module('pandas')
+
+TICK_COLOR = '#969696'
+AXIS_TITLE_COLOR = '#0f0f0f'
+AXIS_TITLE_SIZE = 12
+GRID_COLOR = '#ffffff'
+LEGEND_COLOR = '#efefef'
+PLOT_BGCOLOR = '#ededed'
+ANNOT_RECT_COLOR = '#d0d0d0'
+LEGEND_BORDER_WIDTH = 1
+LEGEND_ANNOT_X = 1.05
+LEGEND_ANNOT_Y = 0.5
+MAX_TICKS_PER_AXIS = 5
+THRES_FOR_FLIPPED_FACET_TITLES = 10
+GRID_WIDTH = 1
+
+VALID_TRACE_TYPES = ['scatter', 'scattergl', 'histogram', 'bar', 'box']
+
+CUSTOM_LABEL_ERROR = (
+ "If you are using a dictionary for custom labels for the facet row/col, "
+ "make sure each key in that column of the dataframe is in your facet "
+ "labels. The keys you need are {}"
+)
+
+
+def _is_flipped(num):
+ if num >= THRES_FOR_FLIPPED_FACET_TITLES:
+ flipped = True
+ else:
+ flipped = False
+ return flipped
+
+
+def _return_label(original_label, facet_labels, facet_var):
+ if isinstance(facet_labels, dict):
+ label = facet_labels[original_label]
+ elif isinstance(facet_labels, str):
+ label = '{}: {}'.format(facet_var, original_label)
+ else:
+ label = original_label
+ return label
+
+
+def _legend_annotation(color_name):
+ legend_title = dict(
+ textangle=0,
+ xanchor='left',
+ yanchor='middle',
+ x=LEGEND_ANNOT_X,
+ y=1.03,
+ showarrow=False,
+ xref='paper',
+ yref='paper',
+ text='factor({})'.format(color_name),
+ font=dict(
+ size=13,
+ color='#000000'
+ )
+ )
+ return legend_title
+
+
+def _annotation_dict(text, lane, num_of_lanes, SUBPLOT_SPACING, row_col='col',
+ flipped=True):
+ l = (1 - (num_of_lanes - 1) * SUBPLOT_SPACING) / (num_of_lanes)
+ if not flipped:
+ xanchor = 'center'
+ yanchor = 'middle'
+ if row_col == 'col':
+ x = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
+ y = 1.03
+ textangle = 0
+ elif row_col == 'row':
+ y = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
+ x = 1.03
+ textangle = 90
+ else:
+ if row_col == 'col':
+ xanchor = 'center'
+ yanchor = 'bottom'
+ x = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
+ y = 1.0
+ textangle = 270
+ elif row_col == 'row':
+ xanchor = 'left'
+ yanchor = 'middle'
+ y = (lane - 1) * (l + SUBPLOT_SPACING) + 0.5 * l
+ x = 1.0
+ textangle = 0
+
+ annotation_dict = dict(
+ textangle=textangle,
+ xanchor=xanchor,
+ yanchor=yanchor,
+ x=x,
+ y=y,
+ showarrow=False,
+ xref='paper',
+ yref='paper',
+ text=str(text),
+ font=dict(
+ size=13,
+ color=AXIS_TITLE_COLOR
+ )
+ )
+ return annotation_dict
+
+
+def _axis_title_annotation(text, x_or_y_axis):
+ if x_or_y_axis == 'x':
+ x_pos = 0.5
+ y_pos = -0.1
+ textangle = 0
+ elif x_or_y_axis == 'y':
+ x_pos = -0.1
+ y_pos = 0.5
+ textangle = 270
+
+ if not text:
+ text = ''
+
+ annot = {'font': {'color': '#000000', 'size': AXIS_TITLE_SIZE},
+ 'showarrow': False,
+ 'text': text,
+ 'textangle': textangle,
+ 'x': x_pos,
+ 'xanchor': 'center',
+ 'xref': 'paper',
+ 'y': y_pos,
+ 'yanchor': 'middle',
+ 'yref': 'paper'}
+ return annot
+
+
+def _add_shapes_to_fig(fig, annot_rect_color, flipped_rows=False,
+ flipped_cols=False):
+ shapes_list = []
+ for key in fig['layout'].to_plotly_json().keys():
+ if 'axis' in key and fig['layout'][key]['domain'] != [0.0, 1.0]:
+ shape = {
+ 'fillcolor': annot_rect_color,
+ 'layer': 'below',
+ 'line': {'color': annot_rect_color, 'width': 1},
+ 'type': 'rect',
+ 'xref': 'paper',
+ 'yref': 'paper'
+ }
+
+ if 'xaxis' in key:
+ shape['x0'] = fig['layout'][key]['domain'][0]
+ shape['x1'] = fig['layout'][key]['domain'][1]
+ shape['y0'] = 1.005
+ shape['y1'] = 1.05
+
+ if flipped_cols:
+ shape['y1'] += 0.5
+ shapes_list.append(shape)
+
+ elif 'yaxis' in key:
+ shape['x0'] = 1.005
+ shape['x1'] = 1.05
+ shape['y0'] = fig['layout'][key]['domain'][0]
+ shape['y1'] = fig['layout'][key]['domain'][1]
+
+ if flipped_rows:
+ shape['x1'] += 1
+ shapes_list.append(shape)
+
+ fig['layout']['shapes'] = shapes_list
+
+
+def _make_trace_for_scatter(trace, trace_type, color, **kwargs_marker):
+ if trace_type in ['scatter', 'scattergl']:
+ trace['mode'] = 'markers'
+ trace['marker'] = dict(color=color, **kwargs_marker)
+ return trace
+
+
+def _facet_grid_color_categorical(df, x, y, facet_row, facet_col, color_name,
+ colormap, num_of_rows, num_of_cols,
+ facet_row_labels, facet_col_labels,
+ trace_type, flipped_rows, flipped_cols,
+ show_boxes, SUBPLOT_SPACING, marker_color,
+ kwargs_trace, kwargs_marker):
+
+ fig = make_subplots(rows=num_of_rows, cols=num_of_cols,
+ shared_xaxes=True, shared_yaxes=True,
+ horizontal_spacing=SUBPLOT_SPACING,
+ vertical_spacing=SUBPLOT_SPACING, print_grid=False)
+
+ annotations = []
+ if not facet_row and not facet_col:
+ color_groups = list(df.groupby(color_name))
+ for group in color_groups:
+ trace = dict(
+ type=trace_type,
+ name=group[0],
+ marker=dict(
+ color=colormap[group[0]],
+ ),
+ **kwargs_trace
+ )
+ if x:
+ trace['x'] = group[1][x]
+ if y:
+ trace['y'] = group[1][y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, colormap[group[0]], **kwargs_marker
+ )
+
+ fig.append_trace(trace, 1, 1)
+
+ elif (facet_row and not facet_col) or (not facet_row and facet_col):
+ groups_by_facet = list(
+ df.groupby(facet_row if facet_row else facet_col)
+ )
+ for j, group in enumerate(groups_by_facet):
+ for color_val in df[color_name].unique():
+ data_by_color = group[1][group[1][color_name] == color_val]
+ trace = dict(
+ type=trace_type,
+ name=color_val,
+ marker=dict(
+ color=colormap[color_val],
+ ),
+ **kwargs_trace
+ )
+ if x:
+ trace['x'] = data_by_color[x]
+ if y:
+ trace['y'] = data_by_color[y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, colormap[color_val], **kwargs_marker
+ )
+
+ fig.append_trace(trace,
+ j + 1 if facet_row else 1,
+ 1 if facet_row else j + 1)
+
+ label = _return_label(
+ group[0],
+ facet_row_labels if facet_row else facet_col_labels,
+ facet_row if facet_row else facet_col
+ )
+
+ annotations.append(
+ _annotation_dict(
+ label,
+ num_of_rows - j if facet_row else j + 1,
+ num_of_rows if facet_row else num_of_cols,
+ SUBPLOT_SPACING,
+ 'row' if facet_row else 'col',
+ flipped_rows)
+ )
+
+ elif facet_row and facet_col:
+ groups_by_facets = list(df.groupby([facet_row, facet_col]))
+ tuple_to_facet_group = {item[0]: item[1] for
+ item in groups_by_facets}
+
+ row_values = df[facet_row].unique()
+ col_values = df[facet_col].unique()
+ color_vals = df[color_name].unique()
+ for row_count, x_val in enumerate(row_values):
+ for col_count, y_val in enumerate(col_values):
+ try:
+ group = tuple_to_facet_group[(x_val, y_val)]
+ except KeyError:
+ group = pd.DataFrame([[None, None, None]],
+ columns=[x, y, color_name])
+
+ for color_val in color_vals:
+ if group.values.tolist() != [[None, None, None]]:
+ group_filtered = group[group[color_name] == color_val]
+
+ trace = dict(
+ type=trace_type,
+ name=color_val,
+ marker=dict(
+ color=colormap[color_val],
+ ),
+ **kwargs_trace
+ )
+ new_x = group_filtered[x]
+ new_y = group_filtered[y]
+ else:
+ trace = dict(
+ type=trace_type,
+ name=color_val,
+ marker=dict(
+ color=colormap[color_val],
+ ),
+ showlegend=False,
+ **kwargs_trace
+ )
+ new_x = group[x]
+ new_y = group[y]
+
+ if x:
+ trace['x'] = new_x
+ if y:
+ trace['y'] = new_y
+ trace = _make_trace_for_scatter(
+ trace, trace_type, colormap[color_val],
+ **kwargs_marker
+ )
+
+ fig.append_trace(trace, row_count + 1, col_count + 1)
+ if row_count == 0:
+ label = _return_label(col_values[col_count],
+ facet_col_labels, facet_col)
+ annotations.append(
+ _annotation_dict(label, col_count + 1, num_of_cols,
+ SUBPLOT_SPACING,
+ row_col='col', flipped=flipped_cols)
+ )
+ label = _return_label(row_values[row_count],
+ facet_row_labels, facet_row)
+ annotations.append(
+ _annotation_dict(label, num_of_rows - row_count, num_of_rows,
+ SUBPLOT_SPACING,
+ row_col='row', flipped=flipped_rows)
+ )
+
+ return fig, annotations
+
+
+def _facet_grid_color_numerical(df, x, y, facet_row, facet_col, color_name,
+ colormap, num_of_rows,
+ num_of_cols, facet_row_labels,
+ facet_col_labels, trace_type,
+ flipped_rows, flipped_cols, show_boxes,
+ SUBPLOT_SPACING, marker_color, kwargs_trace,
+ kwargs_marker):
+
+ fig = make_subplots(rows=num_of_rows, cols=num_of_cols,
+ shared_xaxes=True, shared_yaxes=True,
+ horizontal_spacing=SUBPLOT_SPACING,
+ vertical_spacing=SUBPLOT_SPACING, print_grid=False)
+
+ annotations = []
+ if not facet_row and not facet_col:
+ trace = dict(
+ type=trace_type,
+ marker=dict(
+ color=df[color_name],
+ colorscale=colormap,
+ showscale=True,
+ ),
+ **kwargs_trace
+ )
+ if x:
+ trace['x'] = df[x]
+ if y:
+ trace['y'] = df[y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, df[color_name], **kwargs_marker
+ )
+
+ fig.append_trace(trace, 1, 1)
+
+ if (facet_row and not facet_col) or (not facet_row and facet_col):
+ groups_by_facet = list(
+ df.groupby(facet_row if facet_row else facet_col)
+ )
+ for j, group in enumerate(groups_by_facet):
+ trace = dict(
+ type=trace_type,
+ marker=dict(
+ color=df[color_name],
+ colorscale=colormap,
+ showscale=True,
+ colorbar=dict(x=1.15),
+ ),
+ **kwargs_trace
+ )
+ if x:
+ trace['x'] = group[1][x]
+ if y:
+ trace['y'] = group[1][y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, df[color_name], **kwargs_marker
+ )
+
+ fig.append_trace(
+ trace,
+ j + 1 if facet_row else 1,
+ 1 if facet_row else j + 1
+ )
+
+ labels = facet_row_labels if facet_row else facet_col_labels
+ label = _return_label(
+ group[0], labels, facet_row if facet_row else facet_col
+ )
+
+ annotations.append(
+ _annotation_dict(
+ label,
+ num_of_rows - j if facet_row else j + 1,
+ num_of_rows if facet_row else num_of_cols,
+ SUBPLOT_SPACING,
+ 'row' if facet_row else 'col',
+ flipped=flipped_rows)
+ )
+
+ elif facet_row and facet_col:
+ groups_by_facets = list(df.groupby([facet_row, facet_col]))
+ tuple_to_facet_group = {item[0]: item[1] for
+ item in groups_by_facets}
+
+ row_values = df[facet_row].unique()
+ col_values = df[facet_col].unique()
+ for row_count, x_val in enumerate(row_values):
+ for col_count, y_val in enumerate(col_values):
+ try:
+ group = tuple_to_facet_group[(x_val, y_val)]
+ except KeyError:
+ group = pd.DataFrame([[None, None, None]],
+ columns=[x, y, color_name])
+
+ if group.values.tolist() != [[None, None, None]]:
+ trace = dict(
+ type=trace_type,
+ marker=dict(
+ color=df[color_name],
+ colorscale=colormap,
+ showscale=(row_count == 0),
+ colorbar=dict(x=1.15),
+ ),
+ **kwargs_trace
+ )
+
+ else:
+ trace = dict(
+ type=trace_type,
+ showlegend=False,
+ **kwargs_trace
+ )
+
+ if x:
+ trace['x'] = group[x]
+ if y:
+ trace['y'] = group[y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, df[color_name], **kwargs_marker
+ )
+
+ fig.append_trace(trace, row_count + 1, col_count + 1)
+ if row_count == 0:
+ label = _return_label(col_values[col_count],
+ facet_col_labels, facet_col)
+ annotations.append(
+ _annotation_dict(label, col_count + 1, num_of_cols,
+ SUBPLOT_SPACING,
+ row_col='col', flipped=flipped_cols)
+ )
+ label = _return_label(row_values[row_count],
+ facet_row_labels, facet_row)
+ annotations.append(
+ _annotation_dict(row_values[row_count],
+ num_of_rows - row_count, num_of_rows, SUBPLOT_SPACING,
+ row_col='row', flipped=flipped_rows)
+ )
+
+ return fig, annotations
+
+
+def _facet_grid(df, x, y, facet_row, facet_col, num_of_rows,
+ num_of_cols, facet_row_labels, facet_col_labels,
+ trace_type, flipped_rows, flipped_cols, show_boxes,
+ SUBPLOT_SPACING, marker_color, kwargs_trace, kwargs_marker):
+
+ fig = make_subplots(rows=num_of_rows, cols=num_of_cols,
+ shared_xaxes=True, shared_yaxes=True,
+ horizontal_spacing=SUBPLOT_SPACING,
+ vertical_spacing=SUBPLOT_SPACING, print_grid=False)
+ annotations = []
+ if not facet_row and not facet_col:
+ trace = dict(
+ type=trace_type,
+ marker=dict(
+ color=marker_color,
+ line=kwargs_marker['line'],
+ ),
+ **kwargs_trace
+ )
+
+ if x:
+ trace['x'] = df[x]
+ if y:
+ trace['y'] = df[y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, marker_color, **kwargs_marker
+ )
+
+ fig.append_trace(trace, 1, 1)
+
+ elif (facet_row and not facet_col) or (not facet_row and facet_col):
+ groups_by_facet = list(
+ df.groupby(facet_row if facet_row else facet_col)
+ )
+ for j, group in enumerate(groups_by_facet):
+ trace = dict(
+ type=trace_type,
+ marker=dict(
+ color=marker_color,
+ line=kwargs_marker['line'],
+ ),
+ **kwargs_trace
+ )
+
+ if x:
+ trace['x'] = group[1][x]
+ if y:
+ trace['y'] = group[1][y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, marker_color, **kwargs_marker
+ )
+
+ fig.append_trace(trace,
+ j + 1 if facet_row else 1,
+ 1 if facet_row else j + 1)
+
+ label = _return_label(
+ group[0],
+ facet_row_labels if facet_row else facet_col_labels,
+ facet_row if facet_row else facet_col
+ )
+
+ annotations.append(
+ _annotation_dict(
+ label,
+ num_of_rows - j if facet_row else j + 1,
+ num_of_rows if facet_row else num_of_cols,
+ SUBPLOT_SPACING,
+ 'row' if facet_row else 'col',
+ flipped_rows
+ )
+ )
+
+ elif facet_row and facet_col:
+ groups_by_facets = list(df.groupby([facet_row, facet_col]))
+ tuple_to_facet_group = {item[0]: item[1] for
+ item in groups_by_facets}
+
+ row_values = df[facet_row].unique()
+ col_values = df[facet_col].unique()
+ for row_count, x_val in enumerate(row_values):
+ for col_count, y_val in enumerate(col_values):
+ try:
+ group = tuple_to_facet_group[(x_val, y_val)]
+ except KeyError:
+ group = pd.DataFrame([[None, None]], columns=[x, y])
+ trace = dict(
+ type=trace_type,
+ marker=dict(
+ color=marker_color,
+ line=kwargs_marker['line'],
+ ),
+ **kwargs_trace
+ )
+ if x:
+ trace['x'] = group[x]
+ if y:
+ trace['y'] = group[y]
+ trace = _make_trace_for_scatter(
+ trace, trace_type, marker_color, **kwargs_marker
+ )
+
+ fig.append_trace(trace, row_count + 1, col_count + 1)
+ if row_count == 0:
+ label = _return_label(col_values[col_count],
+ facet_col_labels,
+ facet_col)
+ annotations.append(
+ _annotation_dict(label, col_count + 1, num_of_cols, SUBPLOT_SPACING,
+ row_col='col', flipped=flipped_cols)
+ )
+
+ label = _return_label(row_values[row_count],
+ facet_row_labels,
+ facet_row)
+ annotations.append(
+ _annotation_dict(label, num_of_rows - row_count, num_of_rows, SUBPLOT_SPACING,
+ row_col='row', flipped=flipped_rows)
+ )
+
+ return fig, annotations
+
+
+def create_facet_grid(df, x=None, y=None, facet_row=None, facet_col=None,
+ color_name=None, colormap=None, color_is_cat=False,
+ facet_row_labels=None, facet_col_labels=None,
+ height=None, width=None, trace_type='scatter',
+ scales='fixed', dtick_x=None, dtick_y=None,
+ show_boxes=True, ggplot2=False, binsize=1, **kwargs):
+ """
+ Returns figure for facet grid.
+
+ :param (pd.DataFrame) df: the dataframe of columns for the facet grid.
+ :param (str) x: the name of the dataframe column for the x axis data.
+ :param (str) y: the name of the dataframe column for the y axis data.
+ :param (str) facet_row: the name of the dataframe column that is used to
+ facet the grid into row panels.
+ :param (str) facet_col: the name of the dataframe column that is used to
+ facet the grid into column panels.
+ :param (str) color_name: the name of your dataframe column that will
+ function as the colormap variable.
+ :param (str|list|dict) colormap: the param that determines how the
+ color_name column colors the data. If the dataframe contains numeric
+ data, then a dictionary of colors will group the data categorically
+ while a Plotly Colorscale name or a custom colorscale will treat it
+ numerically. To learn more about colors and types of colormap, run
+ `help(plotly.colors)`.
+ :param (bool) color_is_cat: determines whether a numerical column for the
+ colormap will be treated as categorical (True) or sequential (False).
+ Default = False.
+ :param (str|dict) facet_row_labels: set to either 'name' or a dictionary
+ of all the unique values in the faceting row mapped to some text to
+ show up in the label annotations. If None, labeling works like usual.
+ :param (str|dict) facet_col_labels: set to either 'name' or a dictionary
+ of all the values in the faceting row mapped to some text to show up
+ in the label annotations. If None, labeling works like usual.
+ :param (int) height: the height of the facet grid figure.
+ :param (int) width: the width of the facet grid figure.
+ :param (str) trace_type: decides the type of plot to appear in the
+ facet grid. The options are 'scatter', 'scattergl', 'histogram',
+ 'bar', and 'box'.
+ Default = 'scatter'.
+ :param (str) scales: determines if axes have fixed ranges or not. Valid
+ settings are 'fixed' (all axes fixed), 'free_x' (x axis free only),
+ 'free_y' (y axis free only) or 'free' (both axes free).
+ :param (float) dtick_x: determines the distance between each tick on the
+ x-axis. Default is None which means dtick_x is set automatically.
+ :param (float) dtick_y: determines the distance between each tick on the
+ y-axis. Default is None which means dtick_y is set automatically.
+ :param (bool) show_boxes: draws grey boxes behind the facet titles.
+ :param (bool) ggplot2: draws the facet grid in the style of `ggplot2`. See
+ http://ggplot2.tidyverse.org/reference/facet_grid.html for reference.
+ Default = False
+ :param (int) binsize: groups all data into bins of a given length.
+ :param (dict) kwargs: a dictionary of scatterplot arguments.
+
+ Examples 1: One Way Faceting
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import pandas as pd
+
+ mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt')
+
+ fig = ff.create_facet_grid(
+ mpg,
+ x='displ',
+ y='cty',
+ facet_col='cyl',
+ )
+ py.iplot(fig, filename='facet_grid_mpg_one_way_facet')
+ ```
+
+ Example 2: Two Way Faceting
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import pandas as pd
+
+ mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt')
+
+ fig = ff.create_facet_grid(
+ mpg,
+ x='displ',
+ y='cty',
+ facet_row='drv',
+ facet_col='cyl',
+ )
+ py.iplot(fig, filename='facet_grid_mpg_two_way_facet')
+ ```
+
+ Example 3: Categorical Coloring
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import pandas as pd
+
+ mpg = pd.read_table('https://raw.githubusercontent.com/plotly/datasets/master/mpg_2017.txt')
+
+ fig = ff.create_facet_grid(
+ mtcars,
+ x='mpg',
+ y='wt',
+ facet_col='cyl',
+ color_name='cyl',
+ color_is_cat=True,
+ )
+ py.iplot(fig, filename='facet_grid_mpg_default_colors')
+ ```
+
+ Example 4: Sequential Coloring
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import pandas as pd
+
+ tips = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/tips.csv')
+
+ fig = ff.create_facet_grid(
+ tips,
+ x='total_bill',
+ y='tip',
+ facet_row='sex',
+ facet_col='smoker',
+ color_name='size',
+ colormap='Viridis',
+ )
+ py.iplot(fig, filename='facet_grid_tips_sequential_colors')
+ ```
+
+ Example 5: Custom labels
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import pandas as pd
+
+ mtcars = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/mtcars.csv')
+
+ fig = ff.create_facet_grid(
+ mtcars,
+ x='wt',
+ y='mpg',
+ facet_col='cyl',
+ facet_col_labels={4: "$\\alpha$", 6: '$\\beta$', 8: '$\sqrt[y]{x}$'},
+ )
+
+ py.iplot(fig, filename='facet_grid_mtcars_custom_labels')
+ ```
+
+ Example 6: Other Trace Type
+ ```
+ import plotly.plotly as py
+ import plotly.figure_factory as ff
+
+ import pandas as pd
+
+ mtcars = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/mtcars.csv')
+
+ fig = ff.create_facet_grid(
+ mtcars,
+ x='wt',
+ facet_col='cyl',
+ trace_type='histogram',
+ )
+
+ py.iplot(fig, filename='facet_grid_mtcars_other_trace_type')
+ ```
+ """
+ if not pd:
+ raise exceptions.ImportError(
+ "'pandas' must be installed for this figure_factory."
+ )
+
+ if not isinstance(df, pd.DataFrame):
+ raise exceptions.PlotlyError(
+ "You must input a pandas DataFrame."
+ )
+
+ # make sure all columns are of homogenous datatype
+ utils.validate_dataframe(df)
+
+ if trace_type in ['scatter', 'scattergl']:
+ if not x or not y:
+ raise exceptions.PlotlyError(
+ "You need to input 'x' and 'y' if you are you are using a "
+ "trace_type of 'scatter' or 'scattergl'."
+ )
+
+ for key in [x, y, facet_row, facet_col, color_name]:
+ if key is not None:
+ try:
+ df[key]
+ except KeyError:
+ raise exceptions.PlotlyError(
+ "x, y, facet_row, facet_col and color_name must be keys "
+ "in your dataframe."
+ )
+ # autoscale histogram bars
+ if trace_type not in ['scatter', 'scattergl']:
+ scales = 'free'
+
+ # validate scales
+ if scales not in ['fixed', 'free_x', 'free_y', 'free']:
+ raise exceptions.PlotlyError(
+ "'scales' must be set to 'fixed', 'free_x', 'free_y' and 'free'."
+ )
+
+ if trace_type not in VALID_TRACE_TYPES:
+ raise exceptions.PlotlyError(
+ "'trace_type' must be in {}".format(VALID_TRACE_TYPES)
+ )
+
+ if trace_type == 'histogram':
+ SUBPLOT_SPACING = 0.06
+ else:
+ SUBPLOT_SPACING = 0.015
+
+ # seperate kwargs for marker and else
+ if 'marker' in kwargs:
+ kwargs_marker = kwargs['marker']
+ else:
+ kwargs_marker = {}
+ marker_color = kwargs_marker.pop('color', None)
+ kwargs.pop('marker', None)
+ kwargs_trace = kwargs
+
+ if 'size' not in kwargs_marker:
+ if ggplot2:
+ kwargs_marker['size'] = 5
+ else:
+ kwargs_marker['size'] = 8
+
+ if 'opacity' not in kwargs_marker:
+ if not ggplot2:
+ kwargs_trace['opacity'] = 0.6
+
+ if 'line' not in kwargs_marker:
+ if not ggplot2:
+ kwargs_marker['line'] = {'color': 'darkgrey', 'width': 1}
+ else:
+ kwargs_marker['line'] = {}
+
+ # default marker size
+ if not ggplot2:
+ if not marker_color:
+ marker_color = 'rgb(31, 119, 180)'
+ else:
+ marker_color = 'rgb(0, 0, 0)'
+
+ num_of_rows = 1
+ num_of_cols = 1
+ flipped_rows = False
+ flipped_cols = False
+ if facet_row:
+ num_of_rows = len(df[facet_row].unique())
+ flipped_rows = _is_flipped(num_of_rows)
+ if isinstance(facet_row_labels, dict):
+ for key in df[facet_row].unique():
+ if key not in facet_row_labels.keys():
+ unique_keys = df[facet_row].unique().tolist()
+ raise exceptions.PlotlyError(
+ CUSTOM_LABEL_ERROR.format(unique_keys)
+ )
+ if facet_col:
+ num_of_cols = len(df[facet_col].unique())
+ flipped_cols = _is_flipped(num_of_cols)
+ if isinstance(facet_col_labels, dict):
+ for key in df[facet_col].unique():
+ if key not in facet_col_labels.keys():
+ unique_keys = df[facet_col].unique().tolist()
+ raise exceptions.PlotlyError(
+ CUSTOM_LABEL_ERROR.format(unique_keys)
+ )
+ show_legend = False
+ if color_name:
+ if isinstance(df[color_name].iloc[0], str) or color_is_cat:
+ show_legend = True
+ if isinstance(colormap, dict):
+ utils.validate_colors_dict(colormap, 'rgb')
+
+ for val in df[color_name].unique():
+ if val not in colormap.keys():
+ raise exceptions.PlotlyError(
+ "If using 'colormap' as a dictionary, make sure "
+ "all the values of the colormap column are in "
+ "the keys of your dictionary."
+ )
+ else:
+ # use default plotly colors for dictionary
+ default_colors = utils.DEFAULT_PLOTLY_COLORS
+ colormap = {}
+ j = 0
+ for val in df[color_name].unique():
+ if j >= len(default_colors):
+ j = 0
+ colormap[val] = default_colors[j]
+ j += 1
+ fig, annotations = _facet_grid_color_categorical(
+ df, x, y, facet_row, facet_col, color_name, colormap,
+ num_of_rows, num_of_cols, facet_row_labels, facet_col_labels,
+ trace_type, flipped_rows, flipped_cols, show_boxes,
+ SUBPLOT_SPACING, marker_color, kwargs_trace, kwargs_marker
+ )
+
+ elif isinstance(df[color_name].iloc[0], Number):
+ if isinstance(colormap, dict):
+ show_legend = True
+ utils.validate_colors_dict(colormap, 'rgb')
+
+ for val in df[color_name].unique():
+ if val not in colormap.keys():
+ raise exceptions.PlotlyError(
+ "If using 'colormap' as a dictionary, make sure "
+ "all the values of the colormap column are in "
+ "the keys of your dictionary."
+ )
+ fig, annotations = _facet_grid_color_categorical(
+ df, x, y, facet_row, facet_col, color_name, colormap,
+ num_of_rows, num_of_cols, facet_row_labels,
+ facet_col_labels, trace_type, flipped_rows,
+ flipped_cols, show_boxes, SUBPLOT_SPACING, marker_color,
+ kwargs_trace, kwargs_marker
+ )
+
+ elif isinstance(colormap, list):
+ colorscale_list = colormap
+ utils.validate_colorscale(colorscale_list)
+
+ fig, annotations = _facet_grid_color_numerical(
+ df, x, y, facet_row, facet_col, color_name,
+ colorscale_list, num_of_rows, num_of_cols,
+ facet_row_labels, facet_col_labels, trace_type,
+ flipped_rows, flipped_cols, show_boxes, SUBPLOT_SPACING,
+ marker_color, kwargs_trace, kwargs_marker
+ )
+ elif isinstance(colormap, str):
+ if colormap in utils.PLOTLY_SCALES.keys():
+ colorscale_list = utils.PLOTLY_SCALES[colormap]
+ else:
+ raise exceptions.PlotlyError(
+ "If 'colormap' is a string, it must be the name "
+ "of a Plotly Colorscale. The available colorscale "
+ "names are {}".format(utils.PLOTLY_SCALES.keys())
+ )
+ fig, annotations = _facet_grid_color_numerical(
+ df, x, y, facet_row, facet_col, color_name,
+ colorscale_list, num_of_rows, num_of_cols,
+ facet_row_labels, facet_col_labels, trace_type,
+ flipped_rows, flipped_cols, show_boxes, SUBPLOT_SPACING,
+ marker_color, kwargs_trace, kwargs_marker
+ )
+ else:
+ colorscale_list = utils.PLOTLY_SCALES['Reds']
+ fig, annotations = _facet_grid_color_numerical(
+ df, x, y, facet_row, facet_col, color_name,
+ colorscale_list, num_of_rows, num_of_cols,
+ facet_row_labels, facet_col_labels, trace_type,
+ flipped_rows, flipped_cols, show_boxes, SUBPLOT_SPACING,
+ marker_color, kwargs_trace, kwargs_marker
+ )
+
+ else:
+ fig, annotations = _facet_grid(
+ df, x, y, facet_row, facet_col, num_of_rows, num_of_cols,
+ facet_row_labels, facet_col_labels, trace_type, flipped_rows,
+ flipped_cols, show_boxes, SUBPLOT_SPACING, marker_color,
+ kwargs_trace, kwargs_marker
+ )
+
+ if not height:
+ height = max(600, 100 * num_of_rows)
+ if not width:
+ width = max(600, 100 * num_of_cols)
+
+ fig['layout'].update(height=height, width=width, title='',
+ paper_bgcolor='rgb(251, 251, 251)')
+ if ggplot2:
+ fig['layout'].update(plot_bgcolor=PLOT_BGCOLOR,
+ paper_bgcolor='rgb(255, 255, 255)',
+ hovermode='closest')
+
+ # axis titles
+ x_title_annot = _axis_title_annotation(x, 'x')
+ y_title_annot = _axis_title_annotation(y, 'y')
+
+ # annotations
+ annotations.append(x_title_annot)
+ annotations.append(y_title_annot)
+
+ # legend
+ fig['layout']['showlegend'] = show_legend
+ fig['layout']['legend']['bgcolor'] = LEGEND_COLOR
+ fig['layout']['legend']['borderwidth'] = LEGEND_BORDER_WIDTH
+ fig['layout']['legend']['x'] = 1.05
+ fig['layout']['legend']['y'] = 1
+ fig['layout']['legend']['yanchor'] = 'top'
+
+ if show_legend:
+ fig['layout']['showlegend'] = show_legend
+ if ggplot2:
+ if color_name:
+ legend_annot = _legend_annotation(color_name)
+ annotations.append(legend_annot)
+ fig['layout']['margin']['r'] = 150
+
+ # assign annotations to figure
+ fig['layout']['annotations'] = annotations
+
+ # add shaded boxes behind axis titles
+ if show_boxes and ggplot2:
+ _add_shapes_to_fig(fig, ANNOT_RECT_COLOR, flipped_rows, flipped_cols)
+
+ # all xaxis and yaxis labels
+ axis_labels = {'x': [], 'y': []}
+ for key in fig['layout']:
+ if 'xaxis' in key:
+ axis_labels['x'].append(key)
+ elif 'yaxis' in key:
+ axis_labels['y'].append(key)
+
+ string_number_in_data = False
+ for var in [v for v in [x, y] if v]:
+ if isinstance(df[var].tolist()[0], str):
+ for item in df[var]:
+ try:
+ int(item)
+ string_number_in_data = True
+ except ValueError:
+ pass
+
+ if string_number_in_data:
+ for x_y in axis_labels.keys():
+ for axis_name in axis_labels[x_y]:
+ fig['layout'][axis_name]['type'] = 'category'
+
+ if scales == 'fixed':
+ fixed_axes = ['x', 'y']
+ elif scales == 'free_x':
+ fixed_axes = ['y']
+ elif scales == 'free_y':
+ fixed_axes = ['x']
+ elif scales == 'free':
+ fixed_axes = []
+
+ # fixed ranges
+ for x_y in fixed_axes:
+ min_ranges = []
+ max_ranges = []
+ for trace in fig['data']:
+ if trace[x_y] is not None and len(trace[x_y]) > 0:
+ min_ranges.append(min(trace[x_y]))
+ max_ranges.append(max(trace[x_y]))
+ while None in min_ranges:
+ min_ranges.remove(None)
+ while None in max_ranges:
+ max_ranges.remove(None)
+
+ min_range = min(min_ranges)
+ max_range = max(max_ranges)
+
+ range_are_numbers = (isinstance(min_range, Number) and
+ isinstance(max_range, Number))
+
+ if range_are_numbers:
+ min_range = math.floor(min_range)
+ max_range = math.ceil(max_range)
+
+ # extend widen frame by 5% on each side
+ min_range -= 0.05 * (max_range - min_range)
+ max_range += 0.05 * (max_range - min_range)
+
+ if x_y == 'x':
+ if dtick_x:
+ dtick = dtick_x
+ else:
+ dtick = math.floor(
+ (max_range - min_range) / MAX_TICKS_PER_AXIS
+ )
+ elif x_y == 'y':
+ if dtick_y:
+ dtick = dtick_y
+ else:
+ dtick = math.floor(
+ (max_range - min_range) / MAX_TICKS_PER_AXIS
+ )
+ else:
+ dtick = 1
+
+ for axis_title in axis_labels[x_y]:
+ fig['layout'][axis_title]['dtick'] = dtick
+ fig['layout'][axis_title]['ticklen'] = 0
+ fig['layout'][axis_title]['zeroline'] = False
+ if ggplot2:
+ fig['layout'][axis_title]['tickwidth'] = 1
+ fig['layout'][axis_title]['ticklen'] = 4
+ fig['layout'][axis_title]['gridwidth'] = GRID_WIDTH
+
+ fig['layout'][axis_title]['gridcolor'] = GRID_COLOR
+ fig['layout'][axis_title]['gridwidth'] = 2
+ fig['layout'][axis_title]['tickfont'] = {
+ 'color': TICK_COLOR, 'size': 10
+ }
+
+ # insert ranges into fig
+ if x_y in fixed_axes:
+ for key in fig['layout']:
+ if '{}axis'.format(x_y) in key and range_are_numbers:
+ fig['layout'][key]['range'] = [min_range, max_range]
+
+ return fig
diff --git a/plotly/figure_factory/figure_factory/_gantt.py b/plotly/figure_factory/figure_factory/_gantt.py
new file mode 100644
index 00000000000..36024e6c1b5
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_gantt.py
@@ -0,0 +1,780 @@
+from __future__ import absolute_import
+
+from numbers import Number
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+
+pd = optional_imports.get_module('pandas')
+
+REQUIRED_GANTT_KEYS = ['Task', 'Start', 'Finish']
+
+
+def validate_gantt(df):
+ """
+ Validates the inputted dataframe or list
+ """
+ if pd and isinstance(df, pd.core.frame.DataFrame):
+ # validate that df has all the required keys
+ for key in REQUIRED_GANTT_KEYS:
+ if key not in df:
+ raise exceptions.PlotlyError(
+ "The columns in your dataframe must include the "
+ "following keys: {0}".format(
+ ', '.join(REQUIRED_GANTT_KEYS))
+ )
+
+ num_of_rows = len(df.index)
+ chart = []
+ for index in range(num_of_rows):
+ task_dict = {}
+ for key in df:
+ task_dict[key] = df.ix[index][key]
+ chart.append(task_dict)
+
+ return chart
+
+ # validate if df is a list
+ if not isinstance(df, list):
+ raise exceptions.PlotlyError("You must input either a dataframe "
+ "or a list of dictionaries.")
+
+ # validate if df is empty
+ if len(df) <= 0:
+ raise exceptions.PlotlyError("Your list is empty. It must contain "
+ "at least one dictionary.")
+ if not isinstance(df[0], dict):
+ raise exceptions.PlotlyError("Your list must only "
+ "include dictionaries.")
+ return df
+
+
+def gantt(chart, colors, title, bar_width, showgrid_x, showgrid_y, height,
+ width, tasks=None, task_names=None, data=None, group_tasks=False):
+ """
+ Refer to create_gantt() for docstring
+ """
+ if tasks is None:
+ tasks = []
+ if task_names is None:
+ task_names = []
+ if data is None:
+ data = []
+
+ for index in range(len(chart)):
+ task = dict(x0=chart[index]['Start'],
+ x1=chart[index]['Finish'],
+ name=chart[index]['Task'])
+ if 'Description' in chart[index]:
+ task['description'] = chart[index]['Description']
+ tasks.append(task)
+
+ shape_template = {
+ 'type': 'rect',
+ 'xref': 'x',
+ 'yref': 'y',
+ 'opacity': 1,
+ 'line': {
+ 'width': 0,
+ }
+ }
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted first
+ # are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ color_index = 0
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ del tasks[index]['name']
+ tasks[index].update(shape_template)
+
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]['y0'] = groupID - bar_width
+ tasks[index]['y1'] = groupID + bar_width
+
+ # check if colors need to be looped
+ if color_index >= len(colors):
+ color_index = 0
+ tasks[index]['fillcolor'] = colors[color_index]
+ # Add a line for hover text and autorange
+ entry = dict(
+ x=[tasks[index]['x0'], tasks[index]['x1']],
+ y=[groupID, groupID],
+ name='',
+ marker={'color': 'white'}
+ )
+ if "description" in tasks[index]:
+ entry['text'] = tasks[index]['description']
+ del tasks[index]['description']
+ data.append(entry)
+ color_index += 1
+
+ layout = dict(
+ title=title,
+ showlegend=False,
+ height=height,
+ width=width,
+ shapes=[],
+ hovermode='closest',
+ yaxis=dict(
+ showgrid=showgrid_y,
+ ticktext=task_names,
+ tickvals=list(range(len(task_names))),
+ range=[-1, len(task_names) + 1],
+ autorange=False,
+ zeroline=False,
+ ),
+ xaxis=dict(
+ showgrid=showgrid_x,
+ zeroline=False,
+ rangeselector=dict(
+ buttons=list([
+ dict(count=7,
+ label='1w',
+ step='day',
+ stepmode='backward'),
+ dict(count=1,
+ label='1m',
+ step='month',
+ stepmode='backward'),
+ dict(count=6,
+ label='6m',
+ step='month',
+ stepmode='backward'),
+ dict(count=1,
+ label='YTD',
+ step='year',
+ stepmode='todate'),
+ dict(count=1,
+ label='1y',
+ step='year',
+ stepmode='backward'),
+ dict(step='all')
+ ])
+ ),
+ type='date'
+ )
+ )
+ layout['shapes'] = tasks
+
+ fig = graph_objs.Figure(data=data, layout=layout)
+ return fig
+
+
+def gantt_colorscale(chart, colors, title, index_col, show_colorbar, bar_width,
+ showgrid_x, showgrid_y, height, width, tasks=None,
+ task_names=None, data=None, group_tasks=False):
+ """
+ Refer to FigureFactory.create_gantt() for docstring
+ """
+ if tasks is None:
+ tasks = []
+ if task_names is None:
+ task_names = []
+ if data is None:
+ data = []
+ showlegend = False
+
+ for index in range(len(chart)):
+ task = dict(x0=chart[index]['Start'],
+ x1=chart[index]['Finish'],
+ name=chart[index]['Task'])
+ if 'Description' in chart[index]:
+ task['description'] = chart[index]['Description']
+ tasks.append(task)
+
+ shape_template = {
+ 'type': 'rect',
+ 'xref': 'x',
+ 'yref': 'y',
+ 'opacity': 1,
+ 'line': {
+ 'width': 0,
+ }
+ }
+
+ # compute the color for task based on indexing column
+ if isinstance(chart[0][index_col], Number):
+ # check that colors has at least 2 colors
+ if len(colors) < 2:
+ raise exceptions.PlotlyError(
+ "You must use at least 2 colors in 'colors' if you "
+ "are using a colorscale. However only the first two "
+ "colors given will be used for the lower and upper "
+ "bounds on the colormap."
+ )
+
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted
+ # first are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ del tasks[index]['name']
+ tasks[index].update(shape_template)
+
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]['y0'] = groupID - bar_width
+ tasks[index]['y1'] = groupID + bar_width
+
+ # unlabel color
+ colors = utils.color_parser(colors, utils.unlabel_rgb)
+ lowcolor = colors[0]
+ highcolor = colors[1]
+
+ intermed = (chart[index][index_col]) / 100.0
+ intermed_color = utils.find_intermediate_color(
+ lowcolor, highcolor, intermed
+ )
+ intermed_color = utils.color_parser(
+ intermed_color, utils.label_rgb
+ )
+ tasks[index]['fillcolor'] = intermed_color
+ # relabel colors with 'rgb'
+ colors = utils.color_parser(colors, utils.label_rgb)
+
+ # add a line for hover text and autorange
+ entry = dict(
+ x=[tasks[index]['x0'], tasks[index]['x1']],
+ y=[groupID, groupID],
+ name='',
+ marker={'color': 'white'}
+ )
+ if "description" in tasks[index]:
+ entry['text'] = tasks[index]['description']
+ del tasks[index]['description']
+ data.append(entry)
+
+ if show_colorbar is True:
+ # generate dummy data for colorscale visibility
+ data.append(
+ dict(
+ x=[tasks[index]['x0'], tasks[index]['x0']],
+ y=[index, index],
+ name='',
+ marker={'color': 'white',
+ 'colorscale': [[0, colors[0]], [1, colors[1]]],
+ 'showscale': True,
+ 'cmax': 100,
+ 'cmin': 0}
+ )
+ )
+
+ if isinstance(chart[0][index_col], str):
+ index_vals = []
+ for row in range(len(tasks)):
+ if chart[row][index_col] not in index_vals:
+ index_vals.append(chart[row][index_col])
+
+ index_vals.sort()
+
+ if len(colors) < len(index_vals):
+ raise exceptions.PlotlyError(
+ "Error. The number of colors in 'colors' must be no less "
+ "than the number of unique index values in your group "
+ "column."
+ )
+
+ # make a dictionary assignment to each index value
+ index_vals_dict = {}
+ # define color index
+ c_index = 0
+ for key in index_vals:
+ if c_index > len(colors) - 1:
+ c_index = 0
+ index_vals_dict[key] = colors[c_index]
+ c_index += 1
+
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted
+ # first are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ del tasks[index]['name']
+ tasks[index].update(shape_template)
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]['y0'] = groupID - bar_width
+ tasks[index]['y1'] = groupID + bar_width
+
+ tasks[index]['fillcolor'] = index_vals_dict[
+ chart[index][index_col]
+ ]
+
+ # add a line for hover text and autorange
+ entry = dict(
+ x=[tasks[index]['x0'], tasks[index]['x1']],
+ y=[groupID, groupID],
+ name='',
+ marker={'color': 'white'}
+ )
+ if "description" in tasks[index]:
+ entry['text'] = tasks[index]['description']
+ del tasks[index]['description']
+ data.append(entry)
+
+ if show_colorbar is True:
+ # generate dummy data to generate legend
+ showlegend = True
+ for k, index_value in enumerate(index_vals):
+ data.append(
+ dict(
+ x=[tasks[index]['x0'], tasks[index]['x0']],
+ y=[k, k],
+ showlegend=True,
+ name=str(index_value),
+ hoverinfo='none',
+ marker=dict(
+ color=colors[k],
+ size=1
+ )
+ )
+ )
+
+ layout = dict(
+ title=title,
+ showlegend=showlegend,
+ height=height,
+ width=width,
+ shapes=[],
+ hovermode='closest',
+ yaxis=dict(
+ showgrid=showgrid_y,
+ ticktext=task_names,
+ tickvals=list(range(len(task_names))),
+ range=[-1, len(task_names) + 1],
+ autorange=False,
+ zeroline=False,
+ ),
+ xaxis=dict(
+ showgrid=showgrid_x,
+ zeroline=False,
+ rangeselector=dict(
+ buttons=list([
+ dict(count=7,
+ label='1w',
+ step='day',
+ stepmode='backward'),
+ dict(count=1,
+ label='1m',
+ step='month',
+ stepmode='backward'),
+ dict(count=6,
+ label='6m',
+ step='month',
+ stepmode='backward'),
+ dict(count=1,
+ label='YTD',
+ step='year',
+ stepmode='todate'),
+ dict(count=1,
+ label='1y',
+ step='year',
+ stepmode='backward'),
+ dict(step='all')
+ ])
+ ),
+ type='date'
+ )
+ )
+ layout['shapes'] = tasks
+
+ fig = dict(data=data, layout=layout)
+ return fig
+
+
+def gantt_dict(chart, colors, title, index_col, show_colorbar, bar_width,
+ showgrid_x, showgrid_y, height, width, tasks=None,
+ task_names=None, data=None, group_tasks=False):
+ """
+ Refer to FigureFactory.create_gantt() for docstring
+ """
+ if tasks is None:
+ tasks = []
+ if task_names is None:
+ task_names = []
+ if data is None:
+ data = []
+ showlegend = False
+
+ for index in range(len(chart)):
+ task = dict(x0=chart[index]['Start'],
+ x1=chart[index]['Finish'],
+ name=chart[index]['Task'])
+ if 'Description' in chart[index]:
+ task['description'] = chart[index]['Description']
+ tasks.append(task)
+
+ shape_template = {
+ 'type': 'rect',
+ 'xref': 'x',
+ 'yref': 'y',
+ 'opacity': 1,
+ 'line': {
+ 'width': 0,
+ }
+ }
+
+ index_vals = []
+ for row in range(len(tasks)):
+ if chart[row][index_col] not in index_vals:
+ index_vals.append(chart[row][index_col])
+
+ index_vals.sort()
+
+ # verify each value in index column appears in colors dictionary
+ for key in index_vals:
+ if key not in colors:
+ raise exceptions.PlotlyError(
+ "If you are using colors as a dictionary, all of its "
+ "keys must be all the values in the index column."
+ )
+
+ # create the list of task names
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ # Is added to task_names if group_tasks is set to False,
+ # or if the option is used (True) it only adds them if the
+ # name is not already in the list
+ if not group_tasks or tn not in task_names:
+ task_names.append(tn)
+ # Guarantees that for grouped tasks the tasks that are inserted first
+ # are shown at the top
+ if group_tasks:
+ task_names.reverse()
+
+ for index in range(len(tasks)):
+ tn = tasks[index]['name']
+ del tasks[index]['name']
+ tasks[index].update(shape_template)
+
+ # If group_tasks is True, all tasks with the same name belong
+ # to the same row.
+ groupID = index
+ if group_tasks:
+ groupID = task_names.index(tn)
+ tasks[index]['y0'] = groupID - bar_width
+ tasks[index]['y1'] = groupID + bar_width
+
+ tasks[index]['fillcolor'] = colors[chart[index][index_col]]
+
+ # add a line for hover text and autorange
+ entry = dict(
+ x=[tasks[index]['x0'], tasks[index]['x1']],
+ y=[groupID, groupID],
+ showlegend=False,
+ name='',
+ marker={'color': 'white'}
+ )
+ if "description" in tasks[index]:
+ entry['text'] = tasks[index]['description']
+ del tasks[index]['description']
+ data.append(entry)
+
+ if show_colorbar is True:
+ # generate dummy data to generate legend
+ showlegend = True
+ for k, index_value in enumerate(index_vals):
+ data.append(
+ dict(
+ x=[tasks[index]['x0'], tasks[index]['x0']],
+ y=[k, k],
+ showlegend=True,
+ hoverinfo='none',
+ name=str(index_value),
+ marker=dict(
+ color=colors[index_value],
+ size=1
+ )
+ )
+ )
+
+ layout = dict(
+ title=title,
+ showlegend=showlegend,
+ height=height,
+ width=width,
+ shapes=[],
+ hovermode='closest',
+ yaxis=dict(
+ showgrid=showgrid_y,
+ ticktext=task_names,
+ tickvals=list(range(len(task_names))),
+ range=[-1, len(task_names) + 1],
+ autorange=False,
+ zeroline=False,
+ ),
+ xaxis=dict(
+ showgrid=showgrid_x,
+ zeroline=False,
+ rangeselector=dict(
+ buttons=list([
+ dict(count=7,
+ label='1w',
+ step='day',
+ stepmode='backward'),
+ dict(count=1,
+ label='1m',
+ step='month',
+ stepmode='backward'),
+ dict(count=6,
+ label='6m',
+ step='month',
+ stepmode='backward'),
+ dict(count=1,
+ label='YTD',
+ step='year',
+ stepmode='todate'),
+ dict(count=1,
+ label='1y',
+ step='year',
+ stepmode='backward'),
+ dict(step='all')
+ ])
+ ),
+ type='date'
+ )
+ )
+ layout['shapes'] = tasks
+
+ fig = dict(data=data, layout=layout)
+ return fig
+
+
+def create_gantt(df, colors=None, index_col=None, show_colorbar=False,
+ reverse_colors=False, title='Gantt Chart', bar_width=0.2,
+ showgrid_x=False, showgrid_y=False, height=600, width=900,
+ tasks=None, task_names=None, data=None, group_tasks=False):
+ """
+ Returns figure for a gantt chart
+
+ :param (array|list) df: input data for gantt chart. Must be either a
+ a dataframe or a list. If dataframe, the columns must include
+ 'Task', 'Start' and 'Finish'. Other columns can be included and
+ used for indexing. If a list, its elements must be dictionaries
+ with the same required column headers: 'Task', 'Start' and
+ 'Finish'.
+ :param (str|list|dict|tuple) colors: either a plotly scale name, an
+ rgb or hex color, a color tuple or a list of colors. An rgb color
+ is of the form 'rgb(x, y, z)' where x, y, z belong to the interval
+ [0, 255] and a color tuple is a tuple of the form (a, b, c) where
+ a, b and c belong to [0, 1]. If colors is a list, it must
+ contain the valid color types aforementioned as its members.
+ If a dictionary, all values of the indexing column must be keys in
+ colors.
+ :param (str|float) index_col: the column header (if df is a data
+ frame) that will function as the indexing column. If df is a list,
+ index_col must be one of the keys in all the items of df.
+ :param (bool) show_colorbar: determines if colorbar will be visible.
+ Only applies if values in the index column are numeric.
+ :param (bool) reverse_colors: reverses the order of selected colors
+ :param (str) title: the title of the chart
+ :param (float) bar_width: the width of the horizontal bars in the plot
+ :param (bool) showgrid_x: show/hide the x-axis grid
+ :param (bool) showgrid_y: show/hide the y-axis grid
+ :param (float) height: the height of the chart
+ :param (float) width: the width of the chart
+
+ Example 1: Simple Gantt Chart
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_gantt
+
+ # Make data for chart
+ df = [dict(Task="Job A", Start='2009-01-01', Finish='2009-02-30'),
+ dict(Task="Job B", Start='2009-03-05', Finish='2009-04-15'),
+ dict(Task="Job C", Start='2009-02-20', Finish='2009-05-30')]
+
+ # Create a figure
+ fig = create_gantt(df)
+
+ # Plot the data
+ py.iplot(fig, filename='Simple Gantt Chart', world_readable=True)
+ ```
+
+ Example 2: Index by Column with Numerical Entries
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_gantt
+
+ # Make data for chart
+ df = [dict(Task="Job A", Start='2009-01-01',
+ Finish='2009-02-30', Complete=10),
+ dict(Task="Job B", Start='2009-03-05',
+ Finish='2009-04-15', Complete=60),
+ dict(Task="Job C", Start='2009-02-20',
+ Finish='2009-05-30', Complete=95)]
+
+ # Create a figure with Plotly colorscale
+ fig = create_gantt(df, colors='Blues', index_col='Complete',
+ show_colorbar=True, bar_width=0.5,
+ showgrid_x=True, showgrid_y=True)
+
+ # Plot the data
+ py.iplot(fig, filename='Numerical Entries', world_readable=True)
+ ```
+
+ Example 3: Index by Column with String Entries
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_gantt
+
+ # Make data for chart
+ df = [dict(Task="Job A", Start='2009-01-01',
+ Finish='2009-02-30', Resource='Apple'),
+ dict(Task="Job B", Start='2009-03-05',
+ Finish='2009-04-15', Resource='Grape'),
+ dict(Task="Job C", Start='2009-02-20',
+ Finish='2009-05-30', Resource='Banana')]
+
+ # Create a figure with Plotly colorscale
+ fig = create_gantt(df, colors=['rgb(200, 50, 25)', (1, 0, 1), '#6c4774'],
+ index_col='Resource', reverse_colors=True,
+ show_colorbar=True)
+
+ # Plot the data
+ py.iplot(fig, filename='String Entries', world_readable=True)
+ ```
+
+ Example 4: Use a dictionary for colors
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_gantt
+
+ # Make data for chart
+ df = [dict(Task="Job A", Start='2009-01-01',
+ Finish='2009-02-30', Resource='Apple'),
+ dict(Task="Job B", Start='2009-03-05',
+ Finish='2009-04-15', Resource='Grape'),
+ dict(Task="Job C", Start='2009-02-20',
+ Finish='2009-05-30', Resource='Banana')]
+
+ # Make a dictionary of colors
+ colors = {'Apple': 'rgb(255, 0, 0)',
+ 'Grape': 'rgb(170, 14, 200)',
+ 'Banana': (1, 1, 0.2)}
+
+ # Create a figure with Plotly colorscale
+ fig = create_gantt(df, colors=colors, index_col='Resource',
+ show_colorbar=True)
+
+ # Plot the data
+ py.iplot(fig, filename='dictioanry colors', world_readable=True)
+ ```
+
+ Example 5: Use a pandas dataframe
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_gantt
+
+ import pandas as pd
+
+ # Make data as a dataframe
+ df = pd.DataFrame([['Run', '2010-01-01', '2011-02-02', 10],
+ ['Fast', '2011-01-01', '2012-06-05', 55],
+ ['Eat', '2012-01-05', '2013-07-05', 94]],
+ columns=['Task', 'Start', 'Finish', 'Complete'])
+
+ # Create a figure with Plotly colorscale
+ fig = create_gantt(df, colors='Blues', index_col='Complete',
+ show_colorbar=True, bar_width=0.5,
+ showgrid_x=True, showgrid_y=True)
+
+ # Plot the data
+ py.iplot(fig, filename='data with dataframe', world_readable=True)
+ ```
+ """
+ # validate gantt input data
+ chart = validate_gantt(df)
+
+ if index_col:
+ if index_col not in chart[0]:
+ raise exceptions.PlotlyError(
+ "In order to use an indexing column and assign colors to "
+ "the values of the index, you must choose an actual "
+ "column name in the dataframe or key if a list of "
+ "dictionaries is being used.")
+
+ # validate gantt index column
+ index_list = []
+ for dictionary in chart:
+ index_list.append(dictionary[index_col])
+ utils.validate_index(index_list)
+
+ # Validate colors
+ if isinstance(colors, dict):
+ colors = utils.validate_colors_dict(colors, 'rgb')
+ else:
+ colors = utils.validate_colors(colors, 'rgb')
+
+ if reverse_colors is True:
+ colors.reverse()
+
+ if not index_col:
+ if isinstance(colors, dict):
+ raise exceptions.PlotlyError(
+ "Error. You have set colors to a dictionary but have not "
+ "picked an index. An index is required if you are "
+ "assigning colors to particular values in a dictioanry."
+ )
+ fig = gantt(
+ chart, colors, title, bar_width, showgrid_x, showgrid_y,
+ height, width, tasks=None, task_names=None, data=None,
+ group_tasks=group_tasks
+ )
+ return fig
+ else:
+ if not isinstance(colors, dict):
+ fig = gantt_colorscale(
+ chart, colors, title, index_col, show_colorbar, bar_width,
+ showgrid_x, showgrid_y, height, width,
+ tasks=None, task_names=None, data=None, group_tasks=group_tasks
+ )
+ return fig
+ else:
+ fig = gantt_dict(
+ chart, colors, title, index_col, show_colorbar, bar_width,
+ showgrid_x, showgrid_y, height, width,
+ tasks=None, task_names=None, data=None, group_tasks=group_tasks
+ )
+ return fig
diff --git a/plotly/figure_factory/figure_factory/_ohlc.py b/plotly/figure_factory/figure_factory/_ohlc.py
new file mode 100644
index 00000000000..b5e84cd6d93
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_ohlc.py
@@ -0,0 +1,380 @@
+from __future__ import absolute_import
+
+from plotly import exceptions
+from plotly.graph_objs import graph_objs
+from plotly.figure_factory import utils
+
+
+# Default colours for finance charts
+_DEFAULT_INCREASING_COLOR = '#3D9970' # http://clrs.cc
+_DEFAULT_DECREASING_COLOR = '#FF4136'
+
+
+def validate_ohlc(open, high, low, close, direction, **kwargs):
+ """
+ ohlc and candlestick specific validations
+
+ Specifically, this checks that the high value is the greatest value and
+ the low value is the lowest value in each unit.
+
+ See FigureFactory.create_ohlc() or FigureFactory.create_candlestick()
+ for params
+
+ :raises: (PlotlyError) If the high value is not the greatest value in
+ each unit.
+ :raises: (PlotlyError) If the low value is not the lowest value in each
+ unit.
+ :raises: (PlotlyError) If direction is not 'increasing' or 'decreasing'
+ """
+ for lst in [open, low, close]:
+ for index in range(len(high)):
+ if high[index] < lst[index]:
+ raise exceptions.PlotlyError("Oops! Looks like some of "
+ "your high values are less "
+ "the corresponding open, "
+ "low, or close values. "
+ "Double check that your data "
+ "is entered in O-H-L-C order")
+
+ for lst in [open, high, close]:
+ for index in range(len(low)):
+ if low[index] > lst[index]:
+ raise exceptions.PlotlyError("Oops! Looks like some of "
+ "your low values are greater "
+ "than the corresponding high"
+ ", open, or close values. "
+ "Double check that your data "
+ "is entered in O-H-L-C order")
+
+ direction_opts = ('increasing', 'decreasing', 'both')
+ if direction not in direction_opts:
+ raise exceptions.PlotlyError("direction must be defined as "
+ "'increasing', 'decreasing', or "
+ "'both'")
+
+
+def make_increasing_ohlc(open, high, low, close, dates, **kwargs):
+ """
+ Makes increasing ohlc sticks
+
+ _make_increasing_ohlc() and _make_decreasing_ohlc separate the
+ increasing trace from the decreasing trace so kwargs (such as
+ color) can be passed separately to increasing or decreasing traces
+ when direction is set to 'increasing' or 'decreasing' in
+ FigureFactory.create_candlestick()
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to increasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (trace) ohlc_incr_data: Scatter trace of all increasing ohlc
+ sticks.
+ """
+ (flat_increase_x,
+ flat_increase_y,
+ text_increase) = _OHLC(open, high, low, close, dates).get_increase()
+
+ if 'name' in kwargs:
+ showlegend = True
+ else:
+ kwargs.setdefault('name', 'Increasing')
+ showlegend = False
+
+ kwargs.setdefault('line', dict(color=_DEFAULT_INCREASING_COLOR,
+ width=1))
+ kwargs.setdefault('text', text_increase)
+
+ ohlc_incr = dict(type='scatter',
+ x=flat_increase_x,
+ y=flat_increase_y,
+ mode='lines',
+ showlegend=showlegend,
+ **kwargs)
+ return ohlc_incr
+
+
+def make_decreasing_ohlc(open, high, low, close, dates, **kwargs):
+ """
+ Makes decreasing ohlc sticks
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing values
+ :param (list) dates: list of datetime objects. Default: None
+ :param kwargs: kwargs to be passed to increasing trace via
+ plotly.graph_objs.Scatter.
+
+ :rtype (trace) ohlc_decr_data: Scatter trace of all decreasing ohlc
+ sticks.
+ """
+ (flat_decrease_x,
+ flat_decrease_y,
+ text_decrease) = _OHLC(open, high, low, close, dates).get_decrease()
+
+ kwargs.setdefault('line', dict(color=_DEFAULT_DECREASING_COLOR,
+ width=1))
+ kwargs.setdefault('text', text_decrease)
+ kwargs.setdefault('showlegend', False)
+ kwargs.setdefault('name', 'Decreasing')
+
+ ohlc_decr = dict(type='scatter',
+ x=flat_decrease_x,
+ y=flat_decrease_y,
+ mode='lines',
+ **kwargs)
+ return ohlc_decr
+
+
+def create_ohlc(open, high, low, close, dates=None, direction='both',
+ **kwargs):
+ """
+ BETA function that creates an ohlc chart
+
+ :param (list) open: opening values
+ :param (list) high: high values
+ :param (list) low: low values
+ :param (list) close: closing
+ :param (list) dates: list of datetime objects. Default: None
+ :param (string) direction: direction can be 'increasing', 'decreasing',
+ or 'both'. When the direction is 'increasing', the returned figure
+ consists of all units where the close value is greater than the
+ corresponding open value, and when the direction is 'decreasing',
+ the returned figure consists of all units where the close value is
+ less than or equal to the corresponding open value. When the
+ direction is 'both', both increasing and decreasing units are
+ returned. Default: 'both'
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter.
+ These kwargs describe other attributes about the ohlc Scatter trace
+ such as the color or the legend name. For more information on valid
+ kwargs call help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of an ohlc chart figure.
+
+ Example 1: Simple OHLC chart from a Pandas DataFrame
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_ohlc
+ from datetime import datetime
+
+ import pandas.io.data as web
+
+ df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15),
+ datetime(2008, 10, 15))
+ fig = create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index)
+
+ py.plot(fig, filename='finance/aapl-ohlc')
+ ```
+
+ Example 2: Add text and annotations to the OHLC chart
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_ohlc
+ from datetime import datetime
+
+ import pandas.io.data as web
+
+ df = web.DataReader("aapl", 'yahoo', datetime(2008, 8, 15),
+ datetime(2008, 10, 15))
+ fig = create_ohlc(df.Open, df.High, df.Low, df.Close, dates=df.index)
+
+ # Update the fig - options here: https://plot.ly/python/reference/#Layout
+ fig['layout'].update({
+ 'title': 'The Great Recession',
+ 'yaxis': {'title': 'AAPL Stock'},
+ 'shapes': [{
+ 'x0': '2008-09-15', 'x1': '2008-09-15', 'type': 'line',
+ 'y0': 0, 'y1': 1, 'xref': 'x', 'yref': 'paper',
+ 'line': {'color': 'rgb(40,40,40)', 'width': 0.5}
+ }],
+ 'annotations': [{
+ 'text': "the fall of Lehman Brothers",
+ 'x': '2008-09-15', 'y': 1.02,
+ 'xref': 'x', 'yref': 'paper',
+ 'showarrow': False, 'xanchor': 'left'
+ }]
+ })
+
+ py.plot(fig, filename='finance/aapl-recession-ohlc', validate=False)
+ ```
+
+ Example 3: Customize the OHLC colors
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_ohlc
+ from plotly.graph_objs import Line, Marker
+ from datetime import datetime
+
+ import pandas.io.data as web
+
+ df = web.DataReader("aapl", 'yahoo', datetime(2008, 1, 1),
+ datetime(2009, 4, 1))
+
+ # Make increasing ohlc sticks and customize their color and name
+ fig_increasing = create_ohlc(df.Open, df.High, df.Low, df.Close,
+ dates=df.index, direction='increasing',
+ name='AAPL',
+ line=Line(color='rgb(150, 200, 250)'))
+
+ # Make decreasing ohlc sticks and customize their color and name
+ fig_decreasing = create_ohlc(df.Open, df.High, df.Low, df.Close,
+ dates=df.index, direction='decreasing',
+ line=Line(color='rgb(128, 128, 128)'))
+
+ # Initialize the figure
+ fig = fig_increasing
+
+ # Add decreasing data with .extend()
+ fig['data'].extend(fig_decreasing['data'])
+
+ py.iplot(fig, filename='finance/aapl-ohlc-colors', validate=False)
+ ```
+
+ Example 4: OHLC chart with datetime objects
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_ohlc
+
+ from datetime import datetime
+
+ # Add data
+ open_data = [33.0, 33.3, 33.5, 33.0, 34.1]
+ high_data = [33.1, 33.3, 33.6, 33.2, 34.8]
+ low_data = [32.7, 32.7, 32.8, 32.6, 32.8]
+ close_data = [33.0, 32.9, 33.3, 33.1, 33.1]
+ dates = [datetime(year=2013, month=10, day=10),
+ datetime(year=2013, month=11, day=10),
+ datetime(year=2013, month=12, day=10),
+ datetime(year=2014, month=1, day=10),
+ datetime(year=2014, month=2, day=10)]
+
+ # Create ohlc
+ fig = create_ohlc(open_data, high_data, low_data, close_data, dates=dates)
+
+ py.iplot(fig, filename='finance/simple-ohlc', validate=False)
+ ```
+ """
+ if dates is not None:
+ utils.validate_equal_length(open, high, low, close, dates)
+ else:
+ utils.validate_equal_length(open, high, low, close)
+ validate_ohlc(open, high, low, close, direction, **kwargs)
+
+ if direction is 'increasing':
+ ohlc_incr = make_increasing_ohlc(open, high, low, close, dates,
+ **kwargs)
+ data = [ohlc_incr]
+ elif direction is 'decreasing':
+ ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates,
+ **kwargs)
+ data = [ohlc_decr]
+ else:
+ ohlc_incr = make_increasing_ohlc(open, high, low, close, dates,
+ **kwargs)
+ ohlc_decr = make_decreasing_ohlc(open, high, low, close, dates,
+ **kwargs)
+ data = [ohlc_incr, ohlc_decr]
+
+ layout = graph_objs.Layout(xaxis=dict(zeroline=False),
+ hovermode='closest')
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _OHLC(object):
+ """
+ Refer to FigureFactory.create_ohlc_increase() for docstring.
+ """
+ def __init__(self, open, high, low, close, dates, **kwargs):
+ self.open = open
+ self.high = high
+ self.low = low
+ self.close = close
+ self.empty = [None] * len(open)
+ self.dates = dates
+
+ self.all_x = []
+ self.all_y = []
+ self.increase_x = []
+ self.increase_y = []
+ self.decrease_x = []
+ self.decrease_y = []
+ self.get_all_xy()
+ self.separate_increase_decrease()
+
+ def get_all_xy(self):
+ """
+ Zip data to create OHLC shape
+
+ OHLC shape: low to high vertical bar with
+ horizontal branches for open and close values.
+ If dates were added, the smallest date difference is calculated and
+ multiplied by .2 to get the length of the open and close branches.
+ If no date data was provided, the x-axis is a list of integers and the
+ length of the open and close branches is .2.
+ """
+ self.all_y = list(zip(self.open, self.open, self.high,
+ self.low, self.close, self.close, self.empty))
+ if self.dates is not None:
+ date_dif = []
+ for i in range(len(self.dates) - 1):
+ date_dif.append(self.dates[i + 1] - self.dates[i])
+ date_dif_min = (min(date_dif)) / 5
+ self.all_x = [[x - date_dif_min, x, x, x, x, x +
+ date_dif_min, None] for x in self.dates]
+ else:
+ self.all_x = [[x - .2, x, x, x, x, x + .2, None]
+ for x in range(len(self.open))]
+
+ def separate_increase_decrease(self):
+ """
+ Separate data into two groups: increase and decrease
+
+ (1) Increase, where close > open and
+ (2) Decrease, where close <= open
+ """
+ for index in range(len(self.open)):
+ if self.close[index] is None:
+ pass
+ elif self.close[index] > self.open[index]:
+ self.increase_x.append(self.all_x[index])
+ self.increase_y.append(self.all_y[index])
+ else:
+ self.decrease_x.append(self.all_x[index])
+ self.decrease_y.append(self.all_y[index])
+
+ def get_increase(self):
+ """
+ Flatten increase data and get increase text
+
+ :rtype (list, list, list): flat_increase_x: x-values for the increasing
+ trace, flat_increase_y: y=values for the increasing trace and
+ text_increase: hovertext for the increasing trace
+ """
+ flat_increase_x = utils.flatten(self.increase_x)
+ flat_increase_y = utils.flatten(self.increase_y)
+ text_increase = (("Open", "Open", "High",
+ "Low", "Close", "Close", '')
+ * (len(self.increase_x)))
+
+ return flat_increase_x, flat_increase_y, text_increase
+
+ def get_decrease(self):
+ """
+ Flatten decrease data and get decrease text
+
+ :rtype (list, list, list): flat_decrease_x: x-values for the decreasing
+ trace, flat_decrease_y: y=values for the decreasing trace and
+ text_decrease: hovertext for the decreasing trace
+ """
+ flat_decrease_x = utils.flatten(self.decrease_x)
+ flat_decrease_y = utils.flatten(self.decrease_y)
+ text_decrease = (("Open", "Open", "High",
+ "Low", "Close", "Close", '')
+ * (len(self.decrease_x)))
+
+ return flat_decrease_x, flat_decrease_y, text_decrease
diff --git a/plotly/figure_factory/figure_factory/_quiver.py b/plotly/figure_factory/figure_factory/_quiver.py
new file mode 100644
index 00000000000..ef9d00e80dd
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_quiver.py
@@ -0,0 +1,283 @@
+from __future__ import absolute_import
+
+import math
+
+from plotly import exceptions
+from plotly.graph_objs import graph_objs
+from plotly.figure_factory import utils
+
+
+def create_quiver(x, y, u, v, scale=.1, arrow_scale=.3,
+ angle=math.pi / 9, scaleratio=None, **kwargs):
+ """
+ Returns data for a quiver plot.
+
+ :param (list|ndarray) x: x coordinates of the arrow locations
+ :param (list|ndarray) y: y coordinates of the arrow locations
+ :param (list|ndarray) u: x components of the arrow vectors
+ :param (list|ndarray) v: y components of the arrow vectors
+ :param (float in [0,1]) scale: scales size of the arrows(ideally to
+ avoid overlap). Default = .1
+ :param (float in [0,1]) arrow_scale: value multiplied to length of barb
+ to get length of arrowhead. Default = .3
+ :param (angle in radians) angle: angle of arrowhead. Default = pi/9
+ :param (positive float) scaleratio: the ratio between the scale of the y-axis
+ and the scale of the x-axis (scale_y / scale_x). Default = None, the
+ scale ratio is not fixed.
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter
+ for more information on valid kwargs call
+ help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of quiver figure.
+
+ Example 1: Trivial Quiver
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_quiver
+
+ import math
+
+ # 1 Arrow from (0,0) to (1,1)
+ fig = create_quiver(x=[0], y=[0], u=[1], v=[1], scale=1)
+
+ py.plot(fig, filename='quiver')
+ ```
+
+ Example 2: Quiver plot using meshgrid
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_quiver
+
+ import numpy as np
+ import math
+
+ # Add data
+ x,y = np.meshgrid(np.arange(0, 2, .2), np.arange(0, 2, .2))
+ u = np.cos(x)*y
+ v = np.sin(x)*y
+
+ #Create quiver
+ fig = create_quiver(x, y, u, v)
+
+ # Plot
+ py.plot(fig, filename='quiver')
+ ```
+
+ Example 3: Styling the quiver plot
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_quiver
+ import numpy as np
+ import math
+
+ # Add data
+ x, y = np.meshgrid(np.arange(-np.pi, math.pi, .5),
+ np.arange(-math.pi, math.pi, .5))
+ u = np.cos(x)*y
+ v = np.sin(x)*y
+
+ # Create quiver
+ fig = create_quiver(x, y, u, v, scale=.2, arrow_scale=.3, angle=math.pi/6,
+ name='Wind Velocity', line=dict(width=1))
+
+ # Add title to layout
+ fig['layout'].update(title='Quiver Plot')
+
+ # Plot
+ py.plot(fig, filename='quiver')
+ ```
+
+ Example 4: Forcing a fix scale ratio to maintain the arrow length
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_quiver
+
+ import numpy as np
+
+ # Add data
+ x,y = np.meshgrid(np.arange(0.5, 3.5, .5), np.arange(0.5, 4.5, .5))
+ u = x
+ v = y
+ angle = np.arctan(v / u)
+ norm = 0.25
+ u = norm * np.cos(angle)
+ v = norm * np.sin(angle)
+
+ # Create quiver with a fix scale ratio
+ fig = create_quiver(x, y, u, v, scale = 1, scaleratio = 0.5)
+
+ # Plot
+ py.plot(fig, filename='quiver')
+ ```
+ """
+ utils.validate_equal_length(x, y, u, v)
+ utils.validate_positive_scalars(arrow_scale=arrow_scale, scale=scale)
+
+ if scaleratio is None:
+ quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle)
+ else:
+ quiver_obj = _Quiver(x, y, u, v, scale, arrow_scale, angle, scaleratio)
+
+ barb_x, barb_y = quiver_obj.get_barbs()
+ arrow_x, arrow_y = quiver_obj.get_quiver_arrows()
+
+ quiver_plot = graph_objs.Scatter(x=barb_x + arrow_x,
+ y=barb_y + arrow_y,
+ mode='lines', **kwargs)
+
+ data = [quiver_plot]
+
+ if scaleratio is None:
+ layout = graph_objs.Layout(hovermode='closest')
+ else:
+ layout = graph_objs.Layout(
+ hovermode='closest',
+ yaxis=dict(
+ scaleratio = scaleratio,
+ scaleanchor = "x"
+ )
+ )
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+class _Quiver(object):
+ """
+ Refer to FigureFactory.create_quiver() for docstring
+ """
+ def __init__(self, x, y, u, v,
+ scale, arrow_scale, angle, scaleratio=1, **kwargs):
+ try:
+ x = utils.flatten(x)
+ except exceptions.PlotlyError:
+ pass
+
+ try:
+ y = utils.flatten(y)
+ except exceptions.PlotlyError:
+ pass
+
+ try:
+ u = utils.flatten(u)
+ except exceptions.PlotlyError:
+ pass
+
+ try:
+ v = utils.flatten(v)
+ except exceptions.PlotlyError:
+ pass
+
+ self.x = x
+ self.y = y
+ self.u = u
+ self.v = v
+ self.scale = scale
+ self.scaleratio = scaleratio
+ self.arrow_scale = arrow_scale
+ self.angle = angle
+ self.end_x = []
+ self.end_y = []
+ self.scale_uv()
+ barb_x, barb_y = self.get_barbs()
+ arrow_x, arrow_y = self.get_quiver_arrows()
+
+ def scale_uv(self):
+ """
+ Scales u and v to avoid overlap of the arrows.
+
+ u and v are added to x and y to get the
+ endpoints of the arrows so a smaller scale value will
+ result in less overlap of arrows.
+ """
+ self.u = [i * self.scale * self.scaleratio for i in self.u]
+ self.v = [i * self.scale for i in self.v]
+
+ def get_barbs(self):
+ """
+ Creates x and y startpoint and endpoint pairs
+
+ After finding the endpoint of each barb this zips startpoint and
+ endpoint pairs to create 2 lists: x_values for barbs and y values
+ for barbs
+
+ :rtype: (list, list) barb_x, barb_y: list of startpoint and endpoint
+ x_value pairs separated by a None to create the barb of the arrow,
+ and list of startpoint and endpoint y_value pairs separated by a
+ None to create the barb of the arrow.
+ """
+ self.end_x = [i + j for i, j in zip(self.x, self.u)]
+ self.end_y = [i + j for i, j in zip(self.y, self.v)]
+ empty = [None] * len(self.x)
+ barb_x = utils.flatten(zip(self.x, self.end_x, empty))
+ barb_y = utils.flatten(zip(self.y, self.end_y, empty))
+ return barb_x, barb_y
+
+ def get_quiver_arrows(self):
+ """
+ Creates lists of x and y values to plot the arrows
+
+ Gets length of each barb then calculates the length of each side of
+ the arrow. Gets angle of barb and applies angle to each side of the
+ arrowhead. Next uses arrow_scale to scale the length of arrowhead and
+ creates x and y values for arrowhead point1 and point2. Finally x and y
+ values for point1, endpoint and point2s for each arrowhead are
+ separated by a None and zipped to create lists of x and y values for
+ the arrows.
+
+ :rtype: (list, list) arrow_x, arrow_y: list of point1, endpoint, point2
+ x_values separated by a None to create the arrowhead and list of
+ point1, endpoint, point2 y_values separated by a None to create
+ the barb of the arrow.
+ """
+ dif_x = [i - j for i, j in zip(self.end_x, self.x)]
+ dif_y = [i - j for i, j in zip(self.end_y, self.y)]
+
+ # Get barb lengths(default arrow length = 30% barb length)
+ barb_len = [None] * len(self.x)
+ for index in range(len(barb_len)):
+ barb_len[index] = math.hypot(dif_x[index] / self.scaleratio, dif_y[index])
+
+ # Make arrow lengths
+ arrow_len = [None] * len(self.x)
+ arrow_len = [i * self.arrow_scale for i in barb_len]
+
+ # Get barb angles
+ barb_ang = [None] * len(self.x)
+ for index in range(len(barb_ang)):
+ barb_ang[index] = math.atan2(dif_y[index], dif_x[index] / self.scaleratio)
+
+ # Set angles to create arrow
+ ang1 = [i + self.angle for i in barb_ang]
+ ang2 = [i - self.angle for i in barb_ang]
+
+ cos_ang1 = [None] * len(ang1)
+ for index in range(len(ang1)):
+ cos_ang1[index] = math.cos(ang1[index])
+ seg1_x = [i * j for i, j in zip(arrow_len, cos_ang1)]
+
+ sin_ang1 = [None] * len(ang1)
+ for index in range(len(ang1)):
+ sin_ang1[index] = math.sin(ang1[index])
+ seg1_y = [i * j for i, j in zip(arrow_len, sin_ang1)]
+
+ cos_ang2 = [None] * len(ang2)
+ for index in range(len(ang2)):
+ cos_ang2[index] = math.cos(ang2[index])
+ seg2_x = [i * j for i, j in zip(arrow_len, cos_ang2)]
+
+ sin_ang2 = [None] * len(ang2)
+ for index in range(len(ang2)):
+ sin_ang2[index] = math.sin(ang2[index])
+ seg2_y = [i * j for i, j in zip(arrow_len, sin_ang2)]
+
+ # Set coordinates to create arrow
+ for index in range(len(self.end_x)):
+ point1_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg1_x)]
+ point1_y = [i - j for i, j in zip(self.end_y, seg1_y)]
+ point2_x = [i - j * self.scaleratio for i, j in zip(self.end_x, seg2_x)]
+ point2_y = [i - j for i, j in zip(self.end_y, seg2_y)]
+
+ # Combine lists to create arrow
+ empty = [None] * len(self.end_x)
+ arrow_x = utils.flatten(zip(point1_x, self.end_x, point2_x, empty))
+ arrow_y = utils.flatten(zip(point1_y, self.end_y, point2_y, empty))
+ return arrow_x, arrow_y
diff --git a/plotly/figure_factory/figure_factory/_scatterplot.py b/plotly/figure_factory/figure_factory/_scatterplot.py
new file mode 100644
index 00000000000..25ae4f0976f
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_scatterplot.py
@@ -0,0 +1,1104 @@
+from __future__ import absolute_import
+
+import six
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+from plotly.tools import make_subplots
+
+pd = optional_imports.get_module('pandas')
+
+DIAG_CHOICES = ['scatter', 'histogram', 'box']
+VALID_COLORMAP_TYPES = ['cat', 'seq']
+
+
+def hide_tick_labels_from_box_subplots(fig):
+ """
+ Hides tick labels for box plots in scatterplotmatrix subplots.
+ """
+ boxplot_xaxes = []
+ for trace in fig['data']:
+ if trace['type'] == 'box':
+ # stores the xaxes which correspond to boxplot subplots
+ # since we use xaxis1, xaxis2, etc, in plotly.py
+ boxplot_xaxes.append(
+ 'xaxis{}'.format(trace['xaxis'][1:])
+ )
+ for xaxis in boxplot_xaxes:
+ fig['layout'][xaxis]['showticklabels'] = False
+
+
+def validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs):
+ """
+ Validates basic inputs for FigureFactory.create_scatterplotmatrix()
+
+ :raises: (PlotlyError) If pandas is not imported
+ :raises: (PlotlyError) If pandas dataframe is not inputted
+ :raises: (PlotlyError) If pandas dataframe has <= 1 columns
+ :raises: (PlotlyError) If diagonal plot choice (diag) is not one of
+ the viable options
+ :raises: (PlotlyError) If colormap_type is not a valid choice
+ :raises: (PlotlyError) If kwargs contains 'size', 'color' or
+ 'colorscale'
+ """
+ if not pd:
+ raise ImportError("FigureFactory.scatterplotmatrix requires "
+ "a pandas DataFrame.")
+
+ # Check if pandas dataframe
+ if not isinstance(df, pd.core.frame.DataFrame):
+ raise exceptions.PlotlyError("Dataframe not inputed. Please "
+ "use a pandas dataframe to pro"
+ "duce a scatterplot matrix.")
+
+ # Check if dataframe is 1 column or less
+ if len(df.columns) <= 1:
+ raise exceptions.PlotlyError("Dataframe has only one column. To "
+ "use the scatterplot matrix, use at "
+ "least 2 columns.")
+
+ # Check that diag parameter is a valid selection
+ if diag not in DIAG_CHOICES:
+ raise exceptions.PlotlyError("Make sure diag is set to "
+ "one of {}".format(DIAG_CHOICES))
+
+ # Check that colormap_types is a valid selection
+ if colormap_type not in VALID_COLORMAP_TYPES:
+ raise exceptions.PlotlyError("Must choose a valid colormap type. "
+ "Either 'cat' or 'seq' for a cate"
+ "gorical and sequential colormap "
+ "respectively.")
+
+ # Check for not 'size' or 'color' in 'marker' of **kwargs
+ if 'marker' in kwargs:
+ FORBIDDEN_PARAMS = ['size', 'color', 'colorscale']
+ if any(param in kwargs['marker'] for param in FORBIDDEN_PARAMS):
+ raise exceptions.PlotlyError("Your kwargs dictionary cannot "
+ "include the 'size', 'color' or "
+ "'colorscale' key words inside "
+ "the marker dict since 'size' is "
+ "already an argument of the "
+ "scatterplot matrix function and "
+ "both 'color' and 'colorscale "
+ "are set internally.")
+
+
+def scatterplot(dataframe, headers, diag, size, height, width, title,
+ **kwargs):
+ """
+ Refer to FigureFactory.create_scatterplotmatrix() for docstring
+
+ Returns fig for scatterplotmatrix without index
+
+ """
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ # Insert traces into trace_list
+ for listy in dataframe:
+ for listx in dataframe:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=listx,
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=listx,
+ name=None,
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ showlegend=False,
+ **kwargs
+ )
+ trace_list.append(trace)
+ else:
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ marker=dict(
+ size=size),
+ showlegend=False,
+ **kwargs
+ )
+ trace_list.append(trace)
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ fig.append_trace(trace_list[trace_index],
+ y_index,
+ x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
+ fig['layout'][xaxis_key].update(title=headers[j])
+ for j in range(dim):
+ yaxis_key = 'yaxis{}'.format(1 + (dim * j))
+ fig['layout'][yaxis_key].update(title=headers[j])
+
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True
+ )
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ return fig
+
+
+def scatterplot_dict(dataframe, headers, diag, size,
+ height, width, title, index, index_vals,
+ endpts, colormap, colormap_type, **kwargs):
+ """
+ Refer to FigureFactory.create_scatterplotmatrix() for docstring
+
+ Returns fig for scatterplotmatrix with both index and colormap picked.
+ Used if colormap is a dictionary with index values as keys pointing to
+ colors. Forces colormap_type to behave categorically because it would
+ not make sense colors are assigned to each index value and thus
+ implies that a categorical approach should be taken
+
+ """
+
+ theme = colormap
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Work over all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ # create a dictionary for index_vals
+ unique_index_vals = {}
+ for name in index_vals:
+ if name not in unique_index_vals:
+ unique_index_vals[name] = []
+
+ # Fill all the rest of the names into the dictionary
+ for name in sorted(unique_index_vals.keys()):
+ new_listx = []
+ new_listy = []
+ for j in range(len(index_vals)):
+ if index_vals[j] == name:
+ new_listx.append(listx[j])
+ new_listy.append(listy[j])
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[name]),
+ showlegend=True
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[name]),
+ showlegend=True
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = theme[name]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ showlegend=True,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ marker=dict(
+ size=size,
+ color=theme[name]),
+ showlegend=True,
+ **kwargs
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[name]),
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[name]),
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = theme[name]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ showlegend=False,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ marker=dict(
+ size=size,
+ color=theme[name]),
+ showlegend=False,
+ **kwargs
+ )
+ # Push the trace into dictionary
+ unique_index_vals[name] = trace
+ trace_list.append(unique_index_vals)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ for name in sorted(trace_list[trace_index].keys()):
+ fig.append_trace(
+ trace_list[trace_index][name],
+ y_index,
+ x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
+ fig['layout'][xaxis_key].update(title=headers[j])
+
+ for j in range(dim):
+ yaxis_key = 'yaxis{}'.format(1 + (dim * j))
+ fig['layout'][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == 'histogram':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True,
+ barmode='stack')
+ return fig
+
+ else:
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+
+def scatterplot_theme(dataframe, headers, diag, size, height, width, title,
+ index, index_vals, endpts, colormap, colormap_type,
+ **kwargs):
+ """
+ Refer to FigureFactory.create_scatterplotmatrix() for docstring
+
+ Returns fig for scatterplotmatrix with both index and colormap picked
+
+ """
+
+ # Check if index is made of string values
+ if isinstance(index_vals[0], str):
+ unique_index_vals = []
+ for name in index_vals:
+ if name not in unique_index_vals:
+ unique_index_vals.append(name)
+ n_colors_len = len(unique_index_vals)
+
+ # Convert colormap to list of n RGB tuples
+ if colormap_type == 'seq':
+ foo = colors.color_parser(colormap, colors.unlabel_rgb)
+ foo = utils.n_colors(foo[0], foo[1], n_colors_len)
+ theme = colors.color_parser(foo, colors.label_rgb)
+
+ if colormap_type == 'cat':
+ # leave list of colors the same way
+ theme = colormap
+
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Work over all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ # create a dictionary for index_vals
+ unique_index_vals = {}
+ for name in index_vals:
+ if name not in unique_index_vals:
+ unique_index_vals[name] = []
+
+ c_indx = 0 # color index
+ # Fill all the rest of the names into the dictionary
+ for name in sorted(unique_index_vals.keys()):
+ new_listx = []
+ new_listy = []
+ for j in range(len(index_vals)):
+ if index_vals[j] == name:
+ new_listx.append(listx[j])
+ new_listy.append(listy[j])
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=True
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=True
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ showlegend=True,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ marker=dict(
+ size=size,
+ color=theme[c_indx]),
+ showlegend=True,
+ **kwargs
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ showlegend=False,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=name,
+ marker=dict(
+ size=size,
+ color=theme[c_indx]),
+ showlegend=False,
+ **kwargs
+ )
+ # Push the trace into dictionary
+ unique_index_vals[name] = trace
+ if c_indx >= (len(theme) - 1):
+ c_indx = -1
+ c_indx += 1
+ trace_list.append(unique_index_vals)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ for name in sorted(trace_list[trace_index].keys()):
+ fig.append_trace(
+ trace_list[trace_index][name],
+ y_index,
+ x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
+ fig['layout'][xaxis_key].update(title=headers[j])
+
+ for j in range(dim):
+ yaxis_key = 'yaxis{}'.format(1 + (dim * j))
+ fig['layout'][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == 'histogram':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True,
+ barmode='stack')
+ return fig
+
+ elif diag == 'box':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+ else:
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+ else:
+ if endpts:
+ intervals = utils.endpts_to_intervals(endpts)
+
+ # Convert colormap to list of n RGB tuples
+ if colormap_type == 'seq':
+ foo = colors.color_parser(colormap, colors.unlabel_rgb)
+ foo = utils.n_colors(foo[0], foo[1], len(intervals))
+ theme = colors.color_parser(foo, colors.label_rgb)
+
+ if colormap_type == 'cat':
+ # leave list of colors the same way
+ theme = colormap
+
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Work over all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ interval_labels = {}
+ for interval in intervals:
+ interval_labels[str(interval)] = []
+
+ c_indx = 0 # color index
+ # Fill all the rest of the names into the dictionary
+ for interval in intervals:
+ new_listx = []
+ new_listy = []
+ for j in range(len(index_vals)):
+ if interval[0] < index_vals[j] <= interval[1]:
+ new_listx.append(listx[j])
+ new_listy.append(listy[j])
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=True
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=True
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ (kwargs['marker']
+ ['color']) = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=str(interval),
+ showlegend=True,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=str(interval),
+ marker=dict(
+ size=size,
+ color=theme[c_indx]),
+ showlegend=True,
+ **kwargs
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=new_listx,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=new_listx,
+ name=None,
+ marker=dict(
+ color=theme[c_indx]),
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ (kwargs['marker']
+ ['color']) = theme[c_indx]
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=str(interval),
+ showlegend=False,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=new_listx,
+ y=new_listy,
+ mode='markers',
+ name=str(interval),
+ marker=dict(
+ size=size,
+ color=theme[c_indx]),
+ showlegend=False,
+ **kwargs
+ )
+ # Push the trace into dictionary
+ interval_labels[str(interval)] = trace
+ if c_indx >= (len(theme) - 1):
+ c_indx = -1
+ c_indx += 1
+ trace_list.append(interval_labels)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ for interval in intervals:
+ fig.append_trace(
+ trace_list[trace_index][str(interval)],
+ y_index,
+ x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
+ fig['layout'][xaxis_key].update(title=headers[j])
+ for j in range(dim):
+ yaxis_key = 'yaxis{}'.format(1 + (dim * j))
+ fig['layout'][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == 'histogram':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True,
+ barmode='stack')
+ return fig
+
+ elif diag == 'box':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+ else:
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+ else:
+ theme = colormap
+
+ # add a copy of rgb color to theme if it contains one color
+ if len(theme) <= 1:
+ theme.append(theme[0])
+
+ color = []
+ for incr in range(len(theme)):
+ color.append([1. / (len(theme) - 1) * incr, theme[incr]])
+
+ dim = len(dataframe)
+ fig = make_subplots(rows=dim, cols=dim, print_grid=False)
+ trace_list = []
+ legend_param = 0
+ # Run through all permutations of list pairs
+ for listy in dataframe:
+ for listx in dataframe:
+ # Generate trace with VISIBLE icon
+ if legend_param == 1:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=listx,
+ marker=dict(
+ color=theme[0]),
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=listx,
+ marker=dict(
+ color=theme[0]),
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = index_vals
+ kwargs['marker']['colorscale'] = color
+ kwargs['marker']['showscale'] = True
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ showlegend=False,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ marker=dict(
+ size=size,
+ color=index_vals,
+ colorscale=color,
+ showscale=True),
+ showlegend=False,
+ **kwargs
+ )
+ # Generate trace with INVISIBLE icon
+ else:
+ if (listx == listy) and (diag == 'histogram'):
+ trace = graph_objs.Histogram(
+ x=listx,
+ marker=dict(
+ color=theme[0]),
+ showlegend=False
+ )
+ elif (listx == listy) and (diag == 'box'):
+ trace = graph_objs.Box(
+ y=listx,
+ marker=dict(
+ color=theme[0]),
+ showlegend=False
+ )
+ else:
+ if 'marker' in kwargs:
+ kwargs['marker']['size'] = size
+ kwargs['marker']['color'] = index_vals
+ kwargs['marker']['colorscale'] = color
+ kwargs['marker']['showscale'] = False
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ showlegend=False,
+ **kwargs
+ )
+ else:
+ trace = graph_objs.Scatter(
+ x=listx,
+ y=listy,
+ mode='markers',
+ marker=dict(
+ size=size,
+ color=index_vals,
+ colorscale=color,
+ showscale=False),
+ showlegend=False,
+ **kwargs
+ )
+ # Push the trace into list
+ trace_list.append(trace)
+ legend_param += 1
+
+ trace_index = 0
+ indices = range(1, dim + 1)
+ for y_index in indices:
+ for x_index in indices:
+ fig.append_trace(trace_list[trace_index],
+ y_index,
+ x_index)
+ trace_index += 1
+
+ # Insert headers into the figure
+ for j in range(dim):
+ xaxis_key = 'xaxis{}'.format((dim * dim) - dim + 1 + j)
+ fig['layout'][xaxis_key].update(title=headers[j])
+ for j in range(dim):
+ yaxis_key = 'yaxis{}'.format(1 + (dim * j))
+ fig['layout'][yaxis_key].update(title=headers[j])
+
+ hide_tick_labels_from_box_subplots(fig)
+
+ if diag == 'histogram':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True,
+ barmode='stack')
+ return fig
+
+ elif diag == 'box':
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+ else:
+ fig['layout'].update(
+ height=height, width=width,
+ title=title,
+ showlegend=True)
+ return fig
+
+
+def create_scatterplotmatrix(df, index=None, endpts=None, diag='scatter',
+ height=500, width=500, size=6,
+ title='Scatterplot Matrix', colormap=None,
+ colormap_type='cat', dataframe=None,
+ headers=None, index_vals=None, **kwargs):
+ """
+ Returns data for a scatterplot matrix.
+
+ :param (array) df: array of the data with column headers
+ :param (str) index: name of the index column in data array
+ :param (list|tuple) endpts: takes an increasing sequece of numbers
+ that defines intervals on the real line. They are used to group
+ the entries in an index of numbers into their corresponding
+ interval and therefore can be treated as categorical data
+ :param (str) diag: sets the chart type for the main diagonal plots.
+ The options are 'scatter', 'histogram' and 'box'.
+ :param (int|float) height: sets the height of the chart
+ :param (int|float) width: sets the width of the chart
+ :param (float) size: sets the marker size (in px)
+ :param (str) title: the title label of the scatterplot matrix
+ :param (str|tuple|list|dict) colormap: either a plotly scale name,
+ an rgb or hex color, a color tuple, a list of colors or a
+ dictionary. An rgb color is of the form 'rgb(x, y, z)' where
+ x, y and z belong to the interval [0, 255] and a color tuple is a
+ tuple of the form (a, b, c) where a, b and c belong to [0, 1].
+ If colormap is a list, it must contain valid color types as its
+ members.
+ If colormap is a dictionary, all the string entries in
+ the index column must be a key in colormap. In this case, the
+ colormap_type is forced to 'cat' or categorical
+ :param (str) colormap_type: determines how colormap is interpreted.
+ Valid choices are 'seq' (sequential) and 'cat' (categorical). If
+ 'seq' is selected, only the first two colors in colormap will be
+ considered (when colormap is a list) and the index values will be
+ linearly interpolated between those two colors. This option is
+ forced if all index values are numeric.
+ If 'cat' is selected, a color from colormap will be assigned to
+ each category from index, including the intervals if endpts is
+ being used
+ :param (dict) **kwargs: a dictionary of scatterplot arguments
+ The only forbidden parameters are 'size', 'color' and
+ 'colorscale' in 'marker'
+
+ Example 1: Vanilla Scatterplot Matrix
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+
+ # Create dataframe
+ df = pd.DataFrame(np.random.randn(10, 2),
+ columns=['Column 1', 'Column 2'])
+
+ # Create scatterplot matrix
+ fig = create_scatterplotmatrix(df)
+
+ # Plot
+ py.iplot(fig, filename='Vanilla Scatterplot Matrix')
+ ```
+
+ Example 2: Indexing a Column
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+
+ # Create dataframe with index
+ df = pd.DataFrame(np.random.randn(10, 2),
+ columns=['A', 'B'])
+
+ # Add another column of strings to the dataframe
+ df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple',
+ 'grape', 'pear', 'pear', 'apple', 'pear'])
+
+ # Create scatterplot matrix
+ fig = create_scatterplotmatrix(df, index='Fruit', size=10)
+
+ # Plot
+ py.iplot(fig, filename = 'Scatterplot Matrix with Index')
+ ```
+
+ Example 3: Styling the Diagonal Subplots
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+
+ # Create dataframe with index
+ df = pd.DataFrame(np.random.randn(10, 4),
+ columns=['A', 'B', 'C', 'D'])
+
+ # Add another column of strings to the dataframe
+ df['Fruit'] = pd.Series(['apple', 'apple', 'grape', 'apple', 'apple',
+ 'grape', 'pear', 'pear', 'apple', 'pear'])
+
+ # Create scatterplot matrix
+ fig = create_scatterplotmatrix(df, diag='box', index='Fruit', height=1000,
+ width=1000)
+
+ # Plot
+ py.iplot(fig, filename = 'Scatterplot Matrix - Diagonal Styling')
+ ```
+
+ Example 4: Use a Theme to Style the Subplots
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+
+ # Create dataframe with random data
+ df = pd.DataFrame(np.random.randn(100, 3),
+ columns=['A', 'B', 'C'])
+
+ # Create scatterplot matrix using a built-in
+ # Plotly palette scale and indexing column 'A'
+ fig = create_scatterplotmatrix(df, diag='histogram', index='A',
+ colormap='Blues', height=800, width=800)
+
+ # Plot
+ py.iplot(fig, filename = 'Scatterplot Matrix - Colormap Theme')
+ ```
+
+ Example 5: Example 4 with Interval Factoring
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+
+ # Create dataframe with random data
+ df = pd.DataFrame(np.random.randn(100, 3),
+ columns=['A', 'B', 'C'])
+
+ # Create scatterplot matrix using a list of 2 rgb tuples
+ # and endpoints at -1, 0 and 1
+ fig = create_scatterplotmatrix(df, diag='histogram', index='A',
+ colormap=['rgb(140, 255, 50)',
+ 'rgb(170, 60, 115)', '#6c4774',
+ (0.5, 0.1, 0.8)],
+ endpts=[-1, 0, 1], height=800, width=800)
+
+ # Plot
+ py.iplot(fig, filename = 'Scatterplot Matrix - Intervals')
+ ```
+
+ Example 6: Using the colormap as a Dictionary
+ ```
+ import plotly.plotly as py
+ from plotly.graph_objs import graph_objs
+ from plotly.figure_factory import create_scatterplotmatrix
+
+ import numpy as np
+ import pandas as pd
+ import random
+
+ # Create dataframe with random data
+ df = pd.DataFrame(np.random.randn(100, 3),
+ columns=['Column A',
+ 'Column B',
+ 'Column C'])
+
+ # Add new color column to dataframe
+ new_column = []
+ strange_colors = ['turquoise', 'limegreen', 'goldenrod']
+
+ for j in range(100):
+ new_column.append(random.choice(strange_colors))
+ df['Colors'] = pd.Series(new_column, index=df.index)
+
+ # Create scatterplot matrix using a dictionary of hex color values
+ # which correspond to actual color names in 'Colors' column
+ fig = create_scatterplotmatrix(
+ df, diag='box', index='Colors',
+ colormap= dict(
+ turquoise = '#00F5FF',
+ limegreen = '#32CD32',
+ goldenrod = '#DAA520'
+ ),
+ colormap_type='cat',
+ height=800, width=800
+ )
+
+ # Plot
+ py.iplot(fig, filename = 'Scatterplot Matrix - colormap dictionary ')
+ ```
+ """
+ # TODO: protected until #282
+ if dataframe is None:
+ dataframe = []
+ if headers is None:
+ headers = []
+ if index_vals is None:
+ index_vals = []
+
+ validate_scatterplotmatrix(df, index, diag, colormap_type, **kwargs)
+
+ # Validate colormap
+ if isinstance(colormap, dict):
+ colormap = utils.validate_colors_dict(colormap, 'rgb')
+ elif isinstance(colormap, six.string_types) and 'rgb' not in colormap and '#' not in colormap:
+ if colormap not in utils.PLOTLY_SCALES.keys():
+ raise exceptions.PlotlyError(
+ "If 'colormap' is a string, it must be the name "
+ "of a Plotly Colorscale. The available colorscale "
+ "names are {}".format(utils.PLOTLY_SCALES.keys())
+ )
+ else:
+ # TODO change below to allow the correct Plotly colorscale
+ colormap = utils.colorscale_to_colors(utils.PLOTLY_SCALES[colormap])
+ # keep only first and last item - fix later
+ colormap = [colormap[0]] + [colormap[-1]]
+ colormap = utils.validate_colors(colormap, 'rgb')
+ else:
+ colormap = utils.validate_colors(colormap, 'rgb')
+
+ if not index:
+ for name in df:
+ headers.append(name)
+ for name in headers:
+ dataframe.append(df[name].values.tolist())
+ # Check for same data-type in df columns
+ utils.validate_dataframe(dataframe)
+ figure = scatterplot(dataframe, headers, diag, size, height, width,
+ title, **kwargs)
+ return figure
+ else:
+ # Validate index selection
+ if index not in df:
+ raise exceptions.PlotlyError("Make sure you set the index "
+ "input variable to one of the "
+ "column names of your "
+ "dataframe.")
+ index_vals = df[index].values.tolist()
+ for name in df:
+ if name != index:
+ headers.append(name)
+ for name in headers:
+ dataframe.append(df[name].values.tolist())
+
+ # check for same data-type in each df column
+ utils.validate_dataframe(dataframe)
+ utils.validate_index(index_vals)
+
+ # check if all colormap keys are in the index
+ # if colormap is a dictionary
+ if isinstance(colormap, dict):
+ for key in colormap:
+ if not all(index in colormap for index in index_vals):
+ raise exceptions.PlotlyError("If colormap is a "
+ "dictionary, all the "
+ "names in the index "
+ "must be keys.")
+ figure = scatterplot_dict(
+ dataframe, headers, diag, size, height, width, title,
+ index, index_vals, endpts, colormap, colormap_type,
+ **kwargs
+ )
+ return figure
+
+ else:
+ figure = scatterplot_theme(
+ dataframe, headers, diag, size, height, width, title,
+ index, index_vals, endpts, colormap, colormap_type,
+ **kwargs
+ )
+ return figure
diff --git a/plotly/figure_factory/figure_factory/_streamline.py b/plotly/figure_factory/figure_factory/_streamline.py
new file mode 100644
index 00000000000..a6773420f48
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_streamline.py
@@ -0,0 +1,415 @@
+from __future__ import absolute_import
+
+import math
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+
+np = optional_imports.get_module('numpy')
+
+
+def validate_streamline(x, y):
+ """
+ Streamline-specific validations
+
+ Specifically, this checks that x and y are both evenly spaced,
+ and that the package numpy is available.
+
+ See FigureFactory.create_streamline() for params
+
+ :raises: (ImportError) If numpy is not available.
+ :raises: (PlotlyError) If x is not evenly spaced.
+ :raises: (PlotlyError) If y is not evenly spaced.
+ """
+ if np is False:
+ raise ImportError("FigureFactory.create_streamline requires numpy")
+ for index in range(len(x) - 1):
+ if ((x[index + 1] - x[index]) - (x[1] - x[0])) > .0001:
+ raise exceptions.PlotlyError("x must be a 1 dimensional, "
+ "evenly spaced array")
+ for index in range(len(y) - 1):
+ if ((y[index + 1] - y[index]) -
+ (y[1] - y[0])) > .0001:
+ raise exceptions.PlotlyError("y must be a 1 dimensional, "
+ "evenly spaced array")
+
+
+def create_streamline(x, y, u, v, density=1, angle=math.pi / 9,
+ arrow_scale=.09, **kwargs):
+ """
+ Returns data for a streamline plot.
+
+ :param (list|ndarray) x: 1 dimensional, evenly spaced list or array
+ :param (list|ndarray) y: 1 dimensional, evenly spaced list or array
+ :param (ndarray) u: 2 dimensional array
+ :param (ndarray) v: 2 dimensional array
+ :param (float|int) density: controls the density of streamlines in
+ plot. This is multiplied by 30 to scale similiarly to other
+ available streamline functions such as matplotlib.
+ Default = 1
+ :param (angle in radians) angle: angle of arrowhead. Default = pi/9
+ :param (float in [0,1]) arrow_scale: value to scale length of arrowhead
+ Default = .09
+ :param kwargs: kwargs passed through plotly.graph_objs.Scatter
+ for more information on valid kwargs call
+ help(plotly.graph_objs.Scatter)
+
+ :rtype (dict): returns a representation of streamline figure.
+
+ Example 1: Plot simple streamline and increase arrow size
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_streamline
+
+ import numpy as np
+ import math
+
+ # Add data
+ x = np.linspace(-3, 3, 100)
+ y = np.linspace(-3, 3, 100)
+ Y, X = np.meshgrid(x, y)
+ u = -1 - X**2 + Y
+ v = 1 + X - Y**2
+ u = u.T # Transpose
+ v = v.T # Transpose
+
+ # Create streamline
+ fig = create_streamline(x, y, u, v, arrow_scale=.1)
+
+ # Plot
+ py.plot(fig, filename='streamline')
+ ```
+
+ Example 2: from nbviewer.ipython.org/github/barbagroup/AeroPython
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_streamline
+
+ import numpy as np
+ import math
+
+ # Add data
+ N = 50
+ x_start, x_end = -2.0, 2.0
+ y_start, y_end = -1.0, 1.0
+ x = np.linspace(x_start, x_end, N)
+ y = np.linspace(y_start, y_end, N)
+ X, Y = np.meshgrid(x, y)
+ ss = 5.0
+ x_s, y_s = -1.0, 0.0
+
+ # Compute the velocity field on the mesh grid
+ u_s = ss/(2*np.pi) * (X-x_s)/((X-x_s)**2 + (Y-y_s)**2)
+ v_s = ss/(2*np.pi) * (Y-y_s)/((X-x_s)**2 + (Y-y_s)**2)
+
+ # Create streamline
+ fig = create_streamline(x, y, u_s, v_s, density=2, name='streamline')
+
+ # Add source point
+ point = Scatter(x=[x_s], y=[y_s], mode='markers',
+ marker=Marker(size=14), name='source point')
+
+ # Plot
+ fig['data'].append(point)
+ py.plot(fig, filename='streamline')
+ ```
+ """
+ utils.validate_equal_length(x, y)
+ utils.validate_equal_length(u, v)
+ validate_streamline(x, y)
+ utils.validate_positive_scalars(density=density, arrow_scale=arrow_scale)
+
+ streamline_x, streamline_y = _Streamline(x, y, u, v,
+ density, angle,
+ arrow_scale).sum_streamlines()
+ arrow_x, arrow_y = _Streamline(x, y, u, v,
+ density, angle,
+ arrow_scale).get_streamline_arrows()
+
+ streamline = graph_objs.Scatter(x=streamline_x + arrow_x,
+ y=streamline_y + arrow_y,
+ mode='lines', **kwargs)
+
+ data = [streamline]
+ layout = graph_objs.Layout(hovermode='closest')
+
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Streamline(object):
+ """
+ Refer to FigureFactory.create_streamline() for docstring
+ """
+ def __init__(self, x, y, u, v,
+ density, angle,
+ arrow_scale, **kwargs):
+ self.x = np.array(x)
+ self.y = np.array(y)
+ self.u = np.array(u)
+ self.v = np.array(v)
+ self.angle = angle
+ self.arrow_scale = arrow_scale
+ self.density = int(30 * density) # Scale similarly to other functions
+ self.delta_x = self.x[1] - self.x[0]
+ self.delta_y = self.y[1] - self.y[0]
+ self.val_x = self.x
+ self.val_y = self.y
+
+ # Set up spacing
+ self.blank = np.zeros((self.density, self.density))
+ self.spacing_x = len(self.x) / float(self.density - 1)
+ self.spacing_y = len(self.y) / float(self.density - 1)
+ self.trajectories = []
+
+ # Rescale speed onto axes-coordinates
+ self.u = self.u / (self.x[-1] - self.x[0])
+ self.v = self.v / (self.y[-1] - self.y[0])
+ self.speed = np.sqrt(self.u ** 2 + self.v ** 2)
+
+ # Rescale u and v for integrations.
+ self.u *= len(self.x)
+ self.v *= len(self.y)
+ self.st_x = []
+ self.st_y = []
+ self.get_streamlines()
+ streamline_x, streamline_y = self.sum_streamlines()
+ arrows_x, arrows_y = self.get_streamline_arrows()
+
+ def blank_pos(self, xi, yi):
+ """
+ Set up positions for trajectories to be used with rk4 function.
+ """
+ return (int((xi / self.spacing_x) + 0.5),
+ int((yi / self.spacing_y) + 0.5))
+
+ def value_at(self, a, xi, yi):
+ """
+ Set up for RK4 function, based on Bokeh's streamline code
+ """
+ if isinstance(xi, np.ndarray):
+ self.x = xi.astype(np.int)
+ self.y = yi.astype(np.int)
+ else:
+ self.val_x = np.int(xi)
+ self.val_y = np.int(yi)
+ a00 = a[self.val_y, self.val_x]
+ a01 = a[self.val_y, self.val_x + 1]
+ a10 = a[self.val_y + 1, self.val_x]
+ a11 = a[self.val_y + 1, self.val_x + 1]
+ xt = xi - self.val_x
+ yt = yi - self.val_y
+ a0 = a00 * (1 - xt) + a01 * xt
+ a1 = a10 * (1 - xt) + a11 * xt
+ return a0 * (1 - yt) + a1 * yt
+
+ def rk4_integrate(self, x0, y0):
+ """
+ RK4 forward and back trajectories from the initial conditions.
+
+ Adapted from Bokeh's streamline -uses Runge-Kutta method to fill
+ x and y trajectories then checks length of traj (s in units of axes)
+ """
+ def f(xi, yi):
+ dt_ds = 1. / self.value_at(self.speed, xi, yi)
+ ui = self.value_at(self.u, xi, yi)
+ vi = self.value_at(self.v, xi, yi)
+ return ui * dt_ds, vi * dt_ds
+
+ def g(xi, yi):
+ dt_ds = 1. / self.value_at(self.speed, xi, yi)
+ ui = self.value_at(self.u, xi, yi)
+ vi = self.value_at(self.v, xi, yi)
+ return -ui * dt_ds, -vi * dt_ds
+
+ check = lambda xi, yi: (0 <= xi < len(self.x) - 1 and
+ 0 <= yi < len(self.y) - 1)
+ xb_changes = []
+ yb_changes = []
+
+ def rk4(x0, y0, f):
+ ds = 0.01
+ stotal = 0
+ xi = x0
+ yi = y0
+ xb, yb = self.blank_pos(xi, yi)
+ xf_traj = []
+ yf_traj = []
+ while check(xi, yi):
+ xf_traj.append(xi)
+ yf_traj.append(yi)
+ try:
+ k1x, k1y = f(xi, yi)
+ k2x, k2y = f(xi + .5 * ds * k1x, yi + .5 * ds * k1y)
+ k3x, k3y = f(xi + .5 * ds * k2x, yi + .5 * ds * k2y)
+ k4x, k4y = f(xi + ds * k3x, yi + ds * k3y)
+ except IndexError:
+ break
+ xi += ds * (k1x + 2 * k2x + 2 * k3x + k4x) / 6.
+ yi += ds * (k1y + 2 * k2y + 2 * k3y + k4y) / 6.
+ if not check(xi, yi):
+ break
+ stotal += ds
+ new_xb, new_yb = self.blank_pos(xi, yi)
+ if new_xb != xb or new_yb != yb:
+ if self.blank[new_yb, new_xb] == 0:
+ self.blank[new_yb, new_xb] = 1
+ xb_changes.append(new_xb)
+ yb_changes.append(new_yb)
+ xb = new_xb
+ yb = new_yb
+ else:
+ break
+ if stotal > 2:
+ break
+ return stotal, xf_traj, yf_traj
+
+ sf, xf_traj, yf_traj = rk4(x0, y0, f)
+ sb, xb_traj, yb_traj = rk4(x0, y0, g)
+ stotal = sf + sb
+ x_traj = xb_traj[::-1] + xf_traj[1:]
+ y_traj = yb_traj[::-1] + yf_traj[1:]
+
+ if len(x_traj) < 1:
+ return None
+ if stotal > .2:
+ initxb, inityb = self.blank_pos(x0, y0)
+ self.blank[inityb, initxb] = 1
+ return x_traj, y_traj
+ else:
+ for xb, yb in zip(xb_changes, yb_changes):
+ self.blank[yb, xb] = 0
+ return None
+
+ def traj(self, xb, yb):
+ """
+ Integrate trajectories
+
+ :param (int) xb: results of passing xi through self.blank_pos
+ :param (int) xy: results of passing yi through self.blank_pos
+
+ Calculate each trajectory based on rk4 integrate method.
+ """
+
+ if xb < 0 or xb >= self.density or yb < 0 or yb >= self.density:
+ return
+ if self.blank[yb, xb] == 0:
+ t = self.rk4_integrate(xb * self.spacing_x, yb * self.spacing_y)
+ if t is not None:
+ self.trajectories.append(t)
+
+ def get_streamlines(self):
+ """
+ Get streamlines by building trajectory set.
+ """
+ for indent in range(self.density // 2):
+ for xi in range(self.density - 2 * indent):
+ self.traj(xi + indent, indent)
+ self.traj(xi + indent, self.density - 1 - indent)
+ self.traj(indent, xi + indent)
+ self.traj(self.density - 1 - indent, xi + indent)
+
+ self.st_x = [np.array(t[0]) * self.delta_x + self.x[0] for t in
+ self.trajectories]
+ self.st_y = [np.array(t[1]) * self.delta_y + self.y[0] for t in
+ self.trajectories]
+
+ for index in range(len(self.st_x)):
+ self.st_x[index] = self.st_x[index].tolist()
+ self.st_x[index].append(np.nan)
+
+ for index in range(len(self.st_y)):
+ self.st_y[index] = self.st_y[index].tolist()
+ self.st_y[index].append(np.nan)
+
+ def get_streamline_arrows(self):
+ """
+ Makes an arrow for each streamline.
+
+ Gets angle of streamline at 1/3 mark and creates arrow coordinates
+ based off of user defined angle and arrow_scale.
+
+ :param (array) st_x: x-values for all streamlines
+ :param (array) st_y: y-values for all streamlines
+ :param (angle in radians) angle: angle of arrowhead. Default = pi/9
+ :param (float in [0,1]) arrow_scale: value to scale length of arrowhead
+ Default = .09
+ :rtype (list, list) arrows_x: x-values to create arrowhead and
+ arrows_y: y-values to create arrowhead
+ """
+ arrow_end_x = np.empty((len(self.st_x)))
+ arrow_end_y = np.empty((len(self.st_y)))
+ arrow_start_x = np.empty((len(self.st_x)))
+ arrow_start_y = np.empty((len(self.st_y)))
+ for index in range(len(self.st_x)):
+ arrow_end_x[index] = (self.st_x[index]
+ [int(len(self.st_x[index]) / 3)])
+ arrow_start_x[index] = (self.st_x[index]
+ [(int(len(self.st_x[index]) / 3)) - 1])
+ arrow_end_y[index] = (self.st_y[index]
+ [int(len(self.st_y[index]) / 3)])
+ arrow_start_y[index] = (self.st_y[index]
+ [(int(len(self.st_y[index]) / 3)) - 1])
+
+ dif_x = arrow_end_x - arrow_start_x
+ dif_y = arrow_end_y - arrow_start_y
+
+ orig_err = np.geterr()
+ np.seterr(divide='ignore', invalid='ignore')
+ streamline_ang = np.arctan(dif_y / dif_x)
+ np.seterr(**orig_err)
+
+
+ ang1 = streamline_ang + (self.angle)
+ ang2 = streamline_ang - (self.angle)
+
+ seg1_x = np.cos(ang1) * self.arrow_scale
+ seg1_y = np.sin(ang1) * self.arrow_scale
+ seg2_x = np.cos(ang2) * self.arrow_scale
+ seg2_y = np.sin(ang2) * self.arrow_scale
+
+ point1_x = np.empty((len(dif_x)))
+ point1_y = np.empty((len(dif_y)))
+ point2_x = np.empty((len(dif_x)))
+ point2_y = np.empty((len(dif_y)))
+
+ for index in range(len(dif_x)):
+ if dif_x[index] >= 0:
+ point1_x[index] = arrow_end_x[index] - seg1_x[index]
+ point1_y[index] = arrow_end_y[index] - seg1_y[index]
+ point2_x[index] = arrow_end_x[index] - seg2_x[index]
+ point2_y[index] = arrow_end_y[index] - seg2_y[index]
+ else:
+ point1_x[index] = arrow_end_x[index] + seg1_x[index]
+ point1_y[index] = arrow_end_y[index] + seg1_y[index]
+ point2_x[index] = arrow_end_x[index] + seg2_x[index]
+ point2_y[index] = arrow_end_y[index] + seg2_y[index]
+
+ space = np.empty((len(point1_x)))
+ space[:] = np.nan
+
+ # Combine arrays into matrix
+ arrows_x = np.matrix([point1_x, arrow_end_x, point2_x, space])
+ arrows_x = np.array(arrows_x)
+ arrows_x = arrows_x.flatten('F')
+ arrows_x = arrows_x.tolist()
+
+ # Combine arrays into matrix
+ arrows_y = np.matrix([point1_y, arrow_end_y, point2_y, space])
+ arrows_y = np.array(arrows_y)
+ arrows_y = arrows_y.flatten('F')
+ arrows_y = arrows_y.tolist()
+
+ return arrows_x, arrows_y
+
+ def sum_streamlines(self):
+ """
+ Makes all streamlines readable as a single trace.
+
+ :rtype (list, list): streamline_x: all x values for each streamline
+ combined into single list and streamline_y: all y values for each
+ streamline combined into single list
+ """
+ streamline_x = sum(self.st_x, [])
+ streamline_y = sum(self.st_y, [])
+ return streamline_x, streamline_y
diff --git a/plotly/figure_factory/figure_factory/_table.py b/plotly/figure_factory/figure_factory/_table.py
new file mode 100644
index 00000000000..3bd48db905d
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_table.py
@@ -0,0 +1,233 @@
+from __future__ import absolute_import
+
+from plotly import exceptions, optional_imports
+from plotly.graph_objs import graph_objs
+
+pd = optional_imports.get_module('pandas')
+
+
+def validate_table(table_text, font_colors):
+ """
+ Table-specific validations
+
+ Check that font_colors is supplied correctly (1, 3, or len(text)
+ colors).
+
+ :raises: (PlotlyError) If font_colors is supplied incorretly.
+
+ See FigureFactory.create_table() for params
+ """
+ font_colors_len_options = [1, 3, len(table_text)]
+ if len(font_colors) not in font_colors_len_options:
+ raise exceptions.PlotlyError("Oops, font_colors should be a list "
+ "of length 1, 3 or len(text)")
+
+
+def create_table(table_text, colorscale=None, font_colors=None,
+ index=False, index_title='', annotation_offset=.45,
+ height_constant=30, hoverinfo='none', **kwargs):
+ """
+ BETA function that creates data tables
+
+ :param (pandas.Dataframe | list[list]) text: data for table.
+ :param (str|list[list]) colorscale: Colorscale for table where the
+ color at value 0 is the header color, .5 is the first table color
+ and 1 is the second table color. (Set .5 and 1 to avoid the striped
+ table effect). Default=[[0, '#66b2ff'], [.5, '#d9d9d9'],
+ [1, '#ffffff']]
+ :param (list) font_colors: Color for fonts in table. Can be a single
+ color, three colors, or a color for each row in the table.
+ Default=['#000000'] (black text for the entire table)
+ :param (int) height_constant: Constant multiplied by # of rows to
+ create table height. Default=30.
+ :param (bool) index: Create (header-colored) index column index from
+ Pandas dataframe or list[0] for each list in text. Default=False.
+ :param (string) index_title: Title for index column. Default=''.
+ :param kwargs: kwargs passed through plotly.graph_objs.Heatmap.
+ These kwargs describe other attributes about the annotated Heatmap
+ trace such as the colorscale. For more information on valid kwargs
+ call help(plotly.graph_objs.Heatmap)
+
+ Example 1: Simple Plotly Table
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_table
+
+ text = [['Country', 'Year', 'Population'],
+ ['US', 2000, 282200000],
+ ['Canada', 2000, 27790000],
+ ['US', 2010, 309000000],
+ ['Canada', 2010, 34000000]]
+
+ table = create_table(text)
+ py.iplot(table)
+ ```
+
+ Example 2: Table with Custom Coloring
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_table
+
+ text = [['Country', 'Year', 'Population'],
+ ['US', 2000, 282200000],
+ ['Canada', 2000, 27790000],
+ ['US', 2010, 309000000],
+ ['Canada', 2010, 34000000]]
+
+ table = create_table(text,
+ colorscale=[[0, '#000000'],
+ [.5, '#80beff'],
+ [1, '#cce5ff']],
+ font_colors=['#ffffff', '#000000',
+ '#000000'])
+ py.iplot(table)
+ ```
+ Example 3: Simple Plotly Table with Pandas
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_table
+
+ import pandas as pd
+
+ df = pd.read_csv('http://www.stat.ubc.ca/~jenny/notOcto/STAT545A/examples/gapminder/data/gapminderDataFiveYear.txt', sep='\t')
+ df_p = df[0:25]
+
+ table_simple = create_table(df_p)
+ py.iplot(table_simple)
+ ```
+ """
+
+ # Avoiding mutables in the call signature
+ colorscale = \
+ colorscale if colorscale is not None else [[0, '#00083e'],
+ [.5, '#ededee'],
+ [1, '#ffffff']]
+ font_colors = font_colors if font_colors is not None else ['#ffffff',
+ '#000000',
+ '#000000']
+
+ validate_table(table_text, font_colors)
+ table_matrix = _Table(table_text, colorscale, font_colors, index,
+ index_title, annotation_offset,
+ **kwargs).get_table_matrix()
+ annotations = _Table(table_text, colorscale, font_colors, index,
+ index_title, annotation_offset,
+ **kwargs).make_table_annotations()
+
+ trace = dict(type='heatmap', z=table_matrix, opacity=.75,
+ colorscale=colorscale, showscale=False,
+ hoverinfo=hoverinfo, **kwargs)
+
+ data = [trace]
+ layout = dict(annotations=annotations,
+ height=len(table_matrix) * height_constant + 50,
+ margin=dict(t=0, b=0, r=0, l=0),
+ yaxis=dict(autorange='reversed', zeroline=False,
+ gridwidth=2, ticks='', dtick=1, tick0=.5,
+ showticklabels=False),
+ xaxis=dict(zeroline=False, gridwidth=2, ticks='',
+ dtick=1, tick0=-0.5, showticklabels=False))
+ return graph_objs.Figure(data=data, layout=layout)
+
+
+class _Table(object):
+ """
+ Refer to TraceFactory.create_table() for docstring
+ """
+ def __init__(self, table_text, colorscale, font_colors, index,
+ index_title, annotation_offset, **kwargs):
+ if pd and isinstance(table_text, pd.DataFrame):
+ headers = table_text.columns.tolist()
+ table_text_index = table_text.index.tolist()
+ table_text = table_text.values.tolist()
+ table_text.insert(0, headers)
+ if index:
+ table_text_index.insert(0, index_title)
+ for i in range(len(table_text)):
+ table_text[i].insert(0, table_text_index[i])
+ self.table_text = table_text
+ self.colorscale = colorscale
+ self.font_colors = font_colors
+ self.index = index
+ self.annotation_offset = annotation_offset
+ self.x = range(len(table_text[0]))
+ self.y = range(len(table_text))
+
+ def get_table_matrix(self):
+ """
+ Create z matrix to make heatmap with striped table coloring
+
+ :rtype (list[list]) table_matrix: z matrix to make heatmap with striped
+ table coloring.
+ """
+ header = [0] * len(self.table_text[0])
+ odd_row = [.5] * len(self.table_text[0])
+ even_row = [1] * len(self.table_text[0])
+ table_matrix = [None] * len(self.table_text)
+ table_matrix[0] = header
+ for i in range(1, len(self.table_text), 2):
+ table_matrix[i] = odd_row
+ for i in range(2, len(self.table_text), 2):
+ table_matrix[i] = even_row
+ if self.index:
+ for array in table_matrix:
+ array[0] = 0
+ return table_matrix
+
+ def get_table_font_color(self):
+ """
+ Fill font-color array.
+
+ Table text color can vary by row so this extends a single color or
+ creates an array to set a header color and two alternating colors to
+ create the striped table pattern.
+
+ :rtype (list[list]) all_font_colors: list of font colors for each row
+ in table.
+ """
+ if len(self.font_colors) == 1:
+ all_font_colors = self.font_colors*len(self.table_text)
+ elif len(self.font_colors) == 3:
+ all_font_colors = list(range(len(self.table_text)))
+ all_font_colors[0] = self.font_colors[0]
+ for i in range(1, len(self.table_text), 2):
+ all_font_colors[i] = self.font_colors[1]
+ for i in range(2, len(self.table_text), 2):
+ all_font_colors[i] = self.font_colors[2]
+ elif len(self.font_colors) == len(self.table_text):
+ all_font_colors = self.font_colors
+ else:
+ all_font_colors = ['#000000']*len(self.table_text)
+ return all_font_colors
+
+ def make_table_annotations(self):
+ """
+ Generate annotations to fill in table text
+
+ :rtype (list) annotations: list of annotations for each cell of the
+ table.
+ """
+ table_matrix = _Table.get_table_matrix(self)
+ all_font_colors = _Table.get_table_font_color(self)
+ annotations = []
+ for n, row in enumerate(self.table_text):
+ for m, val in enumerate(row):
+ # Bold text in header and index
+ format_text = ('' + str(val) + '' if n == 0 or
+ self.index and m < 1 else str(val))
+ # Match font color of index to font color of header
+ font_color = (self.font_colors[0] if self.index and m == 0
+ else all_font_colors[n])
+ annotations.append(
+ graph_objs.layout.Annotation(
+ text=format_text,
+ x=self.x[m] - self.annotation_offset,
+ y=self.y[n],
+ xref='x1',
+ yref='y1',
+ align="left",
+ xanchor="left",
+ font=dict(color=font_color),
+ showarrow=False)
+ )
+ return annotations
diff --git a/plotly/figure_factory/figure_factory/_trisurf.py b/plotly/figure_factory/figure_factory/_trisurf.py
new file mode 100644
index 00000000000..6899a8484b7
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_trisurf.py
@@ -0,0 +1,489 @@
+from __future__ import absolute_import
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+
+np = optional_imports.get_module('numpy')
+
+
+def map_face2color(face, colormap, scale, vmin, vmax):
+ """
+ Normalize facecolor values by vmin/vmax and return rgb-color strings
+
+ This function takes a tuple color along with a colormap and a minimum
+ (vmin) and maximum (vmax) range of possible mean distances for the
+ given parametrized surface. It returns an rgb color based on the mean
+ distance between vmin and vmax
+
+ """
+ if vmin >= vmax:
+ raise exceptions.PlotlyError("Incorrect relation between vmin "
+ "and vmax. The vmin value cannot be "
+ "bigger than or equal to the value "
+ "of vmax.")
+ if len(colormap) == 1:
+ # color each triangle face with the same color in colormap
+ face_color = colormap[0]
+ face_color = utils.convert_to_RGB_255(face_color)
+ face_color = utils.label_rgb(face_color)
+ return face_color
+ if face == vmax:
+ # pick last color in colormap
+ face_color = colormap[-1]
+ face_color = utils.convert_to_RGB_255(face_color)
+ face_color = utils.label_rgb(face_color)
+ return face_color
+ else:
+ if scale is None:
+ # find the normalized distance t of a triangle face between
+ # vmin and vmax where the distance is between 0 and 1
+ t = (face - vmin) / float((vmax - vmin))
+ low_color_index = int(t / (1./(len(colormap) - 1)))
+
+ face_color = utils.find_intermediate_color(
+ colormap[low_color_index],
+ colormap[low_color_index + 1],
+ t * (len(colormap) - 1) - low_color_index
+ )
+
+ face_color = utils.convert_to_RGB_255(face_color)
+ face_color = utils.label_rgb(face_color)
+ else:
+ # find the face color for a non-linearly interpolated scale
+ t = (face - vmin) / float((vmax - vmin))
+
+ low_color_index = 0
+ for k in range(len(scale) - 1):
+ if scale[k] <= t < scale[k+1]:
+ break
+ low_color_index += 1
+
+ low_scale_val = scale[low_color_index]
+ high_scale_val = scale[low_color_index + 1]
+
+ face_color = utils.find_intermediate_color(
+ colormap[low_color_index],
+ colormap[low_color_index + 1],
+ (t - low_scale_val)/(high_scale_val - low_scale_val)
+ )
+
+ face_color = utils.convert_to_RGB_255(face_color)
+ face_color = utils.label_rgb(face_color)
+ return face_color
+
+
+def trisurf(x, y, z, simplices, show_colorbar, edges_color, scale,
+ colormap=None, color_func=None, plot_edges=False, x_edge=None,
+ y_edge=None, z_edge=None, facecolor=None):
+ """
+ Refer to FigureFactory.create_trisurf() for docstring
+ """
+ # numpy import check
+ if not np:
+ raise ImportError("FigureFactory._trisurf() requires "
+ "numpy imported.")
+ points3D = np.vstack((x, y, z)).T
+ simplices = np.atleast_2d(simplices)
+
+ # vertices of the surface triangles
+ tri_vertices = points3D[simplices]
+
+ # Define colors for the triangle faces
+ if color_func is None:
+ # mean values of z-coordinates of triangle vertices
+ mean_dists = tri_vertices[:, :, 2].mean(-1)
+ elif isinstance(color_func, (list, np.ndarray)):
+ # Pre-computed list / array of values to map onto color
+ if len(color_func) != len(simplices):
+ raise ValueError("If color_func is a list/array, it must "
+ "be the same length as simplices.")
+
+ # convert all colors in color_func to rgb
+ for index in range(len(color_func)):
+ if isinstance(color_func[index], str):
+ if '#' in color_func[index]:
+ foo = utils.hex_to_rgb(color_func[index])
+ color_func[index] = utils.label_rgb(foo)
+
+ if isinstance(color_func[index], tuple):
+ foo = utils.convert_to_RGB_255(color_func[index])
+ color_func[index] = utils.label_rgb(foo)
+
+ mean_dists = np.asarray(color_func)
+ else:
+ # apply user inputted function to calculate
+ # custom coloring for triangle vertices
+ mean_dists = []
+ for triangle in tri_vertices:
+ dists = []
+ for vertex in triangle:
+ dist = color_func(vertex[0], vertex[1], vertex[2])
+ dists.append(dist)
+ mean_dists.append(np.mean(dists))
+ mean_dists = np.asarray(mean_dists)
+
+ # Check if facecolors are already strings and can be skipped
+ if isinstance(mean_dists[0], str):
+ facecolor = mean_dists
+ else:
+ min_mean_dists = np.min(mean_dists)
+ max_mean_dists = np.max(mean_dists)
+
+ if facecolor is None:
+ facecolor = []
+ for index in range(len(mean_dists)):
+ color = map_face2color(mean_dists[index], colormap, scale,
+ min_mean_dists, max_mean_dists)
+ facecolor.append(color)
+
+ # Make sure facecolor is a list so output is consistent across Pythons
+ facecolor = np.asarray(facecolor)
+ ii, jj, kk = simplices.T
+
+ triangles = graph_objs.Mesh3d(x=x, y=y, z=z, facecolor=facecolor,
+ i=ii, j=jj, k=kk, name='')
+
+ mean_dists_are_numbers = not isinstance(mean_dists[0], str)
+
+ if mean_dists_are_numbers and show_colorbar is True:
+ # make a colorscale from the colors
+ colorscale = utils.make_colorscale(colormap, scale)
+ colorscale = utils.convert_colorscale_to_rgb(colorscale)
+
+ colorbar = graph_objs.Scatter3d(
+ x=x[:1],
+ y=y[:1],
+ z=z[:1],
+ mode='markers',
+ marker=dict(
+ size=0.1,
+ color=[min_mean_dists, max_mean_dists],
+ colorscale=colorscale,
+ showscale=True),
+ hoverinfo='none',
+ showlegend=False
+ )
+
+ # the triangle sides are not plotted
+ if plot_edges is False:
+ if mean_dists_are_numbers and show_colorbar is True:
+ return [triangles, colorbar]
+ else:
+ return [triangles]
+
+ # define the lists x_edge, y_edge and z_edge, of x, y, resp z
+ # coordinates of edge end points for each triangle
+ # None separates data corresponding to two consecutive triangles
+ is_none = [ii is None for ii in [x_edge, y_edge, z_edge]]
+ if any(is_none):
+ if not all(is_none):
+ raise ValueError("If any (x_edge, y_edge, z_edge) is None, "
+ "all must be None")
+ else:
+ x_edge = []
+ y_edge = []
+ z_edge = []
+
+ # Pull indices we care about, then add a None column to separate tris
+ ixs_triangles = [0, 1, 2, 0]
+ pull_edges = tri_vertices[:, ixs_triangles, :]
+ x_edge_pull = np.hstack([pull_edges[:, :, 0],
+ np.tile(None, [pull_edges.shape[0], 1])])
+ y_edge_pull = np.hstack([pull_edges[:, :, 1],
+ np.tile(None, [pull_edges.shape[0], 1])])
+ z_edge_pull = np.hstack([pull_edges[:, :, 2],
+ np.tile(None, [pull_edges.shape[0], 1])])
+
+ # Now unravel the edges into a 1-d vector for plotting
+ x_edge = np.hstack([x_edge, x_edge_pull.reshape([1, -1])[0]])
+ y_edge = np.hstack([y_edge, y_edge_pull.reshape([1, -1])[0]])
+ z_edge = np.hstack([z_edge, z_edge_pull.reshape([1, -1])[0]])
+
+ if not (len(x_edge) == len(y_edge) == len(z_edge)):
+ raise exceptions.PlotlyError("The lengths of x_edge, y_edge and "
+ "z_edge are not the same.")
+
+ # define the lines for plotting
+ lines = graph_objs.Scatter3d(
+ x=x_edge, y=y_edge, z=z_edge, mode='lines',
+ line=graph_objs.scatter3d.Line(
+ color=edges_color,
+ width=1.5
+ ),
+ showlegend=False
+ )
+
+ if mean_dists_are_numbers and show_colorbar is True:
+ return [triangles, lines, colorbar]
+ else:
+ return [triangles, lines]
+
+
+def create_trisurf(x, y, z, simplices, colormap=None, show_colorbar=True,
+ scale=None, color_func=None, title='Trisurf Plot',
+ plot_edges=True, showbackground=True,
+ backgroundcolor='rgb(230, 230, 230)',
+ gridcolor='rgb(255, 255, 255)',
+ zerolinecolor='rgb(255, 255, 255)',
+ edges_color='rgb(50, 50, 50)',
+ height=800, width=800,
+ aspectratio=None):
+ """
+ Returns figure for a triangulated surface plot
+
+ :param (array) x: data values of x in a 1D array
+ :param (array) y: data values of y in a 1D array
+ :param (array) z: data values of z in a 1D array
+ :param (array) simplices: an array of shape (ntri, 3) where ntri is
+ the number of triangles in the triangularization. Each row of the
+ array contains the indicies of the verticies of each triangle
+ :param (str|tuple|list) colormap: either a plotly scale name, an rgb
+ or hex color, a color tuple or a list of colors. An rgb color is
+ of the form 'rgb(x, y, z)' where x, y, z belong to the interval
+ [0, 255] and a color tuple is a tuple of the form (a, b, c) where
+ a, b and c belong to [0, 1]. If colormap is a list, it must
+ contain the valid color types aforementioned as its members
+ :param (bool) show_colorbar: determines if colorbar is visible
+ :param (list|array) scale: sets the scale values to be used if a non-
+ linearly interpolated colormap is desired. If left as None, a
+ linear interpolation between the colors will be excecuted
+ :param (function|list) color_func: The parameter that determines the
+ coloring of the surface. Takes either a function with 3 arguments
+ x, y, z or a list/array of color values the same length as
+ simplices. If None, coloring will only depend on the z axis
+ :param (str) title: title of the plot
+ :param (bool) plot_edges: determines if the triangles on the trisurf
+ are visible
+ :param (bool) showbackground: makes background in plot visible
+ :param (str) backgroundcolor: color of background. Takes a string of
+ the form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
+ :param (str) gridcolor: color of the gridlines besides the axes. Takes
+ a string of the form 'rgb(x,y,z)' x,y,z are between 0 and 255
+ inclusive
+ :param (str) zerolinecolor: color of the axes. Takes a string of the
+ form 'rgb(x,y,z)' x,y,z are between 0 and 255 inclusive
+ :param (str) edges_color: color of the edges, if plot_edges is True
+ :param (int|float) height: the height of the plot (in pixels)
+ :param (int|float) width: the width of the plot (in pixels)
+ :param (dict) aspectratio: a dictionary of the aspect ratio values for
+ the x, y and z axes. 'x', 'y' and 'z' take (int|float) values
+
+ Example 1: Sphere
+ ```
+ # Necessary Imports for Trisurf
+ import numpy as np
+ from scipy.spatial import Delaunay
+
+ import plotly.plotly as py
+ from plotly.figure_factory import create_trisurf
+ from plotly.graph_objs import graph_objs
+
+ # Make data for plot
+ u = np.linspace(0, 2*np.pi, 20)
+ v = np.linspace(0, np.pi, 20)
+ u,v = np.meshgrid(u,v)
+ u = u.flatten()
+ v = v.flatten()
+
+ x = np.sin(v)*np.cos(u)
+ y = np.sin(v)*np.sin(u)
+ z = np.cos(v)
+
+ points2D = np.vstack([u,v]).T
+ tri = Delaunay(points2D)
+ simplices = tri.simplices
+
+ # Create a figure
+ fig1 = create_trisurf(x=x, y=y, z=z, colormap="Rainbow",
+ simplices=simplices)
+ # Plot the data
+ py.iplot(fig1, filename='trisurf-plot-sphere')
+ ```
+
+ Example 2: Torus
+ ```
+ # Necessary Imports for Trisurf
+ import numpy as np
+ from scipy.spatial import Delaunay
+
+ import plotly.plotly as py
+ from plotly.figure_factory import create_trisurf
+ from plotly.graph_objs import graph_objs
+
+ # Make data for plot
+ u = np.linspace(0, 2*np.pi, 20)
+ v = np.linspace(0, 2*np.pi, 20)
+ u,v = np.meshgrid(u,v)
+ u = u.flatten()
+ v = v.flatten()
+
+ x = (3 + (np.cos(v)))*np.cos(u)
+ y = (3 + (np.cos(v)))*np.sin(u)
+ z = np.sin(v)
+
+ points2D = np.vstack([u,v]).T
+ tri = Delaunay(points2D)
+ simplices = tri.simplices
+
+ # Create a figure
+ fig1 = create_trisurf(x=x, y=y, z=z, colormap="Viridis",
+ simplices=simplices)
+ # Plot the data
+ py.iplot(fig1, filename='trisurf-plot-torus')
+ ```
+
+ Example 3: Mobius Band
+ ```
+ # Necessary Imports for Trisurf
+ import numpy as np
+ from scipy.spatial import Delaunay
+
+ import plotly.plotly as py
+ from plotly.figure_factory import create_trisurf
+ from plotly.graph_objs import graph_objs
+
+ # Make data for plot
+ u = np.linspace(0, 2*np.pi, 24)
+ v = np.linspace(-1, 1, 8)
+ u,v = np.meshgrid(u,v)
+ u = u.flatten()
+ v = v.flatten()
+
+ tp = 1 + 0.5*v*np.cos(u/2.)
+ x = tp*np.cos(u)
+ y = tp*np.sin(u)
+ z = 0.5*v*np.sin(u/2.)
+
+ points2D = np.vstack([u,v]).T
+ tri = Delaunay(points2D)
+ simplices = tri.simplices
+
+ # Create a figure
+ fig1 = create_trisurf(x=x, y=y, z=z, colormap=[(0.2, 0.4, 0.6), (1, 1, 1)],
+ simplices=simplices)
+ # Plot the data
+ py.iplot(fig1, filename='trisurf-plot-mobius-band')
+ ```
+
+ Example 4: Using a Custom Colormap Function with Light Cone
+ ```
+ # Necessary Imports for Trisurf
+ import numpy as np
+ from scipy.spatial import Delaunay
+
+ import plotly.plotly as py
+ from plotly.figure_factory import create_trisurf
+ from plotly.graph_objs import graph_objs
+
+ # Make data for plot
+ u=np.linspace(-np.pi, np.pi, 30)
+ v=np.linspace(-np.pi, np.pi, 30)
+ u,v=np.meshgrid(u,v)
+ u=u.flatten()
+ v=v.flatten()
+
+ x = u
+ y = u*np.cos(v)
+ z = u*np.sin(v)
+
+ points2D = np.vstack([u,v]).T
+ tri = Delaunay(points2D)
+ simplices = tri.simplices
+
+ # Define distance function
+ def dist_origin(x, y, z):
+ return np.sqrt((1.0 * x)**2 + (1.0 * y)**2 + (1.0 * z)**2)
+
+ # Create a figure
+ fig1 = create_trisurf(x=x, y=y, z=z,
+ colormap=['#FFFFFF', '#E4FFFE',
+ '#A4F6F9', '#FF99FE',
+ '#BA52ED'],
+ scale=[0, 0.6, 0.71, 0.89, 1],
+ simplices=simplices,
+ color_func=dist_origin)
+ # Plot the data
+ py.iplot(fig1, filename='trisurf-plot-custom-coloring')
+ ```
+
+ Example 5: Enter color_func as a list of colors
+ ```
+ # Necessary Imports for Trisurf
+ import numpy as np
+ from scipy.spatial import Delaunay
+ import random
+
+ import plotly.plotly as py
+ from plotly.figure_factory import create_trisurf
+ from plotly.graph_objs import graph_objs
+
+ # Make data for plot
+ u=np.linspace(-np.pi, np.pi, 30)
+ v=np.linspace(-np.pi, np.pi, 30)
+ u,v=np.meshgrid(u,v)
+ u=u.flatten()
+ v=v.flatten()
+
+ x = u
+ y = u*np.cos(v)
+ z = u*np.sin(v)
+
+ points2D = np.vstack([u,v]).T
+ tri = Delaunay(points2D)
+ simplices = tri.simplices
+
+
+ colors = []
+ color_choices = ['rgb(0, 0, 0)', '#6c4774', '#d6c7dd']
+
+ for index in range(len(simplices)):
+ colors.append(random.choice(color_choices))
+
+ fig = create_trisurf(
+ x, y, z, simplices,
+ color_func=colors,
+ show_colorbar=True,
+ edges_color='rgb(2, 85, 180)',
+ title=' Modern Art'
+ )
+
+ py.iplot(fig, filename="trisurf-plot-modern-art")
+ ```
+ """
+ if aspectratio is None:
+ aspectratio = {'x': 1, 'y': 1, 'z': 1}
+
+ # Validate colormap
+ utils.validate_colors(colormap)
+ colormap, scale = utils.convert_colors_to_same_type(
+ colormap, colortype='tuple',
+ return_default_colors=True, scale=scale
+ )
+
+ data1 = trisurf(x, y, z, simplices, show_colorbar=show_colorbar,
+ color_func=color_func, colormap=colormap, scale=scale,
+ edges_color=edges_color, plot_edges=plot_edges)
+
+ axis = dict(
+ showbackground=showbackground,
+ backgroundcolor=backgroundcolor,
+ gridcolor=gridcolor,
+ zerolinecolor=zerolinecolor,
+ )
+ layout = graph_objs.Layout(
+ title=title,
+ width=width,
+ height=height,
+ scene=graph_objs.layout.Scene(
+ xaxis=graph_objs.layout.scene.XAxis(**axis),
+ yaxis=graph_objs.layout.scene.YAxis(**axis),
+ zaxis=graph_objs.layout.scene.ZAxis(**axis),
+ aspectratio=dict(
+ x=aspectratio['x'],
+ y=aspectratio['y'],
+ z=aspectratio['z']),
+ )
+ )
+
+ return graph_objs.Figure(data=data1, layout=layout)
diff --git a/plotly/figure_factory/figure_factory/_violin.py b/plotly/figure_factory/figure_factory/_violin.py
new file mode 100644
index 00000000000..384f3497a15
--- /dev/null
+++ b/plotly/figure_factory/figure_factory/_violin.py
@@ -0,0 +1,644 @@
+from __future__ import absolute_import
+
+from numbers import Number
+
+from plotly import exceptions, optional_imports
+from plotly.figure_factory import utils
+from plotly.graph_objs import graph_objs
+from plotly.tools import make_subplots
+
+pd = optional_imports.get_module('pandas')
+np = optional_imports.get_module('numpy')
+scipy_stats = optional_imports.get_module('scipy.stats')
+
+
+def calc_stats(data):
+ """
+ Calculate statistics for use in violin plot.
+ """
+ x = np.asarray(data, np.float)
+ vals_min = np.min(x)
+ vals_max = np.max(x)
+ q2 = np.percentile(x, 50, interpolation='linear')
+ q1 = np.percentile(x, 25, interpolation='lower')
+ q3 = np.percentile(x, 75, interpolation='higher')
+ iqr = q3 - q1
+ whisker_dist = 1.5 * iqr
+
+ # in order to prevent drawing whiskers outside the interval
+ # of data one defines the whisker positions as:
+ d1 = np.min(x[x >= (q1 - whisker_dist)])
+ d2 = np.max(x[x <= (q3 + whisker_dist)])
+ return {
+ 'min': vals_min,
+ 'max': vals_max,
+ 'q1': q1,
+ 'q2': q2,
+ 'q3': q3,
+ 'd1': d1,
+ 'd2': d2
+ }
+
+
+def make_half_violin(x, y, fillcolor='#1f77b4', linecolor='rgb(0, 0, 0)'):
+ """
+ Produces a sideways probability distribution fig violin plot.
+ """
+ text = ['(pdf(y), y)=(' + '{:0.2f}'.format(x[i]) +
+ ', ' + '{:0.2f}'.format(y[i]) + ')'
+ for i in range(len(x))]
+
+ return graph_objs.Scatter(
+ x=x,
+ y=y,
+ mode='lines',
+ name='',
+ text=text,
+ fill='tonextx',
+ fillcolor=fillcolor,
+ line=graph_objs.scatter.Line(width=0.5, color=linecolor, shape='spline'),
+ hoverinfo='text',
+ opacity=0.5
+ )
+
+
+def make_violin_rugplot(vals, pdf_max, distance, color='#1f77b4'):
+ """
+ Returns a rugplot fig for a violin plot.
+ """
+ return graph_objs.Scatter(
+ y=vals,
+ x=[-pdf_max-distance]*len(vals),
+ marker=graph_objs.scatter.Marker(
+ color=color,
+ symbol='line-ew-open'
+ ),
+ mode='markers',
+ name='',
+ showlegend=False,
+ hoverinfo='y'
+ )
+
+
+def make_non_outlier_interval(d1, d2):
+ """
+ Returns the scatterplot fig of most of a violin plot.
+ """
+ return graph_objs.Scatter(
+ x=[0, 0],
+ y=[d1, d2],
+ name='',
+ mode='lines',
+ line=graph_objs.scatter.Line(width=1.5,
+ color='rgb(0,0,0)')
+ )
+
+
+def make_quartiles(q1, q3):
+ """
+ Makes the upper and lower quartiles for a violin plot.
+ """
+ return graph_objs.Scatter(
+ x=[0, 0],
+ y=[q1, q3],
+ text=['lower-quartile: ' + '{:0.2f}'.format(q1),
+ 'upper-quartile: ' + '{:0.2f}'.format(q3)],
+ mode='lines',
+ line=graph_objs.scatter.Line(
+ width=4,
+ color='rgb(0,0,0)'
+ ),
+ hoverinfo='text'
+ )
+
+
+def make_median(q2):
+ """
+ Formats the 'median' hovertext for a violin plot.
+ """
+ return graph_objs.Scatter(
+ x=[0],
+ y=[q2],
+ text=['median: ' + '{:0.2f}'.format(q2)],
+ mode='markers',
+ marker=dict(symbol='square',
+ color='rgb(255,255,255)'),
+ hoverinfo='text'
+ )
+
+
+def make_XAxis(xaxis_title, xaxis_range):
+ """
+ Makes the x-axis for a violin plot.
+ """
+ xaxis = graph_objs.layout.XAxis(title=xaxis_title,
+ range=xaxis_range,
+ showgrid=False,
+ zeroline=False,
+ showline=False,
+ mirror=False,
+ ticks='',
+ showticklabels=False)
+ return xaxis
+
+
+def make_YAxis(yaxis_title):
+ """
+ Makes the y-axis for a violin plot.
+ """
+ yaxis = graph_objs.layout.YAxis(title=yaxis_title,
+ showticklabels=True,
+ autorange=True,
+ ticklen=4,
+ showline=True,
+ zeroline=False,
+ showgrid=False,
+ mirror=False)
+ return yaxis
+
+
+def violinplot(vals, fillcolor='#1f77b4', rugplot=True):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+ """
+ vals = np.asarray(vals, np.float)
+ # summary statistics
+ vals_min = calc_stats(vals)['min']
+ vals_max = calc_stats(vals)['max']
+ q1 = calc_stats(vals)['q1']
+ q2 = calc_stats(vals)['q2']
+ q3 = calc_stats(vals)['q3']
+ d1 = calc_stats(vals)['d1']
+ d2 = calc_stats(vals)['d2']
+
+ # kernel density estimation of pdf
+ pdf = scipy_stats.gaussian_kde(vals)
+ # grid over the data interval
+ xx = np.linspace(vals_min, vals_max, 100)
+ # evaluate the pdf at the grid xx
+ yy = pdf(xx)
+ max_pdf = np.max(yy)
+ # distance from the violin plot to rugplot
+ distance = (2.0 * max_pdf)/10 if rugplot else 0
+ # range for x values in the plot
+ plot_xrange = [-max_pdf - distance - 0.1, max_pdf + 0.1]
+ plot_data = [make_half_violin(-yy, xx, fillcolor=fillcolor),
+ make_half_violin(yy, xx, fillcolor=fillcolor),
+ make_non_outlier_interval(d1, d2),
+ make_quartiles(q1, q3),
+ make_median(q2)]
+ if rugplot:
+ plot_data.append(make_violin_rugplot(vals, max_pdf, distance=distance,
+ color=fillcolor))
+ return plot_data, plot_xrange
+
+
+def violin_no_colorscale(data, data_header, group_header, colors,
+ use_colorscale, group_stats, rugplot, sort,
+ height, width, title):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+
+ Returns fig for violin plot without colorscale.
+
+ """
+
+ # collect all group names
+ group_name = []
+ for name in data[group_header]:
+ if name not in group_name:
+ group_name.append(name)
+ if sort:
+ group_name.sort()
+
+ gb = data.groupby([group_header])
+ L = len(group_name)
+
+ fig = make_subplots(rows=1, cols=L,
+ shared_yaxes=True,
+ horizontal_spacing=0.025,
+ print_grid=False)
+ color_index = 0
+ for k, gr in enumerate(group_name):
+ vals = np.asarray(gb.get_group(gr)[data_header], np.float)
+ if color_index >= len(colors):
+ color_index = 0
+ plot_data, plot_xrange = violinplot(vals,
+ fillcolor=colors[color_index],
+ rugplot=rugplot)
+ layout = graph_objs.Layout()
+
+ for item in plot_data:
+ fig.append_trace(item, 1, k + 1)
+ color_index += 1
+
+ # add violin plot labels
+ fig['layout'].update(
+ {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)}
+ )
+
+ # set the sharey axis style
+ fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')})
+ fig['layout'].update(
+ title=title,
+ showlegend=False,
+ hovermode='closest',
+ autosize=False,
+ height=height,
+ width=width
+ )
+
+ return fig
+
+
+def violin_colorscale(data, data_header, group_header, colors, use_colorscale,
+ group_stats, rugplot, sort, height, width,
+ title):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+
+ Returns fig for violin plot with colorscale.
+
+ """
+
+ # collect all group names
+ group_name = []
+ for name in data[group_header]:
+ if name not in group_name:
+ group_name.append(name)
+ if sort:
+ group_name.sort()
+
+ # make sure all group names are keys in group_stats
+ for group in group_name:
+ if group not in group_stats:
+ raise exceptions.PlotlyError("All values/groups in the index "
+ "column must be represented "
+ "as a key in group_stats.")
+
+ gb = data.groupby([group_header])
+ L = len(group_name)
+
+ fig = make_subplots(rows=1, cols=L,
+ shared_yaxes=True,
+ horizontal_spacing=0.025,
+ print_grid=False)
+
+ # prepare low and high color for colorscale
+ lowcolor = utils.color_parser(colors[0], utils.unlabel_rgb)
+ highcolor = utils.color_parser(colors[1], utils.unlabel_rgb)
+
+ # find min and max values in group_stats
+ group_stats_values = []
+ for key in group_stats:
+ group_stats_values.append(group_stats[key])
+
+ max_value = max(group_stats_values)
+ min_value = min(group_stats_values)
+
+ for k, gr in enumerate(group_name):
+ vals = np.asarray(gb.get_group(gr)[data_header], np.float)
+
+ # find intermediate color from colorscale
+ intermed = (group_stats[gr] - min_value) / (max_value - min_value)
+ intermed_color = utils.find_intermediate_color(
+ lowcolor, highcolor, intermed
+ )
+
+ plot_data, plot_xrange = violinplot(
+ vals,
+ fillcolor='rgb{}'.format(intermed_color),
+ rugplot=rugplot
+ )
+ layout = graph_objs.Layout()
+
+ for item in plot_data:
+ fig.append_trace(item, 1, k + 1)
+ fig['layout'].update(
+ {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)}
+ )
+ # add colorbar to plot
+ trace_dummy = graph_objs.Scatter(
+ x=[0],
+ y=[0],
+ mode='markers',
+ marker=dict(
+ size=2,
+ cmin=min_value,
+ cmax=max_value,
+ colorscale=[[0, colors[0]],
+ [1, colors[1]]],
+ showscale=True),
+ showlegend=False,
+ )
+ fig.append_trace(trace_dummy, 1, L)
+
+ # set the sharey axis style
+ fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')})
+ fig['layout'].update(
+ title=title,
+ showlegend=False,
+ hovermode='closest',
+ autosize=False,
+ height=height,
+ width=width
+ )
+
+ return fig
+
+
+def violin_dict(data, data_header, group_header, colors, use_colorscale,
+ group_stats, rugplot, sort, height, width, title):
+ """
+ Refer to FigureFactory.create_violin() for docstring.
+
+ Returns fig for violin plot without colorscale.
+
+ """
+
+ # collect all group names
+ group_name = []
+ for name in data[group_header]:
+ if name not in group_name:
+ group_name.append(name)
+
+ if sort:
+ group_name.sort()
+
+ # check if all group names appear in colors dict
+ for group in group_name:
+ if group not in colors:
+ raise exceptions.PlotlyError("If colors is a dictionary, all "
+ "the group names must appear as "
+ "keys in colors.")
+
+ gb = data.groupby([group_header])
+ L = len(group_name)
+
+ fig = make_subplots(rows=1, cols=L,
+ shared_yaxes=True,
+ horizontal_spacing=0.025,
+ print_grid=False)
+
+ for k, gr in enumerate(group_name):
+ vals = np.asarray(gb.get_group(gr)[data_header], np.float)
+ plot_data, plot_xrange = violinplot(vals, fillcolor=colors[gr],
+ rugplot=rugplot)
+ layout = graph_objs.Layout()
+
+ for item in plot_data:
+ fig.append_trace(item, 1, k + 1)
+
+ # add violin plot labels
+ fig['layout'].update(
+ {'xaxis{}'.format(k + 1): make_XAxis(group_name[k], plot_xrange)}
+ )
+
+ # set the sharey axis style
+ fig['layout'].update({'yaxis{}'.format(1): make_YAxis('')})
+ fig['layout'].update(
+ title=title,
+ showlegend=False,
+ hovermode='closest',
+ autosize=False,
+ height=height,
+ width=width
+ )
+
+ return fig
+
+
+def create_violin(data, data_header=None, group_header=None, colors=None,
+ use_colorscale=False, group_stats=None, rugplot=True,
+ sort=False, height=450, width=600,
+ title='Violin and Rug Plot'):
+ """
+ Returns figure for a violin plot
+
+ :param (list|array) data: accepts either a list of numerical values,
+ a list of dictionaries all with identical keys and at least one
+ column of numeric values, or a pandas dataframe with at least one
+ column of numbers.
+ :param (str) data_header: the header of the data column to be used
+ from an inputted pandas dataframe. Not applicable if 'data' is
+ a list of numeric values.
+ :param (str) group_header: applicable if grouping data by a variable.
+ 'group_header' must be set to the name of the grouping variable.
+ :param (str|tuple|list|dict) colors: either a plotly scale name,
+ an rgb or hex color, a color tuple, a list of colors or a
+ dictionary. An rgb color is of the form 'rgb(x, y, z)' where
+ x, y and z belong to the interval [0, 255] and a color tuple is a
+ tuple of the form (a, b, c) where a, b and c belong to [0, 1].
+ If colors is a list, it must contain valid color types as its
+ members.
+ :param (bool) use_colorscale: only applicable if grouping by another
+ variable. Will implement a colorscale based on the first 2 colors
+ of param colors. This means colors must be a list with at least 2
+ colors in it (Plotly colorscales are accepted since they map to a
+ list of two rgb colors). Default = False
+ :param (dict) group_stats: a dictioanry where each key is a unique
+ value from the group_header column in data. Each value must be a
+ number and will be used to color the violin plots if a colorscale
+ is being used.
+ :param (bool) rugplot: determines if a rugplot is draw on violin plot.
+ Default = True
+ :param (bool) sort: determines if violins are sorted
+ alphabetically (True) or by input order (False). Default = False
+ :param (float) height: the height of the violin plot.
+ :param (float) width: the width of the violin plot.
+ :param (str) title: the title of the violin plot.
+
+ Example 1: Single Violin Plot
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_violin
+ from plotly.graph_objs import graph_objs
+
+ import numpy as np
+ from scipy import stats
+
+ # create list of random values
+ data_list = np.random.randn(100)
+ data_list.tolist()
+
+ # create violin fig
+ fig = create_violin(data_list, colors='#604d9e')
+
+ # plot
+ py.iplot(fig, filename='Violin Plot')
+ ```
+
+ Example 2: Multiple Violin Plots with Qualitative Coloring
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_violin
+ from plotly.graph_objs import graph_objs
+
+ import numpy as np
+ import pandas as pd
+ from scipy import stats
+
+ # create dataframe
+ np.random.seed(619517)
+ Nr=250
+ y = np.random.randn(Nr)
+ gr = np.random.choice(list("ABCDE"), Nr)
+ norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
+
+ for i, letter in enumerate("ABCDE"):
+ y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
+ df = pd.DataFrame(dict(Score=y, Group=gr))
+
+ # create violin fig
+ fig = create_violin(df, data_header='Score', group_header='Group',
+ sort=True, height=600, width=1000)
+
+ # plot
+ py.iplot(fig, filename='Violin Plot with Coloring')
+ ```
+
+ Example 3: Violin Plots with Colorscale
+ ```
+ import plotly.plotly as py
+ from plotly.figure_factory import create_violin
+ from plotly.graph_objs import graph_objs
+
+ import numpy as np
+ import pandas as pd
+ from scipy import stats
+
+ # create dataframe
+ np.random.seed(619517)
+ Nr=250
+ y = np.random.randn(Nr)
+ gr = np.random.choice(list("ABCDE"), Nr)
+ norm_params=[(0, 1.2), (0.7, 1), (-0.5, 1.4), (0.3, 1), (0.8, 0.9)]
+
+ for i, letter in enumerate("ABCDE"):
+ y[gr == letter] *=norm_params[i][1]+ norm_params[i][0]
+ df = pd.DataFrame(dict(Score=y, Group=gr))
+
+ # define header params
+ data_header = 'Score'
+ group_header = 'Group'
+
+ # make groupby object with pandas
+ group_stats = {}
+ groupby_data = df.groupby([group_header])
+
+ for group in "ABCDE":
+ data_from_group = groupby_data.get_group(group)[data_header]
+ # take a stat of the grouped data
+ stat = np.median(data_from_group)
+ # add to dictionary
+ group_stats[group] = stat
+
+ # create violin fig
+ fig = create_violin(df, data_header='Score', group_header='Group',
+ height=600, width=1000, use_colorscale=True,
+ group_stats=group_stats)
+
+ # plot
+ py.iplot(fig, filename='Violin Plot with Colorscale')
+ ```
+ """
+
+ # Validate colors
+ if isinstance(colors, dict):
+ valid_colors = utils.validate_colors_dict(colors, 'rgb')
+ else:
+ valid_colors = utils.validate_colors(colors, 'rgb')
+
+ # validate data and choose plot type
+ if group_header is None:
+ if isinstance(data, list):
+ if len(data) <= 0:
+ raise exceptions.PlotlyError("If data is a list, it must be "
+ "nonempty and contain either "
+ "numbers or dictionaries.")
+
+ if not all(isinstance(element, Number) for element in data):
+ raise exceptions.PlotlyError("If data is a list, it must "
+ "contain only numbers.")
+
+ if pd and isinstance(data, pd.core.frame.DataFrame):
+ if data_header is None:
+ raise exceptions.PlotlyError("data_header must be the "
+ "column name with the "
+ "desired numeric data for "
+ "the violin plot.")
+
+ data = data[data_header].values.tolist()
+
+ # call the plotting functions
+ plot_data, plot_xrange = violinplot(data, fillcolor=valid_colors[0],
+ rugplot=rugplot)
+
+ layout = graph_objs.Layout(
+ title=title,
+ autosize=False,
+ font=graph_objs.layout.Font(size=11),
+ height=height,
+ showlegend=False,
+ width=width,
+ xaxis=make_XAxis('', plot_xrange),
+ yaxis=make_YAxis(''),
+ hovermode='closest'
+ )
+ layout['yaxis'].update(dict(showline=False,
+ showticklabels=False,
+ ticks=''))
+
+ fig = graph_objs.Figure(data=plot_data,
+ layout=layout)
+
+ return fig
+
+ else:
+ if not isinstance(data, pd.core.frame.DataFrame):
+ raise exceptions.PlotlyError("Error. You must use a pandas "
+ "DataFrame if you are using a "
+ "group header.")
+
+ if data_header is None:
+ raise exceptions.PlotlyError("data_header must be the column "
+ "name with the desired numeric "
+ "data for the violin plot.")
+
+ if use_colorscale is False:
+ if isinstance(valid_colors, dict):
+ # validate colors dict choice below
+ fig = violin_dict(
+ data, data_header, group_header, valid_colors,
+ use_colorscale, group_stats, rugplot, sort,
+ height, width, title
+ )
+ return fig
+ else:
+ fig = violin_no_colorscale(
+ data, data_header, group_header, valid_colors,
+ use_colorscale, group_stats, rugplot, sort,
+ height, width, title
+ )
+ return fig
+ else:
+ if isinstance(valid_colors, dict):
+ raise exceptions.PlotlyError("The colors param cannot be "
+ "a dictionary if you are "
+ "using a colorscale.")
+
+ if len(valid_colors) < 2:
+ raise exceptions.PlotlyError("colors must be a list with "
+ "at least 2 colors. A "
+ "Plotly scale is allowed.")
+
+ if not isinstance(group_stats, dict):
+ raise exceptions.PlotlyError("Your group_stats param "
+ "must be a dictionary.")
+
+ fig = violin_colorscale(
+ data, data_header, group_header, valid_colors,
+ use_colorscale, group_stats, rugplot, sort, height,
+ width, title
+ )
+ return fig
diff --git a/plotly/figure_factory/utils.py b/plotly/figure_factory/utils.py
index ee19e17cef9..ea27270c46b 100644
--- a/plotly/figure_factory/utils.py
+++ b/plotly/figure_factory/utils.py
@@ -2,6 +2,8 @@
import collections
import decimal
+import six
+from numbers import Number
from plotly import exceptions
@@ -12,26 +14,136 @@
'rgb(227, 119, 194)', 'rgb(127, 127, 127)',
'rgb(188, 189, 34)', 'rgb(23, 190, 207)']
-# TODO: make PLOTLY_SCALES below like version in plotly.colors
-# requires rewritting scatterplot_matrix code
PLOTLY_SCALES = {
- 'Greys': ['rgb(0,0,0)', 'rgb(255,255,255)'],
- 'YlGnBu': ['rgb(8,29,88)', 'rgb(255,255,217)'],
- 'Greens': ['rgb(0,68,27)', 'rgb(247,252,245)'],
- 'YlOrRd': ['rgb(128,0,38)', 'rgb(255,255,204)'],
- 'Bluered': ['rgb(0,0,255)', 'rgb(255,0,0)'],
- 'RdBu': ['rgb(5,10,172)', 'rgb(178,10,28)'],
- 'Reds': ['rgb(220,220,220)', 'rgb(178,10,28)'],
- 'Blues': ['rgb(5,10,172)', 'rgb(220,220,220)'],
- 'Picnic': ['rgb(0,0,255)', 'rgb(255,0,0)'],
- 'Rainbow': ['rgb(150,0,90)', 'rgb(255,0,0)'],
- 'Portland': ['rgb(12,51,131)', 'rgb(217,30,30)'],
- 'Jet': ['rgb(0,0,131)', 'rgb(128,0,0)'],
- 'Hot': ['rgb(0,0,0)', 'rgb(255,255,255)'],
- 'Blackbody': ['rgb(0,0,0)', 'rgb(160,200,255)'],
- 'Earth': ['rgb(0,0,130)', 'rgb(255,255,255)'],
- 'Electric': ['rgb(0,0,0)', 'rgb(255,250,220)'],
- 'Viridis': ['#440154', '#fde725']
+ 'Greys': [
+ [0, 'rgb(0,0,0)'], [1, 'rgb(255,255,255)']
+ ],
+
+ 'YlGnBu': [
+ [0, 'rgb(8,29,88)'], [0.125, 'rgb(37,52,148)'],
+ [0.25, 'rgb(34,94,168)'], [0.375, 'rgb(29,145,192)'],
+ [0.5, 'rgb(65,182,196)'], [0.625, 'rgb(127,205,187)'],
+ [0.75, 'rgb(199,233,180)'], [0.875, 'rgb(237,248,217)'],
+ [1, 'rgb(255,255,217)']
+ ],
+
+ 'Greens': [
+ [0, 'rgb(0,68,27)'], [0.125, 'rgb(0,109,44)'],
+ [0.25, 'rgb(35,139,69)'], [0.375, 'rgb(65,171,93)'],
+ [0.5, 'rgb(116,196,118)'], [0.625, 'rgb(161,217,155)'],
+ [0.75, 'rgb(199,233,192)'], [0.875, 'rgb(229,245,224)'],
+ [1, 'rgb(247,252,245)']
+ ],
+
+ 'YlOrRd': [
+ [0, 'rgb(128,0,38)'], [0.125, 'rgb(189,0,38)'],
+ [0.25, 'rgb(227,26,28)'], [0.375, 'rgb(252,78,42)'],
+ [0.5, 'rgb(253,141,60)'], [0.625, 'rgb(254,178,76)'],
+ [0.75, 'rgb(254,217,118)'], [0.875, 'rgb(255,237,160)'],
+ [1, 'rgb(255,255,204)']
+ ],
+
+ 'Bluered': [
+ [0, 'rgb(0,0,255)'], [1, 'rgb(255,0,0)']
+ ],
+
+ # modified RdBu based on
+ # www.sandia.gov/~kmorel/documents/ColorMaps/ColorMapsExpanded.pdf
+ 'RdBu': [
+ [0, 'rgb(5,10,172)'], [0.35, 'rgb(106,137,247)'],
+ [0.5, 'rgb(190,190,190)'], [0.6, 'rgb(220,170,132)'],
+ [0.7, 'rgb(230,145,90)'], [1, 'rgb(178,10,28)']
+ ],
+
+ # Scale for non-negative numeric values
+ 'Reds': [
+ [0, 'rgb(220,220,220)'], [0.2, 'rgb(245,195,157)'],
+ [0.4, 'rgb(245,160,105)'], [1, 'rgb(178,10,28)']
+ ],
+
+ # Scale for non-positive numeric values
+ 'Blues': [
+ [0, 'rgb(5,10,172)'], [0.35, 'rgb(40,60,190)'],
+ [0.5, 'rgb(70,100,245)'], [0.6, 'rgb(90,120,245)'],
+ [0.7, 'rgb(106,137,247)'], [1, 'rgb(220,220,220)']
+ ],
+
+ 'Picnic': [
+ [0, 'rgb(0,0,255)'], [0.1, 'rgb(51,153,255)'],
+ [0.2, 'rgb(102,204,255)'], [0.3, 'rgb(153,204,255)'],
+ [0.4, 'rgb(204,204,255)'], [0.5, 'rgb(255,255,255)'],
+ [0.6, 'rgb(255,204,255)'], [0.7, 'rgb(255,153,255)'],
+ [0.8, 'rgb(255,102,204)'], [0.9, 'rgb(255,102,102)'],
+ [1, 'rgb(255,0,0)']
+ ],
+
+ 'Rainbow': [
+ [0, 'rgb(150,0,90)'], [0.125, 'rgb(0,0,200)'],
+ [0.25, 'rgb(0,25,255)'], [0.375, 'rgb(0,152,255)'],
+ [0.5, 'rgb(44,255,150)'], [0.625, 'rgb(151,255,0)'],
+ [0.75, 'rgb(255,234,0)'], [0.875, 'rgb(255,111,0)'],
+ [1, 'rgb(255,0,0)']
+ ],
+
+ 'Portland': [
+ [0, 'rgb(12,51,131)'], [0.25, 'rgb(10,136,186)'],
+ [0.5, 'rgb(242,211,56)'], [0.75, 'rgb(242,143,56)'],
+ [1, 'rgb(217,30,30)']
+ ],
+
+ 'Jet': [
+ [0, 'rgb(0,0,131)'], [0.125, 'rgb(0,60,170)'],
+ [0.375, 'rgb(5,255,255)'], [0.625, 'rgb(255,255,0)'],
+ [0.875, 'rgb(250,0,0)'], [1, 'rgb(128,0,0)']
+ ],
+
+ 'Hot': [
+ [0, 'rgb(0,0,0)'], [0.3, 'rgb(230,0,0)'],
+ [0.6, 'rgb(255,210,0)'], [1, 'rgb(255,255,255)']
+ ],
+
+ 'Blackbody': [
+ [0, 'rgb(0,0,0)'], [0.2, 'rgb(230,0,0)'],
+ [0.4, 'rgb(230,210,0)'], [0.7, 'rgb(255,255,255)'],
+ [1, 'rgb(160,200,255)']
+ ],
+
+ 'Earth': [
+ [0, 'rgb(0,0,130)'], [0.1, 'rgb(0,180,180)'],
+ [0.2, 'rgb(40,210,40)'], [0.4, 'rgb(230,230,50)'],
+ [0.6, 'rgb(120,70,20)'], [1, 'rgb(255,255,255)']
+ ],
+
+ 'Electric': [
+ [0, 'rgb(0,0,0)'], [0.15, 'rgb(30,0,100)'],
+ [0.4, 'rgb(120,0,100)'], [0.6, 'rgb(160,90,0)'],
+ [0.8, 'rgb(230,200,0)'], [1, 'rgb(255,250,220)']
+ ],
+
+ 'Viridis': [
+ [0, '#440154'], [0.06274509803921569, '#48186a'],
+ [0.12549019607843137, '#472d7b'], [0.18823529411764706, '#424086'],
+ [0.25098039215686274, '#3b528b'], [0.3137254901960784, '#33638d'],
+ [0.3764705882352941, '#2c728e'], [0.4392156862745098, '#26828e'],
+ [0.5019607843137255, '#21918c'], [0.5647058823529412, '#1fa088'],
+ [0.6274509803921569, '#28ae80'], [0.6901960784313725, '#3fbc73'],
+ [0.7529411764705882, '#5ec962'], [0.8156862745098039, '#84d44b'],
+ [0.8784313725490196, '#addc30'], [0.9411764705882353, '#d8e219'],
+ [1, '#fde725']
+ ],
+
+ 'Cividis': [
+ [0.000000, 'rgb(0,32,76)'], [0.058824, 'rgb(0,42,102)'],
+ [0.117647, 'rgb(0,52,110)'], [0.176471, 'rgb(39,63,108)'],
+ [0.235294, 'rgb(60,74,107)'], [0.294118, 'rgb(76,85,107)'],
+ [0.352941, 'rgb(91,95,109)'], [0.411765, 'rgb(104,106,112)'],
+ [0.470588, 'rgb(117,117,117)'], [0.529412, 'rgb(131,129,120)'],
+ [0.588235, 'rgb(146,140,120)'], [0.647059, 'rgb(161,152,118)'],
+ [0.705882, 'rgb(176,165,114)'], [0.764706, 'rgb(192,177,109)'],
+ [0.823529, 'rgb(209,191,102)'], [0.882353, 'rgb(225,204,92)'],
+ [0.941176, 'rgb(243,219,79)'], [1.000000, 'rgb(255,233,69)']
+ ]
+
}
@@ -47,7 +159,6 @@ def validate_index(index_vals):
:raises: (PlotlyError) If there are any two items in the list whose
types differ
"""
- from numbers import Number
if isinstance(index_vals[0], Number):
if not all(isinstance(item, Number) for item in index_vals):
raise exceptions.PlotlyError("Error in indexing column. "
@@ -70,7 +181,6 @@ def validate_dataframe(array):
:raises: (PlotlyError) If there are any two items in any list whose
types differ
"""
- from numbers import Number
for vector in array:
if isinstance(vector[0], Number):
if not all(isinstance(item, Number) for item in vector):
@@ -149,30 +259,39 @@ def find_intermediate_color(lowcolor, highcolor, intermed):
lowcolor[2] + intermed * diff_2)
-def n_colors(lowcolor, highcolor, n_colors):
+def n_colors(lowcolor, highcolor, n_colors, colortype='tuple'):
"""
Splits a low and high color into a list of n_colors colors in it
Accepts two color tuples and returns a list of n_colors colors
which form the intermediate colors between lowcolor and highcolor
- from linearly interpolating through RGB space
-
+ from linearly interpolating through RGB space. If colortype is 'rgb'
+ the function will return a list of colors in the same form.
"""
+ if colortype == 'rgb':
+ # convert to tuple
+ lowcolor = unlabel_rgb(lowcolor)
+ highcolor = unlabel_rgb(highcolor)
+
diff_0 = float(highcolor[0] - lowcolor[0])
incr_0 = diff_0/(n_colors - 1)
diff_1 = float(highcolor[1] - lowcolor[1])
incr_1 = diff_1/(n_colors - 1)
diff_2 = float(highcolor[2] - lowcolor[2])
incr_2 = diff_2/(n_colors - 1)
- color_tuples = []
+ list_of_colors = []
for index in range(n_colors):
new_tuple = (lowcolor[0] + (index * incr_0),
lowcolor[1] + (index * incr_1),
lowcolor[2] + (index * incr_2))
- color_tuples.append(new_tuple)
+ list_of_colors.append(new_tuple)
+
+ if colortype == 'rgb':
+ # back to an rgb string
+ list_of_colors = color_parser(list_of_colors, label_rgb)
- return color_tuples
+ return list_of_colors
def label_rgb(colors):
@@ -274,7 +393,6 @@ def color_parser(colors, function):
- rgb string, hex string or tuple
"""
- from numbers import Number
if isinstance(colors, str):
return function(colors)
@@ -295,13 +413,13 @@ def validate_colors(colors, colortype='tuple'):
"""
Validates color(s) and returns a list of color(s) of a specified type
"""
- from numbers import Number
if colors is None:
colors = DEFAULT_PLOTLY_COLORS
if isinstance(colors, str):
if colors in PLOTLY_SCALES:
- colors = PLOTLY_SCALES[colors]
+ colors_list = colorscale_to_colors(PLOTLY_SCALES[colors])
+ colors = [colors_list[0]] + [colors_list[-1]]
elif 'rgb' in colors or '#' in colors:
colors = [colors]
else:
@@ -343,7 +461,7 @@ def validate_colors(colors, colortype='tuple'):
)
colors[j] = each_color
- if colortype == 'rgb':
+ if colortype == 'rgb' and not isinstance(colors, six.string_types):
for j, each_color in enumerate(colors):
rgb_color = color_parser(each_color, convert_to_RGB_255)
colors[j] = color_parser(rgb_color, label_rgb)
@@ -386,6 +504,173 @@ def validate_colors_dict(colors, colortype='tuple'):
return colors
+
+def convert_colors_to_same_type(colors, colortype='rgb', scale=None,
+ return_default_colors=False,
+ num_of_defualt_colors=2):
+ """
+ Converts color(s) to the specified color type
+
+ Takes a single color or an iterable of colors, as well as a list of scale
+ values, and outputs a 2-pair of the list of color(s) converted all to an
+ rgb or tuple color type, aswell as the scale as the second element. If
+ colors is a Plotly Scale name, then 'scale' will be forced to the scale
+ from the respective colorscale and the colors in that colorscale will also
+ be coverted to the selected colortype. If colors is None, then there is an
+ option to return portion of the DEFAULT_PLOTLY_COLORS
+
+ :param (str|tuple|list) colors: either a plotly scale name, an rgb or hex
+ color, a color tuple or a list/tuple of colors
+ :param (list) scale: see docs for validate_scale_values()
+
+ :rtype (tuple) (colors_list, scale) if scale is None in the function call,
+ then scale will remain None in the returned tuple
+ """
+ #if colors_list is None:
+ colors_list = []
+
+ if colors is None and return_default_colors is True:
+ colors_list = DEFAULT_PLOTLY_COLORS[0:num_of_defualt_colors]
+
+ if isinstance(colors, str):
+ if colors in PLOTLY_SCALES:
+ colors_list = colorscale_to_colors(PLOTLY_SCALES[colors])
+ if scale is None:
+ scale = colorscale_to_scale(PLOTLY_SCALES[colors])
+
+ elif 'rgb' in colors or '#' in colors:
+ colors_list = [colors]
+
+ elif isinstance(colors, tuple):
+ if isinstance(colors[0], Number):
+ colors_list = [colors]
+ else:
+ colors_list = list(colors)
+
+ elif isinstance(colors, list):
+ colors_list = colors
+
+ # validate scale
+ if scale is not None:
+ validate_scale_values(scale)
+
+ if len(colors_list) != len(scale):
+ raise exceptions.PlotlyError(
+ 'Make sure that the length of your scale matches the length '
+ 'of your list of colors which is {}.'.format(len(colors_list))
+ )
+
+ # convert all colors to rgb
+ for j, each_color in enumerate(colors_list):
+ if '#' in each_color:
+ each_color = color_parser(
+ each_color, hex_to_rgb
+ )
+ each_color = color_parser(
+ each_color, label_rgb
+ )
+ colors_list[j] = each_color
+
+ elif isinstance(each_color, tuple):
+ each_color = color_parser(
+ each_color, convert_to_RGB_255
+ )
+ each_color = color_parser(
+ each_color, label_rgb
+ )
+ colors_list[j] = each_color
+
+ if colortype == 'rgb':
+ return (colors_list, scale)
+ elif colortype == 'tuple':
+ for j, each_color in enumerate(colors_list):
+ each_color = color_parser(
+ each_color, unlabel_rgb
+ )
+ each_color = color_parser(
+ each_color, unconvert_from_RGB_255
+ )
+ colors_list[j] = each_color
+ return (colors_list, scale)
+ else:
+ raise exceptions.PlotlyError('You must select either rgb or tuple '
+ 'for your colortype variable.')
+
+
+def convert_dict_colors_to_same_type(colors_dict, colortype='rgb'):
+ """
+ Converts a colors in a dictioanry of colors to the specified color type
+
+ :param (dict) colors_dict: a dictioanry whose values are single colors
+ """
+ for key in colors_dict:
+ if '#' in colors_dict[key]:
+ colors_dict[key] = color_parser(
+ colors_dict[key], hex_to_rgb
+ )
+ colors_dict[key] = color_parser(
+ colors_dict[key], label_rgb
+ )
+
+ elif isinstance(colors_dict[key], tuple):
+ colors_dict[key] = color_parser(
+ colors_dict[key], convert_to_RGB_255
+ )
+ colors_dict[key] = color_parser(
+ colors_dict[key], label_rgb
+ )
+
+ if colortype == 'rgb':
+ return colors_dict
+ elif colortype == 'tuple':
+ for key in colors_dict:
+ colors_dict[key] = color_parser(
+ colors_dict[key], unlabel_rgb
+ )
+ colors_dict[key] = color_parser(
+ colors_dict[key], unconvert_from_RGB_255
+ )
+ return colors_dict
+ else:
+ raise exceptions.PlotlyError('You must select either rgb or tuple '
+ 'for your colortype variable.')
+
+
+def make_colorscale(colors, scale=None):
+ """
+ Makes a colorscale from a list of colors and a scale
+
+ Takes a list of colors and scales and constructs a colorscale based
+ on the colors in sequential order. If 'scale' is left empty, a linear-
+ interpolated colorscale will be generated. If 'scale' is a specificed
+ list, it must be the same legnth as colors and must contain all floats
+ For documentation regarding to the form of the output, see
+ https://plot.ly/python/reference/#mesh3d-colorscale
+
+ :param (list) colors: a list of single colors
+ """
+ colorscale = []
+
+ # validate minimum colors length of 2
+ if len(colors) < 2:
+ raise exceptions.PlotlyError('You must input a list of colors that '
+ 'has at least two colors.')
+
+ if scale is None:
+ scale_incr = 1./(len(colors) - 1)
+ return [[i * scale_incr, color] for i, color in enumerate(colors)]
+
+ else:
+ if len(colors) != len(scale):
+ raise exceptions.PlotlyError('The length of colors and scale '
+ 'must be the same.')
+
+ validate_scale_values(scale)
+
+ colorscale = [list(tup) for tup in zip(scale, colors)]
+ return colorscale
+
+
def colorscale_to_colors(colorscale):
"""
Extracts the colors from colorscale as a list
@@ -406,6 +691,23 @@ def colorscale_to_scale(colorscale):
return scale_list
+def convert_colorscale_to_rgb(colorscale):
+ """
+ Converts the colors in a colorscale to rgb colors
+
+ A colorscale is an array of arrays, each with a numeric value as the
+ first item and a color as the second. This function specifically is
+ converting a colorscale with tuple colors (each coordinate between 0
+ and 1) into a colorscale with the colors transformed into rgb colors
+ """
+ for color in colorscale:
+ color[1] = convert_to_RGB_255(color[1])
+
+ for color in colorscale:
+ color[1] = label_rgb(color[1])
+ return colorscale
+
+
def validate_scale_values(scale):
"""
Validates scale values from a colorscale
diff --git a/plotly/tests/test_optional/test_figure_factory/test_figure_factory_utils.py b/plotly/tests/test_optional/test_figure_factory/test_figure_factory_utils.py
new file mode 100644
index 00000000000..7e3399e96e0
--- /dev/null
+++ b/plotly/tests/test_optional/test_figure_factory/test_figure_factory_utils.py
@@ -0,0 +1,244 @@
+from unittest import TestCase
+
+from nose.tools import raises
+import plotly.figure_factory.utils as utils
+import plotly.tools as tls
+from plotly.exceptions import PlotlyError
+
+
+class TestFigureFactoryUtils(TestCase):
+
+ def test_validate_index(self):
+ pattern = (
+ "Error in indexing column. "
+ "Make sure all entries of each "
+ "column are all numbers or "
+ "all strings."
+ )
+
+ self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_index,
+ [13, 'foo'])
+
+ self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_index,
+ ['number', 42])
+
+ def validate_dataframe(self):
+ pattern = (
+ "Error in dataframe. "
+ "Make sure all entries of "
+ "each column are either "
+ "numbers or strings."
+ )
+
+ df = [
+ [8, 'foo'],
+ ['amazing', 64]
+ ]
+
+ self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_index, df)
+
+ df = [
+ ['amazing', 64],
+ [8, 'foo']
+ ]
+
+ self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_index, df)
+
+
+ def validate_equal_length(self):
+ pattern = (
+ "Oops! Your data lists or ndarrays should be the same length."
+ )
+
+ self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_index,
+ (0,0,0), (0,0,0,0,0))
+
+
+ def test_validate_positive_scalars(self):
+ # only one negative number
+ self.assertRaises(ValueError, utils.validate_positive_scalars,
+ number0=1, number1=0.001, number2=-2)
+
+ def test_flatten(self):
+
+ self.assertRaises(
+ PlotlyError, utils.flatten,
+ array=0
+ )
+
+ def test_validate_colors(self):
+
+ # test string input
+ color_string = 'foo'
+
+ pattern = ("If your colors variable is a string, it must be a "
+ "Plotly scale, an rgb color or a hex color.")
+
+ self.assertRaisesRegexp(PlotlyError, pattern, utils.validate_colors,
+ color_string)
+
+ # test rgb color
+ color_string2 = 'rgb(265, 0, 0)'
+
+ pattern2 = ("Whoops! The elements in your rgb colors tuples cannot "
+ "exceed 255.0.")
+
+ self.assertRaisesRegexp(PlotlyError, pattern2, utils.validate_colors,
+ color_string2)
+
+ # test dictionary
+ colors_dict = {
+ 'apple': 'rgb(300, 0, 0)',
+ 'pear': (0, 0.5, 1)
+ }
+
+ self.assertRaisesRegexp(PlotlyError, pattern2, utils.validate_colors_dict,
+ colors_dict)
+
+ # test tuple color
+ color_tuple = (1, 1, 2)
+
+ pattern3 = ("Whoops! The elements in your colors tuples cannot "
+ "exceed 1.0.")
+
+ self.assertRaisesRegexp(PlotlyError, pattern3, utils.validate_colors,
+ color_tuple)
+
+ colors_dict2 = {
+ 'apple': 'rgb(255, 100, 50)',
+ 'pear': (0, 0.5, 2)
+ }
+
+ self.assertRaisesRegexp(PlotlyError, pattern3, utils.validate_colors_dict,
+ colors_dict2)
+
+ def test_convert_colors_to_same_type(self):
+
+ # test colortype
+ color_tuple = ['#aaaaaa', '#bbbbbb', '#cccccc']
+ scale = [0, 1]
+
+ self.assertRaises(PlotlyError, utils.convert_colors_to_same_type,
+ color_tuple, scale=scale)
+
+ # test colortype
+ color_tuple = (1, 1, 1)
+ colortype = 2
+
+ pattern2 = ("You must select either rgb or tuple for your colortype "
+ "variable.")
+
+ self.assertRaisesRegexp(PlotlyError, pattern2,
+ utils.convert_colors_to_same_type,
+ color_tuple, colortype)
+
+ def test_convert_dict_colors_to_same_type(self):
+
+ # test colortype
+ color_dict = dict(apple='rgb(1, 1, 1)')
+ colortype = 2
+
+ pattern = ("You must select either rgb or tuple for your colortype "
+ "variable.")
+
+ self.assertRaisesRegexp(PlotlyError, pattern,
+ utils.convert_dict_colors_to_same_type,
+ color_dict, colortype)
+
+ def test_validate_scale_values(self):
+
+ # test that scale length is at least 2
+ scale = [0]
+
+ pattern = ("You must input a list of scale values that has at least "
+ "two values.")
+
+ self.assertRaisesRegexp(PlotlyError, pattern,
+ utils.validate_scale_values,
+ scale)
+
+ # test if first and last number is 0 and 1 respectively
+ scale = [0, 1.1]
+
+ pattern = ("The first and last number in your scale must be 0.0 and "
+ "1.0 respectively.")
+
+ self.assertRaisesRegexp(PlotlyError, pattern,
+ utils.validate_scale_values,
+ scale)
+
+ # test numbers increase
+ scale = [0, 2, 1]
+
+ pattern = ("'scale' must be a list that contains a strictly "
+ "increasing sequence of numbers.")
+
+ self.assertRaisesRegexp(PlotlyError, pattern,
+ utils.validate_scale_values,
+ scale)
+
+ def test_make_colorscale(self):
+
+ # test minimum colors length
+ color_list = [(0, 0, 0)]
+
+ pattern = (
+ "You must input a list of colors that has at least two colors."
+ )
+
+ self.assertRaisesRegexp(PlotlyError, pattern, utils.make_colorscale,
+ color_list)
+
+ # test length of colors and scale
+ color_list2 = [(0, 0, 0), (1, 1, 1)]
+ scale = [0]
+
+ pattern2 = ("The length of colors and scale must be the same.")
+
+ self.assertRaisesRegexp(PlotlyError, pattern2, utils.make_colorscale,
+ color_list2, scale)
+
+
+ def test_endpts_to_intervals(self):
+
+ pattern = ("The intervals_endpts argument must "
+ "be a list or tuple of a sequence "
+ "of increasing numbers.")
+
+ endpts = 'foo'
+ self.assertRaisesRegexp(PlotlyError, pattern,
+ utils.endpts_to_intervals, endpts)
+
+ endpts = ['foo']
+ self.assertRaisesRegexp(PlotlyError, pattern,
+ utils.endpts_to_intervals, endpts)
+
+ endpts = [1, 0]
+ self.assertRaisesRegexp(PlotlyError, pattern,
+ utils.endpts_to_intervals, endpts)
+
+ def test_validate_colorscale(self):
+
+ pattern = "A valid colorscale must be a list."
+
+ colorscale = 55
+ self.assertRaisesRegexp(PlotlyError, pattern,
+ utils.validate_colorscale, colorscale)
+
+ pattern2 = "A valid colorscale must be a list of lists."
+
+ colorscale = [[], [], 'foo', []]
+ self.assertRaisesRegexp(PlotlyError, pattern,
+ utils.validate_colorscale, colorscale)
+
+ def test_list_of_options(self):
+
+ pattern = 'Your list or tuple must contain at least 2 items.'
+
+ self.assertRaisesRegexp(PlotlyError, pattern,
+ utils.list_of_options, ['item #1'])
+
+
+
+
+
diff --git a/plotly/tools.py b/plotly/tools.py
index 3f6fd488069..95e87efe4e8 100644
--- a/plotly/tools.py
+++ b/plotly/tools.py
@@ -27,23 +27,6 @@
REQUIRED_GANTT_KEYS = ['Task', 'Start', 'Finish']
-PLOTLY_SCALES = {'Greys': ['rgb(0,0,0)', 'rgb(255,255,255)'],
- 'YlGnBu': ['rgb(8,29,88)', 'rgb(255,255,217)'],
- 'Greens': ['rgb(0,68,27)', 'rgb(247,252,245)'],
- 'YlOrRd': ['rgb(128,0,38)', 'rgb(255,255,204)'],
- 'Bluered': ['rgb(0,0,255)', 'rgb(255,0,0)'],
- 'RdBu': ['rgb(5,10,172)', 'rgb(178,10,28)'],
- 'Reds': ['rgb(220,220,220)', 'rgb(178,10,28)'],
- 'Blues': ['rgb(5,10,172)', 'rgb(220,220,220)'],
- 'Picnic': ['rgb(0,0,255)', 'rgb(255,0,0)'],
- 'Rainbow': ['rgb(150,0,90)', 'rgb(255,0,0)'],
- 'Portland': ['rgb(12,51,131)', 'rgb(217,30,30)'],
- 'Jet': ['rgb(0,0,131)', 'rgb(128,0,0)'],
- 'Hot': ['rgb(0,0,0)', 'rgb(255,255,255)'],
- 'Blackbody': ['rgb(0,0,0)', 'rgb(160,200,255)'],
- 'Earth': ['rgb(0,0,130)', 'rgb(255,255,255)'],
- 'Electric': ['rgb(0,0,0)', 'rgb(255,250,220)'],
- 'Viridis': ['rgb(68,1,84)', 'rgb(253,231,37)']}
# color constants for violin plot
DEFAULT_FILLCOLOR = '#1f77b4'