Skip to content

[CLN] More cython cleanups, with bonus type annotations #22283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 8, 2018
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions pandas/_libs/algos_common_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_dispatch(dtypes):

@cython.wraparound(False)
@cython.boundscheck(False)
cpdef map_indices_{{name}}(ndarray[{{c_type}}] index):
def map_indices_{{name}}(ndarray[{{c_type}}] index):
"""
Produce a dict mapping the values of the input array to their respective
locations.
Expand All @@ -55,8 +55,9 @@ cpdef map_indices_{{name}}(ndarray[{{c_type}}] index):

Better to do this with Cython because of the enormous speed boost.
"""
cdef Py_ssize_t i, length
cdef dict result = {}
cdef:
Py_ssize_t i, length
dict result = {}

length = len(index)

Expand Down Expand Up @@ -541,7 +542,7 @@ def put2d_{{name}}_{{dest_type}}(ndarray[{{c_type}}, ndim=2, cast=True] values,
cdef int PLATFORM_INT = (<ndarray> np.arange(0, dtype=np.intp)).descr.type_num


cpdef ensure_platform_int(object arr):
def ensure_platform_int(object arr):
# GH3033, GH1392
# platform int is the size of the int pointer, e.g. np.intp
if util.is_array(arr):
Expand All @@ -553,7 +554,7 @@ cpdef ensure_platform_int(object arr):
return np.array(arr, dtype=np.intp)


cpdef ensure_object(object arr):
def ensure_object(object arr):
if util.is_array(arr):
if (<ndarray> arr).descr.type_num == NPY_OBJECT:
return arr
Expand Down Expand Up @@ -586,7 +587,7 @@ def get_dispatch(dtypes):

{{for name, c_type, dtype in get_dispatch(dtypes)}}

cpdef ensure_{{name}}(object arr, copy=True):
def ensure_{{name}}(object arr, copy=True):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these for sure not called in cython code?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-cimported only

if util.is_array(arr):
if (<ndarray> arr).descr.type_num == NPY_{{c_type}}:
return arr
Expand Down
1 change: 1 addition & 0 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ cdef inline float64_t median_linear(float64_t* a, int n) nogil:
return result


# TODO: Is this redundant with algos.kth_smallest?
cdef inline float64_t kth_smallest_c(float64_t* a,
Py_ssize_t k,
Py_ssize_t n) nogil:
Expand Down
1 change: 1 addition & 0 deletions pandas/_libs/hashing.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ cdef inline void _sipround(uint64_t* v0, uint64_t* v1,
v2[0] = _rotl(v2[0], 32)


# TODO: This appears unused; remove?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might be unused - it was a part of the hashing at one point iirc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, will remove in next pass.

cpdef uint64_t siphash(bytes data, bytes key) except? 0:
if len(key) != 16:
raise ValueError("key should be a 16-byte bytestring, "
Expand Down
2 changes: 1 addition & 1 deletion pandas/_libs/index.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ cpdef get_value_at(ndarray arr, object loc, object tz=None):
return util.get_value_at(arr, loc)


cpdef object get_value_box(ndarray arr, object loc):
def get_value_box(arr: ndarray, loc: object) -> object:
cdef:
Py_ssize_t i, sz

Expand Down
4 changes: 2 additions & 2 deletions pandas/_libs/internals.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ cdef class BlockPlacement:
return self._as_slice


cpdef slice_canonize(slice s):
cdef slice_canonize(slice s):
"""
Convert slice to canonical bounded form.
"""
Expand Down Expand Up @@ -255,7 +255,7 @@ cpdef Py_ssize_t slice_len(
return length


cpdef slice_get_indices_ex(slice slc, Py_ssize_t objlen=PY_SSIZE_T_MAX):
cdef slice_get_indices_ex(slice slc, Py_ssize_t objlen=PY_SSIZE_T_MAX):
"""
Get (start, stop, step, length) tuple for a slice.

Expand Down
5 changes: 3 additions & 2 deletions pandas/_libs/interval.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ cdef class Interval(IntervalMixin):

@cython.wraparound(False)
@cython.boundscheck(False)
cpdef intervals_to_interval_bounds(ndarray intervals,
bint validate_closed=True):
def intervals_to_interval_bounds(ndarray intervals,
bint validate_closed=True):
"""
Parameters
----------
Expand Down Expand Up @@ -415,4 +415,5 @@ cpdef intervals_to_interval_bounds(ndarray intervals,

return left, right, closed


include "intervaltree.pxi"
38 changes: 18 additions & 20 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def memory_usage_of_objects(object[:] arr):
# ----------------------------------------------------------------------


cpdef bint is_scalar(object val):
def is_scalar(val: object) -> bint:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this not ever called in cython?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, easy to confirm via grep

"""
Return True if given value is scalar.

Expand Down Expand Up @@ -137,7 +137,7 @@ cpdef bint is_scalar(object val):
or util.is_period_object(val)
or is_decimal(val)
or is_interval(val)
or is_offset(val))
or util.is_offset_object(val))


def item_from_zerodim(object val):
Expand Down Expand Up @@ -457,7 +457,7 @@ def maybe_booleans_to_slice(ndarray[uint8_t] mask):

@cython.wraparound(False)
@cython.boundscheck(False)
cpdef bint array_equivalent_object(object[:] left, object[:] right):
def array_equivalent_object(left: object[:], right: object[:]) -> bint:
""" perform an element by element comparion on 1-d object arrays
taking into account nan positions """
cdef:
Expand Down Expand Up @@ -499,7 +499,7 @@ def astype_intsafe(ndarray[object] arr, new_dtype):
return result


cpdef ndarray[object] astype_unicode(ndarray arr):
def astype_unicode(arr: ndarray) -> ndarray[object]:
cdef:
Py_ssize_t i, n = arr.size
ndarray[object] result = np.empty(n, dtype=object)
Expand All @@ -512,7 +512,7 @@ cpdef ndarray[object] astype_unicode(ndarray arr):
return result


cpdef ndarray[object] astype_str(ndarray arr):
def astype_str(arr: ndarray) -> ndarray[object]:
cdef:
Py_ssize_t i, n = arr.size
ndarray[object] result = np.empty(n, dtype=object)
Expand Down Expand Up @@ -797,19 +797,19 @@ def indices_fast(object index, ndarray[int64_t] labels, list keys,

# core.common import for fast inference checks

cpdef bint is_float(object obj):
def is_float(obj: object) -> bint:
return util.is_float_object(obj)


cpdef bint is_integer(object obj):
def is_integer(obj: object) -> bint:
return util.is_integer_object(obj)


cpdef bint is_bool(object obj):
def is_bool(obj: object) -> bint:
return util.is_bool_object(obj)


cpdef bint is_complex(object obj):
def is_complex(obj: object) -> bint:
return util.is_complex_object(obj)


Expand All @@ -821,15 +821,11 @@ cpdef bint is_interval(object obj):
return getattr(obj, '_typ', '_typ') == 'interval'


cpdef bint is_period(object val):
def is_period(val: object) -> bint:
""" Return a boolean if this is a Period object """
return util.is_period_object(val)


cdef inline bint is_offset(object val):
return getattr(val, '_typ', '_typ') == 'dateoffset'


_TYPE_MAP = {
'categorical': 'categorical',
'category': 'categorical',
Expand Down Expand Up @@ -1231,7 +1227,7 @@ def infer_dtype(object value, bint skipna=False):
if is_bytes_array(values, skipna=skipna):
return 'bytes'

elif is_period(val):
elif util.is_period_object(val):
if is_period_array(values):
return 'period'

Expand All @@ -1249,7 +1245,7 @@ def infer_dtype(object value, bint skipna=False):
return 'mixed'


cpdef object infer_datetimelike_array(object arr):
def infer_datetimelike_array(arr: object) -> object:
"""
infer if we have a datetime or timedelta array
- date: we have *only* date and maybe strings, nulls
Expand Down Expand Up @@ -1586,7 +1582,7 @@ cpdef bint is_datetime64_array(ndarray values):
return validator.validate(values)


cpdef bint is_datetime_with_singletz_array(ndarray values):
def is_datetime_with_singletz_array(values: ndarray) -> bint:
"""
Check values have the same tzinfo attribute.
Doesn't check values are datetime-like types.
Expand Down Expand Up @@ -1622,7 +1618,8 @@ cdef class TimedeltaValidator(TemporalValidator):
return is_null_timedelta64(value)


cpdef bint is_timedelta_array(ndarray values):
# TODO: Not used outside of tests; remove?
def is_timedelta_array(values: ndarray) -> bint:
cdef:
TimedeltaValidator validator = TimedeltaValidator(len(values),
skipna=True)
Expand All @@ -1634,7 +1631,8 @@ cdef class Timedelta64Validator(TimedeltaValidator):
return util.is_timedelta64_object(value)


cpdef bint is_timedelta64_array(ndarray values):
# TODO: Not used outside of tests; remove?
def is_timedelta64_array(values: ndarray) -> bint:
cdef:
Timedelta64Validator validator = Timedelta64Validator(len(values),
skipna=True)
Expand Down Expand Up @@ -1678,7 +1676,7 @@ cpdef bint is_time_array(ndarray values, bint skipna=False):

cdef class PeriodValidator(TemporalValidator):
cdef inline bint is_value_typed(self, object value) except -1:
return is_period(value)
return util.is_period_object(value)

cdef inline bint is_valid_null(self, object value) except -1:
return is_null_period(value)
Expand Down
41 changes: 22 additions & 19 deletions pandas/_libs/parsers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ cdef extern from "Python.h":

import numpy as np
cimport numpy as cnp
from numpy cimport ndarray, uint8_t, uint64_t, int64_t
from numpy cimport ndarray, uint8_t, uint64_t, int64_t, float64_t
cnp.import_array()

from util cimport UINT64_MAX, INT64_MAX, INT64_MIN
Expand Down Expand Up @@ -694,7 +694,7 @@ cdef class TextReader:
if ptr == NULL:
if not os.path.exists(source):
raise compat.FileNotFoundError(
'File %s does not exist' % source)
'File {source} does not exist'.format(source=source))
raise IOError('Initializing from file failed')

self.parser.source = ptr
Expand Down Expand Up @@ -772,9 +772,10 @@ cdef class TextReader:

if name == '':
if self.has_mi_columns:
name = 'Unnamed: %d_level_%d' % (i, level)
name = ('Unnamed: {i}_level_{lvl}'
.format(i=i, lvl=level))
else:
name = 'Unnamed: %d' % i
name = 'Unnamed: {i}'.format(i=i)
unnamed_count += 1

count = counts.get(name, 0)
Expand Down Expand Up @@ -849,8 +850,8 @@ cdef class TextReader:
# 'data has %d fields'
# % (passed_count, field_count))

if self.has_usecols and self.allow_leading_cols and \
not callable(self.usecols):
if (self.has_usecols and self.allow_leading_cols and
not callable(self.usecols)):
nuse = len(self.usecols)
if nuse == passed_count:
self.leading_cols = 0
Expand Down Expand Up @@ -1027,17 +1028,19 @@ cdef class TextReader:

if self.table_width - self.leading_cols > num_cols:
raise ParserError(
"Too many columns specified: expected %s and found %s" %
(self.table_width - self.leading_cols, num_cols))
"Too many columns specified: expected {expected} and "
"found {found}"
.format(expected=self.table_width - self.leading_cols,
found=num_cols))

results = {}
nused = 0
for i in range(self.table_width):
if i < self.leading_cols:
# Pass through leading columns always
name = i
elif self.usecols and not callable(self.usecols) and \
nused == len(self.usecols):
elif (self.usecols and not callable(self.usecols) and
nused == len(self.usecols)):
# Once we've gathered all requested columns, stop. GH5766
break
else:
Expand Down Expand Up @@ -1103,7 +1106,7 @@ cdef class TextReader:
col_res = _maybe_upcast(col_res)

if col_res is None:
raise ParserError('Unable to parse column %d' % i)
raise ParserError('Unable to parse column {i}'.format(i=i))

results[i] = col_res

Expand Down Expand Up @@ -1222,8 +1225,8 @@ cdef class TextReader:
elif dtype.kind == 'U':
width = dtype.itemsize
if width > 0:
raise TypeError("the dtype %s is not "
"supported for parsing" % dtype)
raise TypeError("the dtype {dtype} is not "
"supported for parsing".format(dtype=dtype))

# unicode variable width
return self._string_convert(i, start, end, na_filter,
Expand All @@ -1241,12 +1244,12 @@ cdef class TextReader:
return self._string_convert(i, start, end, na_filter,
na_hashset)
elif is_datetime64_dtype(dtype):
raise TypeError("the dtype %s is not supported "
raise TypeError("the dtype {dtype} is not supported "
"for parsing, pass this column "
"using parse_dates instead" % dtype)
"using parse_dates instead".format(dtype=dtype))
else:
raise TypeError("the dtype %s is not "
"supported for parsing" % dtype)
raise TypeError("the dtype {dtype} is not "
"supported for parsing".format(dtype=dtype))

cdef _string_convert(self, Py_ssize_t i, int64_t start, int64_t end,
bint na_filter, kh_str_t *na_hashset):
Expand Down Expand Up @@ -2058,7 +2061,7 @@ cdef kh_float64_t* kset_float64_from_list(values) except NULL:
khiter_t k
kh_float64_t *table
int ret = 0
cnp.float64_t val
float64_t val
object value

table = kh_init_float64()
Expand Down Expand Up @@ -2101,7 +2104,7 @@ cdef raise_parser_error(object base, parser_t *parser):
Py_XDECREF(type)
raise old_exc

message = '%s. C error: ' % base
message = '{base}. C error: '.format(base=base)
if parser.error_msg != NULL:
if PY3:
message += parser.error_msg.decode('utf-8')
Expand Down
2 changes: 1 addition & 1 deletion pandas/_libs/tslib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def format_array_from_datetime(ndarray[int64_t] values, object tz=None,
return result


cpdef array_with_unit_to_datetime(ndarray values, unit, errors='coerce'):
def array_with_unit_to_datetime(ndarray values, unit, errors='coerce'):
"""
convert the ndarray according to the unit
if errors:
Expand Down
Loading