|
| 1 | +from contextlib import contextmanager |
| 2 | +import tracemalloc |
| 3 | + |
1 | 4 | import numpy as np
|
2 | 5 | import pytest
|
3 | 6 |
|
|
6 | 9 | import pandas._testing as tm
|
7 | 10 |
|
8 | 11 |
|
| 12 | +@contextmanager |
| 13 | +def activated_tracemalloc(): |
| 14 | + tracemalloc.start() |
| 15 | + try: |
| 16 | + yield |
| 17 | + finally: |
| 18 | + tracemalloc.stop() |
| 19 | + |
| 20 | + |
| 21 | +def get_allocated_khash_memory(): |
| 22 | + snapshot = tracemalloc.take_snapshot() |
| 23 | + snapshot = snapshot.filter_traces( |
| 24 | + (tracemalloc.DomainFilter(True, ht.get_hashtable_trace_domain()),) |
| 25 | + ) |
| 26 | + return sum(map(lambda x: x.size, snapshot.traces)) |
| 27 | + |
| 28 | + |
9 | 29 | @pytest.mark.parametrize(
|
10 | 30 | "table_type, dtype",
|
11 | 31 | [
|
| 32 | + (ht.PyObjectHashTable, np.object_), |
12 | 33 | (ht.Int64HashTable, np.int64),
|
13 | 34 | (ht.UInt64HashTable, np.uint64),
|
14 | 35 | (ht.Float64HashTable, np.float64),
|
@@ -53,13 +74,15 @@ def test_get_set_contains_len(self, table_type, dtype):
|
53 | 74 | assert str(index + 2) in str(excinfo.value)
|
54 | 75 |
|
55 | 76 | def test_map(self, table_type, dtype):
|
56 |
| - N = 77 |
57 |
| - table = table_type() |
58 |
| - keys = np.arange(N).astype(dtype) |
59 |
| - vals = np.arange(N).astype(np.int64) + N |
60 |
| - table.map(keys, vals) |
61 |
| - for i in range(N): |
62 |
| - assert table.get_item(keys[i]) == i + N |
| 77 | + # PyObjectHashTable has no map-method |
| 78 | + if table_type != ht.PyObjectHashTable: |
| 79 | + N = 77 |
| 80 | + table = table_type() |
| 81 | + keys = np.arange(N).astype(dtype) |
| 82 | + vals = np.arange(N).astype(np.int64) + N |
| 83 | + table.map(keys, vals) |
| 84 | + for i in range(N): |
| 85 | + assert table.get_item(keys[i]) == i + N |
63 | 86 |
|
64 | 87 | def test_map_locations(self, table_type, dtype):
|
65 | 88 | N = 8
|
@@ -101,6 +124,53 @@ def test_unique(self, table_type, dtype):
|
101 | 124 | unique = table.unique(keys)
|
102 | 125 | tm.assert_numpy_array_equal(unique, expected)
|
103 | 126 |
|
| 127 | + def test_tracemalloc_works(self, table_type, dtype): |
| 128 | + if dtype in (np.int8, np.uint8): |
| 129 | + N = 256 |
| 130 | + else: |
| 131 | + N = 30000 |
| 132 | + keys = np.arange(N).astype(dtype) |
| 133 | + with activated_tracemalloc(): |
| 134 | + table = table_type() |
| 135 | + table.map_locations(keys) |
| 136 | + used = get_allocated_khash_memory() |
| 137 | + my_size = table.sizeof() |
| 138 | + assert used == my_size |
| 139 | + del table |
| 140 | + assert get_allocated_khash_memory() == 0 |
| 141 | + |
| 142 | + def test_tracemalloc_for_empty(self, table_type, dtype): |
| 143 | + with activated_tracemalloc(): |
| 144 | + table = table_type() |
| 145 | + used = get_allocated_khash_memory() |
| 146 | + my_size = table.sizeof() |
| 147 | + assert used == my_size |
| 148 | + del table |
| 149 | + assert get_allocated_khash_memory() == 0 |
| 150 | + |
| 151 | + |
| 152 | +def test_tracemalloc_works_for_StringHashTable(): |
| 153 | + N = 1000 |
| 154 | + keys = np.arange(N).astype(np.compat.unicode).astype(np.object_) |
| 155 | + with activated_tracemalloc(): |
| 156 | + table = ht.StringHashTable() |
| 157 | + table.map_locations(keys) |
| 158 | + used = get_allocated_khash_memory() |
| 159 | + my_size = table.sizeof() |
| 160 | + assert used == my_size |
| 161 | + del table |
| 162 | + assert get_allocated_khash_memory() == 0 |
| 163 | + |
| 164 | + |
| 165 | +def test_tracemalloc_for_empty_StringHashTable(): |
| 166 | + with activated_tracemalloc(): |
| 167 | + table = ht.StringHashTable() |
| 168 | + used = get_allocated_khash_memory() |
| 169 | + my_size = table.sizeof() |
| 170 | + assert used == my_size |
| 171 | + del table |
| 172 | + assert get_allocated_khash_memory() == 0 |
| 173 | + |
104 | 174 |
|
105 | 175 | @pytest.mark.parametrize(
|
106 | 176 | "table_type, dtype",
|
@@ -157,6 +227,7 @@ def get_ht_function(fun_name, type_suffix):
|
157 | 227 | @pytest.mark.parametrize(
|
158 | 228 | "dtype, type_suffix",
|
159 | 229 | [
|
| 230 | + (np.object_, "object"), |
160 | 231 | (np.int64, "int64"),
|
161 | 232 | (np.uint64, "uint64"),
|
162 | 233 | (np.float64, "float64"),
|
|
0 commit comments