diff --git a/packages/python/plotly/_plotly_utils/colors/__init__.py b/packages/python/plotly/_plotly_utils/colors/__init__.py index 515e7c99f6a..00856b782f6 100644 --- a/packages/python/plotly/_plotly_utils/colors/__init__.py +++ b/packages/python/plotly/_plotly_utils/colors/__init__.py @@ -833,3 +833,48 @@ def get_colorscale(name): if should_reverse: return colorscale[::-1] return colorscale + + +def sample_colorscale(colorscale, samplepoints, low=0.0, high=1.0, colortype="rgb"): + """ + Samples a colorscale at specific points. + + Interpolates between colors in a colorscale to find the specific colors + corresponding to the specified sample values. The colorscale can be specified + as a list of `[scale, color]` pairs, as a list of colors, or as a named + plotly colorscale. The samplepoints can be specefied as an iterable of specific + points in the range [0.0, 1.0], or as an integer number of points which will + be spaced equally between the low value (default 0.0) and the high value + (default 1.0). The output is a list of colors, formatted according to the + specified colortype. + """ + from bisect import bisect_left + + try: + validate_colorscale(colorscale) + except exceptions.PlotlyError: + if isinstance(colorscale, str): + colorscale = get_colorscale(colorscale) + else: + colorscale = make_colorscale(colorscale) + + scale = colorscale_to_scale(colorscale) + validate_scale_values(scale) + colors = colorscale_to_colors(colorscale) + colors = validate_colors(colors, colortype="tuple") + + if isinstance(samplepoints, int): + samplepoints = [ + low + idx / (samplepoints - 1) * (high - low) for idx in range(samplepoints) + ] + elif isinstance(samplepoints, float): + samplepoints = [samplepoints] + + sampled_colors = [] + for point in samplepoints: + high = bisect_left(scale, point) + low = high - 1 + interpolant = (point - scale[low]) / (scale[high] - scale[low]) + sampled_color = find_intermediate_color(colors[low], colors[high], interpolant) + sampled_colors.append(sampled_color) + return validate_colors(sampled_colors, colortype=colortype) diff --git a/packages/python/plotly/plotly/colors/__init__.py b/packages/python/plotly/plotly/colors/__init__.py index b08d52c87e6..16fa195b4b3 100644 --- a/packages/python/plotly/plotly/colors/__init__.py +++ b/packages/python/plotly/plotly/colors/__init__.py @@ -36,6 +36,7 @@ "label_rgb", "make_colorscale", "n_colors", + "sample_colorscale", "unconvert_from_RGB_255", "unlabel_rgb", "validate_colors", diff --git a/packages/python/plotly/plotly/tests/test_core/test_colors/test_colors.py b/packages/python/plotly/plotly/tests/test_core/test_colors/test_colors.py index 0d393b33b90..230deedec2a 100644 --- a/packages/python/plotly/plotly/tests/test_core/test_colors/test_colors.py +++ b/packages/python/plotly/plotly/tests/test_core/test_colors/test_colors.py @@ -149,7 +149,6 @@ def test_get_colorscale(self): # test for non-existing colorscale pattern = r"Colorscale \S+ is not a built-in scale." name = "foo" - self.assertRaisesRegex(PlotlyError, pattern, colors.get_colorscale, name) # test non-capitalised access @@ -164,3 +163,34 @@ def test_get_colorscale(self): self.assertEqual( colors.diverging.Portland_r, colors.get_colorscale("portland_r") ) + + def test_sample_colorscale(self): + + # test that sampling a colorscale at the defined points returns the same + defined_colors = colors.sequential.Inferno + sampled_colors = colors.sample_colorscale( + defined_colors, len(defined_colors), colortype="rgb" + ) + defined_colors_rgb = colors.convert_colors_to_same_type( + defined_colors, colortype="rgb" + )[0] + self.assertEqual(sampled_colors, defined_colors_rgb) + + # test sampling an easy colorscale that goes [red, green, blue] + defined_colors = ["rgb(255,0,0)", "rgb(0,255,0)", "rgb(0,0,255)"] + samplepoints = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0] + expected_output = [ + (1.0, 0.0, 0.0), + (0.75, 0.25, 0.0), + (0.5, 0.5, 0.0), + (0.25, 0.75, 0.0), + (0.0, 1.0, 0.0), + (0.0, 0.75, 0.25), + (0.0, 0.5, 0.5), + (0.0, 0.25, 0.75), + (0.0, 0.0, 1.0), + ] + output = colors.sample_colorscale( + defined_colors, samplepoints, colortype="tuple" + ) + self.assertEqual(expected_output, output)