diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index f41c767d0b13a..193b8f5053d65 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -9,7 +9,7 @@ import os import re import time -from typing import List, Optional, Type, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union import warnings import numpy as np @@ -55,6 +55,10 @@ from pandas.io.common import _stringify_path from pandas.io.formats.printing import adjoin, pprint_thing +if TYPE_CHECKING: + from tables import File # noqa:F401 + + # versioning attribute _version = "0.15.2" @@ -465,6 +469,8 @@ class HDFStore: >>> store.close() """ + _handle: Optional["File"] + def __init__( self, path, @@ -535,7 +541,7 @@ def __getattr__(self, name): ) ) - def __contains__(self, key): + def __contains__(self, key: str): """ check for existence of this key can match the exact pathname or the pathnm w/o the leading '/' """ @@ -560,7 +566,7 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() - def keys(self): + def keys(self) -> List[str]: """ Return a list of keys corresponding to objects stored in HDFStore. @@ -698,13 +704,13 @@ def flush(self, fsync: bool = False): except OSError: pass - def get(self, key): + def get(self, key: str): """ Retrieve pandas object stored in file. Parameters ---------- - key : object + key : str Returns ------- @@ -718,7 +724,7 @@ def get(self, key): def select( self, - key, + key: str, where=None, start=None, stop=None, @@ -733,7 +739,7 @@ def select( Parameters ---------- - key : object + key : str Object being retrieved from file. where : list, default None List of Term (or convertible) objects, optional. @@ -784,13 +790,15 @@ def func(_start, _stop, _where): return it.get_result() - def select_as_coordinates(self, key, where=None, start=None, stop=None, **kwargs): + def select_as_coordinates( + self, key: str, where=None, start=None, stop=None, **kwargs + ): """ return the selection as an Index Parameters ---------- - key : object + key : str where : list of Term (or convertible) objects, optional start : integer (defaults to None), row number to start selection stop : integer (defaults to None), row number to stop selection @@ -800,15 +808,16 @@ def select_as_coordinates(self, key, where=None, start=None, stop=None, **kwargs where=where, start=start, stop=stop, **kwargs ) - def select_column(self, key, column, **kwargs): + def select_column(self, key: str, column: str, **kwargs): """ return a single column from the table. This is generally only useful to select an indexable Parameters ---------- - key : object - column: the column of interest + key : str + column: str + The column of interest. Raises ------ @@ -966,7 +975,7 @@ def put(self, key, value, format=None, append=False, **kwargs): kwargs = self._validate_format(format, kwargs) self._write_to_group(key, value, append=append, **kwargs) - def remove(self, key, where=None, start=None, stop=None): + def remove(self, key: str, where=None, start=None, stop=None): """ Remove pandas object partially by specifying the where condition @@ -1152,16 +1161,17 @@ def append_to_multiple( self.append(k, val, data_columns=dc, **kwargs) - def create_table_index(self, key, **kwargs): - """ Create a pytables index on the table + def create_table_index(self, key: str, **kwargs): + """ + Create a pytables index on the table. + Parameters ---------- - key : object (the node to index) + key : str Raises ------ - raises if the node is not a table - + TypeError: raises if the node is not a table """ # version requirements @@ -1247,17 +1257,19 @@ def walk(self, where="/"): yield (g._v_pathname.rstrip("/"), groups, leaves) - def get_node(self, key): + def get_node(self, key: str): """ return the node with the key or None if it does not exist """ self._check_if_open() + if not key.startswith("/"): + key = "/" + key + + assert self._handle is not None try: - if not key.startswith("/"): - key = "/" + key return self._handle.get_node(self.root, key) - except _table_mod.exceptions.NoSuchNodeError: + except _table_mod.exceptions.NoSuchNodeError: # type: ignore return None - def get_storer(self, key): + def get_storer(self, key: str): """ return the storer object for a key, raise if not in the file """ group = self.get_node(key) if group is None: @@ -1481,7 +1493,7 @@ def error(t): def _write_to_group( self, - key, + key: str, value, format, index=True, @@ -1492,6 +1504,10 @@ def _write_to_group( ): group = self.get_node(key) + # we make this assertion for mypy; the get_node call will already + # have raised if this is incorrect + assert self._handle is not None + # remove the node if we are not appending if group is not None and not append: self._handle.remove_node(group, recursive=True) @@ -2691,7 +2707,7 @@ def f(values, freq=None, tz=None): return klass - def validate_read(self, kwargs): + def validate_read(self, kwargs: Dict[str, Any]) -> Dict[str, Any]: """ remove table keywords from kwargs and return raise if any keywords are passed which are not-None @@ -2733,7 +2749,7 @@ def get_attrs(self): def write(self, obj, **kwargs): self.set_attrs() - def read_array(self, key, start=None, stop=None): + def read_array(self, key: str, start=None, stop=None): """ read an array for the specified node (off of group """ import tables @@ -4008,7 +4024,7 @@ def read_coordinates(self, where=None, start=None, stop=None, **kwargs): return Index(coords) - def read_column(self, column, where=None, start=None, stop=None): + def read_column(self, column: str, where=None, start=None, stop=None): """return a single column from the table, generally only indexables are interesting """ @@ -4642,8 +4658,8 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None): converted, "datetime64", _tables().Int64Col(), - freq=getattr(index, "freq", None), - tz=getattr(index, "tz", None), + freq=index.freq, + tz=index.tz, index_name=index_name, ) elif isinstance(index, TimedeltaIndex): @@ -4652,7 +4668,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None): converted, "timedelta64", _tables().Int64Col(), - freq=getattr(index, "freq", None), + freq=index.freq, index_name=index_name, ) elif isinstance(index, (Int64Index, PeriodIndex)):