From 87158d49b24c1b4dc1379d07c04581c906584afa Mon Sep 17 00:00:00 2001 From: Will Ayd Date: Sat, 2 Mar 2024 09:39:04 -0800 Subject: [PATCH] Cython guard against [c|m|re]alloc failures --- pandas/_libs/algos.pyx | 2 ++ pandas/_libs/groupby.pyx | 4 ++++ pandas/_libs/hashing.pyx | 4 ++++ pandas/_libs/hashtable_class_helper.pxi.in | 16 ++++++++++++---- pandas/_libs/sas.pyx | 2 +- pandas/_libs/tslibs/period.pyx | 2 ++ 6 files changed, 25 insertions(+), 5 deletions(-) diff --git a/pandas/_libs/algos.pyx b/pandas/_libs/algos.pyx index e70ac26a2c28e..e2e93c5242b24 100644 --- a/pandas/_libs/algos.pyx +++ b/pandas/_libs/algos.pyx @@ -180,6 +180,8 @@ def is_lexsorted(list_of_arrays: list) -> bool: n = len(list_of_arrays[0]) cdef int64_t **vecs = malloc(nlevels * sizeof(int64_t*)) + if vecs is NULL: + raise MemoryError() for i in range(nlevels): arr = list_of_arrays[i] assert arr.dtype.name == "int64" diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index 391bb4a3a3fd3..2ff45038d6a3e 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -81,6 +81,8 @@ cdef float64_t median_linear_mask(float64_t* a, int n, uint8_t* mask) noexcept n return NaN tmp = malloc((n - na_count) * sizeof(float64_t)) + if tmp is NULL: + raise MemoryError() j = 0 for i in range(n): @@ -118,6 +120,8 @@ cdef float64_t median_linear(float64_t* a, int n) noexcept nogil: return NaN tmp = malloc((n - na_count) * sizeof(float64_t)) + if tmp is NULL: + raise MemoryError() j = 0 for i in range(n): diff --git a/pandas/_libs/hashing.pyx b/pandas/_libs/hashing.pyx index be6958e3315e9..8b424e96973d3 100644 --- a/pandas/_libs/hashing.pyx +++ b/pandas/_libs/hashing.pyx @@ -68,7 +68,11 @@ def hash_object_array( # create an array of bytes vecs = malloc(n * sizeof(char *)) + if vecs is NULL: + raise MemoryError() lens = malloc(n * sizeof(uint64_t)) + if lens is NULL: + raise MemoryError() for i in range(n): val = arr[i] diff --git a/pandas/_libs/hashtable_class_helper.pxi.in b/pandas/_libs/hashtable_class_helper.pxi.in index 26dcf0b6c4ce3..3ec8f5dd04993 100644 --- a/pandas/_libs/hashtable_class_helper.pxi.in +++ b/pandas/_libs/hashtable_class_helper.pxi.in @@ -211,7 +211,7 @@ cdef class {{name}}Vector(Vector): def __cinit__(self): self.data = <{{name}}VectorData *>PyMem_Malloc( sizeof({{name}}VectorData)) - if not self.data: + if self.data is NULL: raise MemoryError() self.data.n = 0 self.data.m = _INIT_VEC_CAP @@ -264,12 +264,12 @@ cdef class StringVector(Vector): def __cinit__(self): self.data = PyMem_Malloc(sizeof(StringVectorData)) - if not self.data: + if self.data is NULL: raise MemoryError() self.data.n = 0 self.data.m = _INIT_VEC_CAP self.data.data = malloc(self.data.m * sizeof(char *)) - if not self.data.data: + if self.data.data is NULL: raise MemoryError() cdef resize(self): @@ -282,7 +282,7 @@ cdef class StringVector(Vector): orig_data = self.data.data self.data.data = malloc(self.data.m * sizeof(char *)) - if not self.data.data: + if self.data.data is NULL: raise MemoryError() for i in range(m): self.data.data[i] = orig_data[i] @@ -991,6 +991,8 @@ cdef class StringHashTable(HashTable): const char **vecs vecs = malloc(n * sizeof(char *)) + if vecs is NULL: + raise MemoryError() for i in range(n): val = values[i] v = get_c_string(val) @@ -1021,6 +1023,8 @@ cdef class StringHashTable(HashTable): # these by-definition *must* be strings vecs = malloc(n * sizeof(char *)) + if vecs is NULL: + raise MemoryError() for i in range(n): val = values[i] @@ -1057,6 +1061,8 @@ cdef class StringHashTable(HashTable): # these by-definition *must* be strings vecs = malloc(n * sizeof(char *)) + if vecs is NULL: + raise MemoryError() for i in range(n): val = values[i] @@ -1132,6 +1138,8 @@ cdef class StringHashTable(HashTable): # assign pointers and pre-filter out missing (if ignore_na) vecs = malloc(n * sizeof(char *)) + if vecs is NULL: + raise MemoryError() for i in range(n): val = values[i] diff --git a/pandas/_libs/sas.pyx b/pandas/_libs/sas.pyx index 9e1af2cb9c3e7..209e82c6284f5 100644 --- a/pandas/_libs/sas.pyx +++ b/pandas/_libs/sas.pyx @@ -49,7 +49,7 @@ cdef bytes buf_as_bytes(Buffer buf, size_t offset, size_t length): cdef Buffer buf_new(size_t length) except *: cdef uint8_t *data = calloc(length, sizeof(uint8_t)) - if data == NULL: + if data is NULL: raise MemoryError(f"Failed to allocate {length} bytes") return Buffer(data, length) diff --git a/pandas/_libs/tslibs/period.pyx b/pandas/_libs/tslibs/period.pyx index 3da0fa182faf3..838b5b9f4595f 100644 --- a/pandas/_libs/tslibs/period.pyx +++ b/pandas/_libs/tslibs/period.pyx @@ -679,6 +679,8 @@ cdef char* c_strftime(npy_datetimestruct *dts, char *fmt): c_date.tm_isdst = -1 result = malloc(result_len * sizeof(char)) + if result is NULL: + raise MemoryError() strftime(result, result_len, fmt, &c_date)