Skip to content

Import and initialization optimizations #2368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
The diff you're trying to view is too large. We only load the first 3000 changed files.
52 changes: 28 additions & 24 deletions packages/python/plotly/_plotly_utils/basevalidators.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ def to_scalar_or_list(v):
# Python native scalar type ('float' in the example above).
# We explicitly check if is has the 'item' method, which conventionally
# converts these types to native scalars.
np = get_module("numpy")
pd = get_module("pandas")
np = get_module("numpy", should_load=False)
pd = get_module("pandas", should_load=False)
if np and np.isscalar(v) and hasattr(v, "item"):
return v.item()
if isinstance(v, (list, tuple)):
Expand Down Expand Up @@ -74,7 +74,9 @@ def copy_to_readonly_numpy_array(v, kind=None, force_numeric=False):
Numpy array with the 'WRITEABLE' flag set to False
"""
np = get_module("numpy")
pd = get_module("pandas")

# Don't force pandas to be loaded, we only want to know if it's already loaded
pd = get_module("pandas", should_load=False)
assert np is not None

# ### Process kind ###
Expand Down Expand Up @@ -166,8 +168,8 @@ def is_homogeneous_array(v):
"""
Return whether a value is considered to be a homogeneous array
"""
np = get_module("numpy")
pd = get_module("pandas")
np = get_module("numpy", should_load=False)
pd = get_module("pandas", should_load=False)
if (
np
and isinstance(v, np.ndarray)
Expand Down Expand Up @@ -2455,6 +2457,10 @@ def validate_coerce(self, v, skip_invalid=False):
v._plotly_name = self.plotly_name
return v

def present(self, v):
# Return compound object as-is
return v


class TitleValidator(CompoundValidator):
"""
Expand Down Expand Up @@ -2549,6 +2555,10 @@ def validate_coerce(self, v, skip_invalid=False):

return v

def present(self, v):
# Return compound object as tuple
return tuple(v)


class BaseDataValidator(BaseValidator):
def __init__(
Expand All @@ -2559,7 +2569,7 @@ def __init__(
)

self.class_strs_map = class_strs_map
self._class_map = None
self._class_map = {}
self.set_uid = set_uid

def description(self):
Expand Down Expand Up @@ -2595,21 +2605,17 @@ def description(self):

return desc

@property
def class_map(self):
if self._class_map is None:

# Initialize class map
self._class_map = {}

# Import trace classes
def get_trace_class(self, trace_name):
# Import trace classes
if trace_name not in self._class_map:
trace_module = import_module("plotly.graph_objs")
for k, class_str in self.class_strs_map.items():
self._class_map[k] = getattr(trace_module, class_str)
trace_class_name = self.class_strs_map[trace_name]
self._class_map[trace_name] = getattr(trace_module, trace_class_name)

return self._class_map
return self._class_map[trace_name]

def validate_coerce(self, v, skip_invalid=False):
from plotly.basedatatypes import BaseTraceType

# Import Histogram2dcontour, this is the deprecated name of the
# Histogram2dContour trace.
Expand All @@ -2621,13 +2627,11 @@ def validate_coerce(self, v, skip_invalid=False):
if not isinstance(v, (list, tuple)):
v = [v]

trace_classes = tuple(self.class_map.values())

res = []
invalid_els = []
for v_el in v:

if isinstance(v_el, trace_classes):
if isinstance(v_el, BaseTraceType):
# Clone input traces
v_el = v_el.to_plotly_json()

Expand All @@ -2641,25 +2645,25 @@ def validate_coerce(self, v, skip_invalid=False):
else:
trace_type = "scatter"

if trace_type not in self.class_map:
if trace_type not in self.class_strs_map:
if skip_invalid:
# Treat as scatter trace
trace = self.class_map["scatter"](
trace = self.get_trace_class("scatter")(
skip_invalid=skip_invalid, **v_copy
)
res.append(trace)
else:
res.append(None)
invalid_els.append(v_el)
else:
trace = self.class_map[trace_type](
trace = self.get_trace_class(trace_type)(
skip_invalid=skip_invalid, **v_copy
)
res.append(trace)
else:
if skip_invalid:
# Add empty scatter trace
trace = self.class_map["scatter"]()
trace = self.get_trace_class("scatter")()
res.append(trace)
else:
res.append(None)
Expand Down
50 changes: 50 additions & 0 deletions packages/python/plotly/_plotly_utils/importers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import importlib


def relative_import(parent_name, rel_modules=(), rel_classes=()):
"""
Helper function to import submodules lazily in Python 3.7+

Parameters
----------
rel_modules: list of str
list of submodules to import, of the form .submodule
rel_classes: list of str
list of submodule classes/variables to import, of the form ._submodule.Foo

Returns
-------
tuple
Tuple that should be assigned to __all__, __getattr__ in the caller
"""
module_names = {rel_module.split(".")[-1]: rel_module for rel_module in rel_modules}
class_names = {rel_path.split(".")[-1]: rel_path for rel_path in rel_classes}

def __getattr__(import_name):
# In Python 3.7+, lazy import submodules

# Check for submodule
if import_name in module_names:
rel_import = module_names[import_name]
return importlib.import_module(rel_import, parent_name)

# Check for submodule class
if import_name in class_names:
rel_path_parts = class_names[import_name].split(".")
rel_module = ".".join(rel_path_parts[:-1])
class_name = import_name
class_module = importlib.import_module(rel_module, parent_name)
return getattr(class_module, class_name)

raise AttributeError(
"module {__name__!r} has no attribute {name!r}".format(
name=import_name, __name__=parent_name
)
)

__all__ = list(module_names) + list(class_names)

def __dir__():
return __all__

return __all__, __getattr__, __dir__
4 changes: 3 additions & 1 deletion packages/python/plotly/_plotly_utils/optional_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
_not_importable = set()


def get_module(name):
def get_module(name, should_load=True):
"""
Return module or None. Absolute import is required.

Expand All @@ -23,6 +23,8 @@ def get_module(name):
"""
if name in sys.modules:
return sys.modules[name]
if not should_load:
return None
if name not in _not_importable:
try:
return import_module(name)
Expand Down
4 changes: 2 additions & 2 deletions packages/python/plotly/_plotly_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def encode_as_sage(obj):
@staticmethod
def encode_as_pandas(obj):
"""Attempt to convert pandas.NaT"""
pandas = get_module("pandas")
pandas = get_module("pandas", should_load=False)
if not pandas:
raise NotEncodable

Expand All @@ -159,7 +159,7 @@ def encode_as_pandas(obj):
@staticmethod
def encode_as_numpy(obj):
"""Attempt to convert numpy.ma.core.masked"""
numpy = get_module("numpy")
numpy = get_module("numpy", should_load=False)
if not numpy:
raise NotEncodable

Expand Down
Loading