Skip to content

TYP: add string annotations in io.pytables #29682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 18, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import re
import time
from typing import List, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
import warnings

import numpy as np
Expand Down Expand Up @@ -55,6 +55,10 @@
from pandas.io.common import _stringify_path
from pandas.io.formats.printing import adjoin, pprint_thing

if TYPE_CHECKING:
from tables import File # noqa:F401


# versioning attribute
_version = "0.15.2"

Expand Down Expand Up @@ -465,6 +469,8 @@ class HDFStore:
>>> store.close()
"""

_handle: Optional["File"]

def __init__(
self,
path,
Expand Down Expand Up @@ -535,7 +541,7 @@ def __getattr__(self, name):
)
)

def __contains__(self, key):
def __contains__(self, key: str):
""" check for existence of this key
can match the exact pathname or the pathnm w/o the leading '/'
"""
Expand All @@ -560,7 +566,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
self.close()

def keys(self):
def keys(self) -> List[str]:
"""
Return a list of keys corresponding to objects stored in HDFStore.

Expand Down Expand Up @@ -698,13 +704,13 @@ def flush(self, fsync: bool = False):
except OSError:
pass

def get(self, key):
def get(self, key: str):
"""
Retrieve pandas object stored in file.

Parameters
----------
key : object
key : str

Returns
-------
Expand All @@ -718,7 +724,7 @@ def get(self, key):

def select(
self,
key,
key: str,
where=None,
start=None,
stop=None,
Expand All @@ -733,7 +739,7 @@ def select(

Parameters
----------
key : object
key : str
Object being retrieved from file.
where : list, default None
List of Term (or convertible) objects, optional.
Expand Down Expand Up @@ -784,13 +790,15 @@ def func(_start, _stop, _where):

return it.get_result()

def select_as_coordinates(self, key, where=None, start=None, stop=None, **kwargs):
def select_as_coordinates(
self, key: str, where=None, start=None, stop=None, **kwargs
):
"""
return the selection as an Index

Parameters
----------
key : object
key : str
where : list of Term (or convertible) objects, optional
start : integer (defaults to None), row number to start selection
stop : integer (defaults to None), row number to stop selection
Expand All @@ -800,15 +808,16 @@ def select_as_coordinates(self, key, where=None, start=None, stop=None, **kwargs
where=where, start=start, stop=stop, **kwargs
)

def select_column(self, key, column, **kwargs):
def select_column(self, key: str, column: str, **kwargs):
"""
return a single column from the table. This is generally only useful to
select an indexable

Parameters
----------
key : object
column: the column of interest
key : str
column: str
The column of interest.

Raises
------
Expand Down Expand Up @@ -966,7 +975,7 @@ def put(self, key, value, format=None, append=False, **kwargs):
kwargs = self._validate_format(format, kwargs)
self._write_to_group(key, value, append=append, **kwargs)

def remove(self, key, where=None, start=None, stop=None):
def remove(self, key: str, where=None, start=None, stop=None):
"""
Remove pandas object partially by specifying the where condition

Expand Down Expand Up @@ -1152,16 +1161,17 @@ def append_to_multiple(

self.append(k, val, data_columns=dc, **kwargs)

def create_table_index(self, key, **kwargs):
""" Create a pytables index on the table
def create_table_index(self, key: str, **kwargs):
"""
Create a pytables index on the table.

Parameters
----------
key : object (the node to index)
key : str

Raises
------
raises if the node is not a table

TypeError: raises if the node is not a table
"""

# version requirements
Expand Down Expand Up @@ -1247,17 +1257,19 @@ def walk(self, where="/"):

yield (g._v_pathname.rstrip("/"), groups, leaves)

def get_node(self, key):
def get_node(self, key: str):
""" return the node with the key or None if it does not exist """
self._check_if_open()
if not key.startswith("/"):
key = "/" + key

assert self._handle is not None
try:
if not key.startswith("/"):
key = "/" + key
return self._handle.get_node(self.root, key)
except _table_mod.exceptions.NoSuchNodeError:
except _table_mod.exceptions.NoSuchNodeError: # type: ignore
return None

def get_storer(self, key):
def get_storer(self, key: str):
""" return the storer object for a key, raise if not in the file """
group = self.get_node(key)
if group is None:
Expand Down Expand Up @@ -1481,7 +1493,7 @@ def error(t):

def _write_to_group(
self,
key,
key: str,
value,
format,
index=True,
Expand All @@ -1492,6 +1504,10 @@ def _write_to_group(
):
group = self.get_node(key)

# we make this assertion for mypy; the get_node call will already
# have raised if this is incorrect
assert self._handle is not None

# remove the node if we are not appending
if group is not None and not append:
self._handle.remove_node(group, recursive=True)
Expand Down Expand Up @@ -2691,7 +2707,7 @@ def f(values, freq=None, tz=None):

return klass

def validate_read(self, kwargs):
def validate_read(self, kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
remove table keywords from kwargs and return
raise if any keywords are passed which are not-None
Expand Down Expand Up @@ -2733,7 +2749,7 @@ def get_attrs(self):
def write(self, obj, **kwargs):
self.set_attrs()

def read_array(self, key, start=None, stop=None):
def read_array(self, key: str, start=None, stop=None):
""" read an array for the specified node (off of group """
import tables

Expand Down Expand Up @@ -4008,7 +4024,7 @@ def read_coordinates(self, where=None, start=None, stop=None, **kwargs):

return Index(coords)

def read_column(self, column, where=None, start=None, stop=None):
def read_column(self, column: str, where=None, start=None, stop=None):
"""return a single column from the table, generally only indexables
are interesting
"""
Expand Down Expand Up @@ -4642,8 +4658,8 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
converted,
"datetime64",
_tables().Int64Col(),
freq=getattr(index, "freq", None),
tz=getattr(index, "tz", None),
freq=index.freq,
tz=index.tz,
index_name=index_name,
)
elif isinstance(index, TimedeltaIndex):
Expand All @@ -4652,7 +4668,7 @@ def _convert_index(index, encoding=None, errors="strict", format_type=None):
converted,
"timedelta64",
_tables().Int64Col(),
freq=getattr(index, "freq", None),
freq=index.freq,
index_name=index_name,
)
elif isinstance(index, (Int64Index, PeriodIndex)):
Expand Down