-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: merge_asof() has type specializations and can take multiple 'by' parameters (#13936) #14783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
75157fc
f01142c
46cc309
c33c4cb
5eeb7d9
fafbb02
2bce3cc
0ad1687
89256f0
77eb47b
1f208a8
ffcf0c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
# cython: boundscheck=False, wraparound=False | ||
""" | ||
Template for each `dtype` helper function for hashtable | ||
|
||
|
@@ -10,18 +11,25 @@ WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in | |
|
||
{{py: | ||
|
||
# table_type, by_dtype | ||
by_dtypes = [('PyObjectHashTable', 'object'), ('Int64HashTable', 'int64_t')] | ||
# by_dtype, table_type, init_table, s1, s2, s3, g1, g2 | ||
by_dtypes = [('int64_t', 'Int64HashTable', 'Int64HashTable(right_size)', | ||
'.set_item(', ', ', ')', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are you changing this for object table? it makes the code much more complicated. We don't use python objects anywhere in cython (instead we use the PyObjectHashTable). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
I stumbled on this when I tried the I did end-up adding the |
||
'.get_item(', ')'), | ||
('object', 'dict', '{}', | ||
'[', '] = ', '', | ||
'[', ']')] | ||
|
||
# on_dtype | ||
on_dtypes = ['int64_t', 'double'] | ||
on_dtypes = ['uint8_t', 'uint16_t', 'uint32_t', 'uint64_t', | ||
'int8_t', 'int16_t', 'int32_t', 'int64_t', | ||
'float', 'double'] | ||
|
||
}} | ||
|
||
|
||
from hashtable cimport * | ||
|
||
{{for table_type, by_dtype in by_dtypes}} | ||
{{for by_dtype, table_type, init_table, s1, s2, s3, g1, g2 in by_dtypes}} | ||
{{for on_dtype in on_dtypes}} | ||
|
||
|
||
|
@@ -51,7 +59,7 @@ def asof_join_{{on_dtype}}_by_{{by_dtype}}(ndarray[{{on_dtype}}] left_values, | |
left_indexer = np.empty(left_size, dtype=np.int64) | ||
right_indexer = np.empty(left_size, dtype=np.int64) | ||
|
||
hash_table = {{table_type}}(right_size) | ||
hash_table = {{init_table}} | ||
|
||
right_pos = 0 | ||
for left_pos in range(left_size): | ||
|
@@ -63,18 +71,18 @@ def asof_join_{{on_dtype}}_by_{{by_dtype}}(ndarray[{{on_dtype}}] left_values, | |
if allow_exact_matches: | ||
while right_pos < right_size and\ | ||
right_values[right_pos] <= left_values[left_pos]: | ||
hash_table.set_item(right_by_values[right_pos], right_pos) | ||
hash_table{{s1}}right_by_values[right_pos]{{s2}}right_pos{{s3}} | ||
right_pos += 1 | ||
else: | ||
while right_pos < right_size and\ | ||
right_values[right_pos] < left_values[left_pos]: | ||
hash_table.set_item(right_by_values[right_pos], right_pos) | ||
hash_table{{s1}}right_by_values[right_pos]{{s2}}right_pos{{s3}} | ||
right_pos += 1 | ||
right_pos -= 1 | ||
|
||
# save positions as the desired index | ||
by_value = left_by_values[left_pos] | ||
found_right_pos = hash_table.get_item(by_value)\ | ||
found_right_pos = hash_table{{g1}}by_value{{g2}}\ | ||
if by_value in hash_table else -1 | ||
left_indexer[left_pos] = left_pos | ||
right_indexer[left_pos] = found_right_pos | ||
|
@@ -98,7 +106,9 @@ def asof_join_{{on_dtype}}_by_{{by_dtype}}(ndarray[{{on_dtype}}] left_values, | |
{{py: | ||
|
||
# on_dtype | ||
dtypes = ['int64_t', 'double'] | ||
dtypes = ['uint8_t', 'uint16_t', 'uint32_t', 'uint64_t', | ||
'int8_t', 'int16_t', 'int32_t', 'int64_t', | ||
'float', 'double'] | ||
|
||
}} | ||
|
||
|
@@ -158,3 +168,21 @@ def asof_join_{{on_dtype}}(ndarray[{{on_dtype}}] left_values, | |
|
||
{{endfor}} | ||
|
||
|
||
#---------------------------------------------------------------------- | ||
# stringify | ||
#---------------------------------------------------------------------- | ||
|
||
def stringify(ndarray[object, ndim=2] xt): | ||
cdef: | ||
Py_ssize_t n | ||
ndarray[object] result | ||
|
||
n = len(xt) | ||
result = np.empty(n, dtype=np.object) | ||
|
||
for i in range(n): | ||
result[i] = xt[i].tostring() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you need this again? (I see you are using it), but what is the input that you are giving it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe a doc-string would help There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added a couple comments to address this. When the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh if all u need is hashing we just added this: https://github.com/pandas-dev/pandas/blob/master/pandas/tools/hashing.py There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh if all u need is hashing we just added this: https://github.com/pandas-dev/pandas/blob/master/pandas/tools/hashing.py |
||
|
||
return result | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,7 +28,8 @@ | |
is_list_like, | ||
_ensure_int64, | ||
_ensure_float64, | ||
_ensure_object) | ||
_ensure_object, | ||
_get_dtype) | ||
from pandas.types.missing import na_value_for_dtype | ||
|
||
from pandas.core.generic import NDFrame | ||
|
@@ -270,8 +271,8 @@ def merge_asof(left, right, on=None, | |
DataFrame whose 'on' key is less than or equal to the left's key. Both | ||
DataFrames must be sorted by the key. | ||
|
||
Optionally perform group-wise merge. This searches for the nearest match | ||
on the 'on' key within the same group according to 'by'. | ||
Optionally match on equivalent keys with 'by' before searching for nearest | ||
match with 'on'. | ||
|
||
.. versionadded:: 0.19.0 | ||
|
||
|
@@ -288,9 +289,8 @@ def merge_asof(left, right, on=None, | |
Field name to join on in left DataFrame. | ||
right_on : label | ||
Field name to join on in right DataFrame. | ||
by : column name | ||
Group both the left and right DataFrames by the group column; perform | ||
the merge operation on these pieces and recombine. | ||
by : column name or list of column names | ||
Match on these columns before performing merge operation. | ||
suffixes : 2-length sequence (tuple, list, ...) | ||
Suffix to apply to overlapping column names in the left and right | ||
side, respectively | ||
|
@@ -926,27 +926,44 @@ def get_result(self): | |
return result | ||
|
||
|
||
_asof_functions = { | ||
'int64_t': _join.asof_join_int64_t, | ||
'double': _join.asof_join_double, | ||
} | ||
def _asof_function(on_type): | ||
return getattr(_join, 'asof_join_%s' % on_type, None) | ||
|
||
|
||
def _asof_by_function(on_type, by_type): | ||
return getattr(_join, 'asof_join_%s_by_%s' % (on_type, by_type), None) | ||
|
||
_asof_by_functions = { | ||
('int64_t', 'int64_t'): _join.asof_join_int64_t_by_int64_t, | ||
('double', 'int64_t'): _join.asof_join_double_by_int64_t, | ||
('int64_t', 'object'): _join.asof_join_int64_t_by_object, | ||
('double', 'object'): _join.asof_join_double_by_object, | ||
} | ||
|
||
_type_casters = { | ||
'int64_t': _ensure_int64, | ||
'double': _ensure_float64, | ||
'object': _ensure_object, | ||
} | ||
|
||
_cyton_types = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. _cython_types ? |
||
'uint8': 'uint8_t', | ||
'uint32': 'uint32_t', | ||
'uint16': 'uint16_t', | ||
'uint64': 'uint64_t', | ||
'int8': 'int8_t', | ||
'int32': 'int32_t', | ||
'int16': 'int16_t', | ||
'int64': 'int64_t', | ||
'float16': 'float', | ||
'float32': 'float', | ||
'float64': 'double', | ||
} | ||
|
||
|
||
def _get_cython_type(dtype): | ||
""" Given a dtype, return 'int64_t', 'double', or 'object' """ | ||
""" Given a dtype, return a C name like 'int64_t' or 'double' """ | ||
type_name = _get_dtype(dtype).name | ||
ctype = _cyton_types.get(type_name, 'object') | ||
return ctype | ||
|
||
|
||
def _get_cython_type_upcast(dtype): | ||
""" Upcast a dtype to 'int64_t', 'double', or 'object' """ | ||
if is_integer_dtype(dtype): | ||
return 'int64_t' | ||
elif is_float_dtype(dtype): | ||
|
@@ -990,9 +1007,6 @@ def _validate_specification(self): | |
if not is_list_like(self.by): | ||
self.by = [self.by] | ||
|
||
if len(self.by) != 1: | ||
raise MergeError("can only asof by a single key") | ||
|
||
self.left_on = self.by + list(self.left_on) | ||
self.right_on = self.by + list(self.right_on) | ||
|
||
|
@@ -1046,6 +1060,11 @@ def _get_merge_keys(self): | |
def _get_join_indexers(self): | ||
""" return the join indexers """ | ||
|
||
def flip_stringify(xs): | ||
""" flip an array of arrays and string-ify contents """ | ||
xt = np.transpose(xs) | ||
return _join.stringify(_ensure_object(xt)) | ||
|
||
# values to compare | ||
left_values = self.left_join_keys[-1] | ||
right_values = self.right_join_keys[-1] | ||
|
@@ -1067,22 +1086,23 @@ def _get_join_indexers(self): | |
|
||
# a "by" parameter requires special handling | ||
if self.by is not None: | ||
left_by_values = self.left_join_keys[0] | ||
right_by_values = self.right_join_keys[0] | ||
|
||
# choose appropriate function by type | ||
on_type = _get_cython_type(left_values.dtype) | ||
by_type = _get_cython_type(left_by_values.dtype) | ||
if len(self.left_join_keys) > 2: | ||
# get string representation of values if more than one | ||
left_by_values = flip_stringify(self.left_join_keys[0:-1]) | ||
right_by_values = flip_stringify(self.right_join_keys[0:-1]) | ||
else: | ||
left_by_values = self.left_join_keys[0] | ||
right_by_values = self.right_join_keys[0] | ||
|
||
on_type_caster = _type_casters[on_type] | ||
# upcast 'by' parameter because HashTable is limited | ||
by_type = _get_cython_type_upcast(left_by_values.dtype) | ||
by_type_caster = _type_casters[by_type] | ||
func = _asof_by_functions[(on_type, by_type)] | ||
|
||
left_values = on_type_caster(left_values) | ||
right_values = on_type_caster(right_values) | ||
left_by_values = by_type_caster(left_by_values) | ||
right_by_values = by_type_caster(right_by_values) | ||
|
||
# choose appropriate function by type | ||
on_type = _get_cython_type(left_values.dtype) | ||
func = _asof_by_function(on_type, by_type) | ||
return func(left_values, | ||
right_values, | ||
left_by_values, | ||
|
@@ -1092,12 +1112,7 @@ def _get_join_indexers(self): | |
else: | ||
# choose appropriate function by type | ||
on_type = _get_cython_type(left_values.dtype) | ||
type_caster = _type_casters[on_type] | ||
func = _asof_functions[on_type] | ||
|
||
left_values = type_caster(left_values) | ||
right_values = type_caster(right_values) | ||
|
||
func = _asof_function(on_type) | ||
return func(left_values, | ||
right_values, | ||
self.allow_exact_matches, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
specialized dtypes, and elabortae on what this does (e.g. perf)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's to improve performance rather than cast all integer types to int64 (see my benchmarks pasted below). I can add a description to whatsnew line.