@@ -2561,7 +2561,7 @@ def __init__(
2561
2561
)
2562
2562
2563
2563
self .class_strs_map = class_strs_map
2564
- self ._class_map = None
2564
+ self ._class_map = {}
2565
2565
self .set_uid = set_uid
2566
2566
2567
2567
def description (self ):
@@ -2597,21 +2597,17 @@ def description(self):
2597
2597
2598
2598
return desc
2599
2599
2600
- @property
2601
- def class_map (self ):
2602
- if self ._class_map is None :
2603
-
2604
- # Initialize class map
2605
- self ._class_map = {}
2606
-
2607
- # Import trace classes
2600
+ def get_trace_class (self , trace_name ):
2601
+ # Import trace classes
2602
+ if trace_name not in self ._class_map :
2608
2603
trace_module = import_module ("plotly.graph_objs" )
2609
- for k , class_str in self .class_strs_map . items ():
2610
- self ._class_map [k ] = getattr (trace_module , class_str )
2604
+ trace_class_name = self .class_strs_map [ trace_name ]
2605
+ self ._class_map [trace_name ] = getattr (trace_module , trace_class_name )
2611
2606
2612
- return self ._class_map
2607
+ return self ._class_map [ trace_name ]
2613
2608
2614
2609
def validate_coerce (self , v , skip_invalid = False ):
2610
+ from plotly .basedatatypes import BaseTraceType
2615
2611
2616
2612
# Import Histogram2dcontour, this is the deprecated name of the
2617
2613
# Histogram2dContour trace.
@@ -2623,13 +2619,11 @@ def validate_coerce(self, v, skip_invalid=False):
2623
2619
if not isinstance (v , (list , tuple )):
2624
2620
v = [v ]
2625
2621
2626
- trace_classes = tuple (self .class_map .values ())
2627
-
2628
2622
res = []
2629
2623
invalid_els = []
2630
2624
for v_el in v :
2631
2625
2632
- if isinstance (v_el , trace_classes ):
2626
+ if isinstance (v_el , BaseTraceType ):
2633
2627
# Clone input traces
2634
2628
v_el = v_el .to_plotly_json ()
2635
2629
@@ -2643,25 +2637,25 @@ def validate_coerce(self, v, skip_invalid=False):
2643
2637
else :
2644
2638
trace_type = "scatter"
2645
2639
2646
- if trace_type not in self .class_map :
2640
+ if trace_type not in self .class_strs_map :
2647
2641
if skip_invalid :
2648
2642
# Treat as scatter trace
2649
- trace = self .class_map [ "scatter" ] (
2643
+ trace = self .get_trace_class ( "scatter" ) (
2650
2644
skip_invalid = skip_invalid , ** v_copy
2651
2645
)
2652
2646
res .append (trace )
2653
2647
else :
2654
2648
res .append (None )
2655
2649
invalid_els .append (v_el )
2656
2650
else :
2657
- trace = self .class_map [ trace_type ] (
2651
+ trace = self .get_trace_class ( trace_type ) (
2658
2652
skip_invalid = skip_invalid , ** v_copy
2659
2653
)
2660
2654
res .append (trace )
2661
2655
else :
2662
2656
if skip_invalid :
2663
2657
# Add empty scatter trace
2664
- trace = self .class_map [ "scatter" ] ()
2658
+ trace = self .get_trace_class ( "scatter" ) ()
2665
2659
res .append (trace )
2666
2660
else :
2667
2661
res .append (None )
0 commit comments