9
9
from contextlib import contextmanager
10
10
from copy import deepcopy , copy
11
11
12
- from plotly .subplots import _set_trace_grid_reference , _get_grid_subplot
12
+ from plotly .subplots import (
13
+ _set_trace_grid_reference ,
14
+ _get_grid_subplot ,
15
+ _get_subplot_ref_for_trace ,
16
+ _validate_v4_subplots )
13
17
from .optional_imports import get_module
14
18
15
19
from _plotly_utils .basevalidators import (
@@ -33,13 +37,25 @@ class BaseFigure(object):
33
37
"""
34
38
_bracket_re = re .compile ('^(.*)\[(\d+)\]$' )
35
39
40
+ _valid_underscore_properties = {
41
+ 'error_x' : 'error-x' ,
42
+ 'error_y' : 'error-y' ,
43
+ 'error_z' : 'error-z' ,
44
+ 'copy_xstyle' : 'copy-xstyle' ,
45
+ 'copy_ystyle' : 'copy-ystyle' ,
46
+ 'copy_zstyle' : 'copy-zstyle' ,
47
+ 'paper_bgcolor' : 'paper-bgcolor' ,
48
+ 'plot_bgcolor' : 'plot-bgcolor'
49
+ }
50
+
36
51
# Constructor
37
52
# -----------
38
53
def __init__ (self ,
39
54
data = None ,
40
55
layout_plotly = None ,
41
56
frames = None ,
42
- skip_invalid = False ):
57
+ skip_invalid = False ,
58
+ ** kwargs ):
43
59
"""
44
60
Construct a BaseFigure object
45
61
@@ -247,6 +263,14 @@ class is a subclass of both BaseFigure and widgets.DOMWidget.
247
263
# ### Check for default template ###
248
264
self ._initialize_layout_template ()
249
265
266
+ # Process kwargs
267
+ # --------------
268
+ for k , v in kwargs .items ():
269
+ if k in self :
270
+ self [k ] = v
271
+ elif not skip_invalid :
272
+ raise TypeError ('invalid Figure property: {}' .format (k ))
273
+
250
274
# Magic Methods
251
275
# -------------
252
276
def __reduce__ (self ):
@@ -356,7 +380,13 @@ def __iter__(self):
356
380
return iter (('data' , 'layout' , 'frames' ))
357
381
358
382
def __contains__ (self , prop ):
359
- return prop in ('data' , 'layout' , 'frames' )
383
+ prop = BaseFigure ._str_to_dict_path (prop )
384
+ if prop [0 ] not in ('data' , 'layout' , 'frames' ):
385
+ return False
386
+ elif len (prop ) == 1 :
387
+ return True
388
+ else :
389
+ return prop [1 :] in self [prop [0 ]]
360
390
361
391
def __eq__ (self , other ):
362
392
if not isinstance (other , BaseFigure ):
@@ -447,17 +477,21 @@ def update(self, dict1=None, **kwargs):
447
477
for d in [dict1 , kwargs ]:
448
478
if d :
449
479
for k , v in d .items ():
450
- if self [k ] == ():
480
+ update_target = self [k ]
481
+ if update_target == ():
451
482
# existing data or frames property is empty
452
483
# In this case we accept the v as is.
453
484
if k == 'data' :
454
485
self .add_traces (v )
455
486
else :
456
487
# Accept v
457
488
self [k ] = v
458
- else :
489
+ elif (isinstance (update_target , BasePlotlyType ) or
490
+ (isinstance (update_target , tuple ) and
491
+ isinstance (update_target [0 ], BasePlotlyType ))):
459
492
BaseFigure ._perform_update (self [k ], v )
460
-
493
+ else :
494
+ self [k ] = v
461
495
return self
462
496
463
497
# Data
@@ -604,6 +638,140 @@ def data(self, new_data):
604
638
for trace_ind , trace in enumerate (self ._data_objs ):
605
639
trace ._trace_ind = trace_ind
606
640
641
+ def select_traces (self , selector = None , row = None , col = None ):
642
+ """
643
+ Select traces from a particular subplot cell and/or traces
644
+ that satisfy custom selection criteria.
645
+
646
+ Parameters
647
+ ----------
648
+ selector: dict or None (default None)
649
+ Dict to use as selection criteria.
650
+ Traces will be selected if they contain properties corresponding
651
+ to all of the dictionary's keys, with values that exactly match
652
+ the supplied values. If None (the default), all traces are
653
+ selected.
654
+ row, col: int or None (default None)
655
+ Subplot row and column index of traces to select.
656
+ To select traces by row and column, the Figure must have been
657
+ created using plotly.subplots.make_subplots. If None
658
+ (the default), all traces are selected.
659
+
660
+ Returns
661
+ -------
662
+ generator
663
+ Generator that iterates through all of the traces that satisfy
664
+ all of the specified selection criteria
665
+ """
666
+ if not selector :
667
+ selector = {}
668
+
669
+ if row is not None and col is not None :
670
+ _validate_v4_subplots ('select_traces' )
671
+ grid_ref = self ._validate_get_grid_ref ()
672
+ grid_subplot_ref = grid_ref [row - 1 ][col - 1 ]
673
+ filter_by_subplot = True
674
+ else :
675
+ filter_by_subplot = False
676
+ grid_subplot_ref = None
677
+
678
+ return self ._perform_select_traces (
679
+ filter_by_subplot , grid_subplot_ref , selector )
680
+
681
+ def _perform_select_traces (
682
+ self , filter_by_subplot , grid_subplot_ref , selector ):
683
+
684
+ def select_eq (obj1 , obj2 ):
685
+ try :
686
+ obj1 = obj1 .to_plotly_json ()
687
+ except Exception :
688
+ pass
689
+ try :
690
+ obj2 = obj2 .to_plotly_json ()
691
+ except Exception :
692
+ pass
693
+
694
+ return BasePlotlyType ._vals_equal (obj1 , obj2 )
695
+
696
+ for trace in self .data :
697
+ # Filter by subplot
698
+ if filter_by_subplot :
699
+ trace_subplot_ref = _get_subplot_ref_for_trace (trace )
700
+ if grid_subplot_ref != trace_subplot_ref :
701
+ continue
702
+
703
+ # Filter by selector
704
+ if not all (
705
+ k in trace and select_eq (trace [k ], selector [k ])
706
+ for k in selector ):
707
+ continue
708
+
709
+ yield trace
710
+
711
+ def for_each_trace (self , fn , selector = None , row = None , col = None ):
712
+ """
713
+ Apply a function to all traces that satisfy the specified selection
714
+ criteria
715
+
716
+ Parameters
717
+ ----------
718
+ fn:
719
+ Function that inputs a single trace object.
720
+ selector: dict or None (default None)
721
+ Dict to use as selection criteria.
722
+ Traces will be selected if they contain properties corresponding
723
+ to all of the dictionary's keys, with values that exactly match
724
+ the supplied values. If None (the default), all traces are
725
+ selected.
726
+ row, col: int or None (default None)
727
+ Subplot row and column index of traces to select.
728
+ To select traces by row and column, the Figure must have been
729
+ created using plotly.subplots.make_subplots. If None
730
+ (the default), all traces are selected.
731
+
732
+ Returns
733
+ -------
734
+ self
735
+ Returns the Figure object that the method was called on
736
+ """
737
+ for trace in self .select_traces (selector = selector , row = row , col = col ):
738
+ fn (trace )
739
+
740
+ return self
741
+
742
+ def update_traces (self , patch , selector = None , row = None , col = None ):
743
+ """
744
+ Perform a property update operation on all traces that satisfy the
745
+ specified selection criteria
746
+
747
+ Parameters
748
+ ----------
749
+ patch: dict
750
+ Dictionary of property updates to be applied to all traces that
751
+ satisfy the selection criteria.
752
+ fn:
753
+ Function that inputs a single trace object.
754
+ selector: dict or None (default None)
755
+ Dict to use as selection criteria.
756
+ Traces will be selected if they contain properties corresponding
757
+ to all of the dictionary's keys, with values that exactly match
758
+ the supplied values. If None (the default), all traces are
759
+ selected.
760
+ row, col: int or None (default None)
761
+ Subplot row and column index of traces to select.
762
+ To select traces by row and column, the Figure must have been
763
+ created using plotly.subplots.make_subplots. If None
764
+ (the default), all traces are selected.
765
+
766
+ Returns
767
+ -------
768
+ self
769
+ Returns the Figure object that the method was called on
770
+ """
771
+ for trace in self .select_traces (selector = selector , row = row , col = col ):
772
+ trace .update (patch )
773
+ return self
774
+
607
775
# Restyle
608
776
# -------
609
777
def plotly_restyle (self , restyle_data , trace_indexes = None , ** kwargs ):
@@ -822,18 +990,20 @@ def _str_to_dict_path(key_path_str):
822
990
"""
823
991
if isinstance (key_path_str , string_types ) and \
824
992
'.' not in key_path_str and \
825
- '[' not in key_path_str :
993
+ '[' not in key_path_str and \
994
+ '_' not in key_path_str :
826
995
# Fast path for common case that avoids regular expressions
827
996
return (key_path_str ,)
828
997
elif isinstance (key_path_str , tuple ):
829
998
# Nothing to do
830
999
return key_path_str
831
1000
else :
832
- # Split string on periods. e.g. 'foo.bar[1]' -> ['foo', 'bar[1]']
1001
+ # Split string on periods.
1002
+ # e.g. 'foo.bar_baz[1]' -> ['foo', 'bar_baz[1]']
833
1003
key_path = key_path_str .split ('.' )
834
1004
835
1005
# Split out bracket indexes.
836
- # e.g. ['foo', 'bar [1]'] -> ['foo', 'bar ', '1']
1006
+ # e.g. ['foo', 'bar_baz [1]'] -> ['foo', 'bar_baz ', '1']
837
1007
key_path2 = []
838
1008
for key in key_path :
839
1009
match = BaseFigure ._bracket_re .match (key )
@@ -842,15 +1012,39 @@ def _str_to_dict_path(key_path_str):
842
1012
else :
843
1013
key_path2 .append (key )
844
1014
1015
+ # Split out underscore
1016
+ # e.g. ['foo', 'bar_baz', '1'] -> ['foo', 'bar', 'baz', '1']
1017
+ key_path3 = []
1018
+ underscore_props = BaseFigure ._valid_underscore_properties
1019
+ for key in key_path2 :
1020
+ if '_' in key [1 :]:
1021
+ # For valid properties that contain underscores (error_x)
1022
+ # replace the underscores with hyphens to protect them
1023
+ # from being split up
1024
+ for under_prop , hyphen_prop in underscore_props .items ():
1025
+ key = key .replace (under_prop , hyphen_prop )
1026
+
1027
+ # Split key on underscores
1028
+ key = key .split ('_' )
1029
+
1030
+ # Replace hyphens with underscores to restore properties
1031
+ # that include underscores
1032
+ for i in range (len (key )):
1033
+ key [i ] = key [i ].replace ('-' , '_' )
1034
+
1035
+ key_path3 .extend (key )
1036
+ else :
1037
+ key_path3 .append (key )
1038
+
845
1039
# Convert elements to ints if possible.
846
1040
# e.g. ['foo', 'bar', '0'] -> ['foo', 'bar', 0]
847
- for i in range (len (key_path2 )):
1041
+ for i in range (len (key_path3 )):
848
1042
try :
849
- key_path2 [i ] = int (key_path2 [i ])
1043
+ key_path3 [i ] = int (key_path3 [i ])
850
1044
except ValueError as _ :
851
1045
pass
852
1046
853
- return tuple (key_path2 )
1047
+ return tuple (key_path3 )
854
1048
855
1049
@staticmethod
856
1050
def _set_in (d , key_path_str , v ):
@@ -1235,13 +1429,8 @@ def append_trace(self, trace, row, col):
1235
1429
self .add_trace (trace = trace , row = row , col = col )
1236
1430
1237
1431
def _set_trace_grid_position (self , trace , row , col ):
1238
- try :
1239
- grid_ref = self ._grid_ref
1240
- except AttributeError :
1241
- raise Exception ("In order to reference traces by row and column, "
1242
- "you must first use "
1243
- "plotly.tools.make_subplots "
1244
- "to create the figure with a subplot grid." )
1432
+ grid_ref = self ._validate_get_grid_ref ()
1433
+
1245
1434
from _plotly_future_ import _future_flags
1246
1435
if 'v4_subplots' in _future_flags :
1247
1436
return _set_trace_grid_reference (
@@ -1277,6 +1466,18 @@ def _set_trace_grid_position(self, trace, row, col):
1277
1466
trace ['xaxis' ] = ref [0 ]
1278
1467
trace ['yaxis' ] = ref [1 ]
1279
1468
1469
+ def _validate_get_grid_ref (self ):
1470
+ try :
1471
+ grid_ref = self ._grid_ref
1472
+ if grid_ref is None :
1473
+ raise AttributeError ('_grid_ref' )
1474
+ except AttributeError :
1475
+ raise Exception ("In order to reference traces by row and column, "
1476
+ "you must first use "
1477
+ "plotly.tools.make_subplots "
1478
+ "to create the figure with a subplot grid." )
1479
+ return grid_ref
1480
+
1280
1481
def get_subplot (self , row , col ):
1281
1482
"""
1282
1483
Return an object representing the subplot at the specified row
@@ -2429,8 +2630,16 @@ def _process_kwargs(self, **kwargs):
2429
2630
"""
2430
2631
Process any extra kwargs that are not predefined as constructor params
2431
2632
"""
2432
- if not self ._skip_invalid :
2433
- self ._raise_on_invalid_property_error (* kwargs .keys ())
2633
+ invalid_kwargs = {}
2634
+ for k , v in kwargs .items ():
2635
+ if k in self :
2636
+ # e.g. underscore kwargs like marker_line_color
2637
+ self [k ] = v
2638
+ else :
2639
+ invalid_kwargs [k ] = v
2640
+
2641
+ if invalid_kwargs and not self ._skip_invalid :
2642
+ self ._raise_on_invalid_property_error (* invalid_kwargs .keys ())
2434
2643
2435
2644
@property
2436
2645
def plotly_name (self ):
@@ -2675,7 +2884,9 @@ def _get_prop_validator(self, prop):
2675
2884
plotly_obj = self [prop_path [:- 1 ]]
2676
2885
prop = prop_path [- 1 ]
2677
2886
else :
2678
- plotly_obj = self
2887
+ prop_path = BaseFigure ._str_to_dict_path (prop )
2888
+ plotly_obj = self [prop_path [:- 1 ]]
2889
+ prop = prop_path [- 1 ]
2679
2890
2680
2891
# Return validator
2681
2892
# ----------------
0 commit comments