|
3 | 3 | # distutils: language = c++
|
4 | 4 | # emd.pyx
|
5 | 5 |
|
| 6 | +from pkg_resources import parse_version |
| 7 | + |
6 | 8 | from libcpp.pair cimport pair
|
7 | 9 | from libcpp.vector cimport vector
|
8 | 10 | import cython
|
@@ -139,6 +141,16 @@ def euclidean_pairwise_distance_matrix(x):
|
139 | 141 | return distance_matrix.reshape(len(x), len(x))
|
140 | 142 |
|
141 | 143 |
|
| 144 | +# Use `np.histogram_bin_edges` if available (since NumPy version 1.15.0) |
| 145 | +if parse_version(np.__version__) >= parse_version('1.15.0'): |
| 146 | + get_bins = np.histogram_bin_edges |
| 147 | +else: |
| 148 | + def get_bins(a, bins=10, **kwargs): |
| 149 | + if isinstance(bins, str): |
| 150 | + hist, bins = np.histogram(a, bins=bins, **kwargs) |
| 151 | + return bins |
| 152 | + |
| 153 | + |
142 | 154 | def emd_samples(first_array,
|
143 | 155 | second_array,
|
144 | 156 | extra_mass_penalty=DEFAULT_EXTRA_MASS_PENALTY,
|
@@ -196,14 +208,10 @@ def emd_samples(first_array,
|
196 | 208 | if range is None:
|
197 | 209 | range = (min(np.min(first_array), np.min(second_array)),
|
198 | 210 | max(np.max(first_array), np.max(second_array)))
|
199 |
| - # Use automatic binning from `np.histogram()` |
200 |
| - # TODO: Use `np.histogram_bin_edges()` when it's available; |
201 |
| - # see https://github.com/numpy/numpy/issues/10183 |
202 |
| - if isinstance(bins, str): |
203 |
| - hist, _ = np.histogram(np.concatenate([first_array, second_array]), |
204 |
| - range=range, |
205 |
| - bins=bins) |
206 |
| - bins = len(hist) |
| 211 | + # Get bin edges using both arrays |
| 212 | + bins = get_bins(np.concatenate([first_array, second_array]), |
| 213 | + range=range, |
| 214 | + bins=bins) |
207 | 215 | # Compute histograms
|
208 | 216 | first_histogram, bin_edges = np.histogram(first_array,
|
209 | 217 | range=range,
|
|
0 commit comments