Skip to content

Commit eb23f47

Browse files
authored
Merge pull request pandas-dev#608 from manahl/avoid_mongodb_in_memory_sort
Ensure Arctic performs well with MongoDB 3.6+
2 parents 3ed3c7f + 7db7278 commit eb23f47

File tree

4 files changed

+31
-25
lines changed

4 files changed

+31
-25
lines changed

arctic/store/_ndarray_store.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import hashlib
22
import logging
33
import os
4+
from operator import itemgetter
45

56

67
from bson.binary import Binary
@@ -269,7 +270,7 @@ def _do_read(self, collection, version, symbol, index_range=None):
269270

270271
data = bytearray()
271272
i = -1
272-
for i, x in enumerate(collection.find(spec, sort=[('segment', pymongo.ASCENDING)],)):
273+
for i, x in enumerate(sorted(collection.find(spec), key=itemgetter('segment'))):
273274
data.extend(decompress(x['data']) if x['compressed'] else x['data'])
274275

275276
# Check that the correct number of segments has been returned
@@ -409,11 +410,8 @@ def _concat_and_rewrite(self, collection, version, symbol, item, previous_versio
409410
read_index_range = [0, None]
410411
# The unchanged segments are the compressed ones (apart from the last compressed)
411412
unchanged_segment_ids = []
412-
for segment in collection.find(spec, projection={'_id':1,
413-
'segment':1,
414-
'compressed': 1
415-
},
416-
sort=[('segment', pymongo.ASCENDING)]):
413+
for segment in sorted(collection.find(spec, projection={'_id': 1, 'segment': 1, 'compressed': 1}),
414+
key=itemgetter('segment')):
417415
# We want to stop iterating when we find the first uncompressed chunks
418416
if not segment['compressed']:
419417
# We include the last compressed chunk in the recompression

arctic/store/_pickle_store.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from bson.binary import Binary
44
from bson.errors import InvalidDocument
5+
from operator import itemgetter
56
from six.moves import cPickle, xrange
67
import io
78
from .._compression import decompress, compress_array
@@ -37,14 +38,14 @@ def read(self, mongoose_lib, version, symbol, **kwargs):
3738
if blob is not None:
3839
if blob == _MAGIC_CHUNKEDV2:
3940
collection = mongoose_lib.get_top_level_collection()
40-
data = b''.join(decompress(x['data']) for x in collection.find({'symbol': symbol,
41-
'parent': version_base_or_id(version)},
42-
sort=[('segment', pymongo.ASCENDING)]))
41+
data = b''.join(decompress(x['data']) for x in sorted(
42+
collection.find({'symbol': symbol, 'parent': version_base_or_id(version)}),
43+
key=itemgetter('segment')))
4344
elif blob == _MAGIC_CHUNKED:
4445
collection = mongoose_lib.get_top_level_collection()
45-
data = b''.join(x['data'] for x in collection.find({'symbol': symbol,
46-
'parent': version_base_or_id(version)},
47-
sort=[('segment', pymongo.ASCENDING)]))
46+
data = b''.join(x['data'] for x in sorted(
47+
collection.find({'symbol': symbol, 'parent': version_base_or_id(version)}),
48+
key=itemgetter('segment')))
4849
data = decompress(data)
4950
else:
5051
if blob[:len(_MAGIC_CHUNKED)] == _MAGIC_CHUNKED:

tests/unit/store/test_ndarray_store.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,8 @@ def test_concat_and_rewrite_checks_chunk_count():
8585
symbol = sentinel.symbol
8686
item = sentinel.item
8787

88-
collection.find.return_value = [{'compressed': True},
89-
{'compressed': False}]
88+
collection.find.return_value = [{'compressed': True, 'segment': 1},
89+
{'compressed': False, 'segment': 2}]
9090
with pytest.raises(DataIntegrityException) as e:
9191
NdarrayStore._concat_and_rewrite(self, collection, version, symbol, item, previous_version)
9292
assert str(e.value) == 'Symbol: sentinel.symbol:sentinel.version expected 1 segments but found 0'
@@ -108,9 +108,11 @@ def test_concat_and_rewrite_checks_written():
108108

109109
collection.find.return_value = [{'_id': sentinel.id,
110110
'segment': 47, 'compressed': True},
111-
{'compressed': True},
111+
{'_id': sentinel.id_2, 'segment': 48, 'compressed': True},
112112
# 3 appended items
113-
{'compressed': False}, {'compressed': False}, {'compressed': False}]
113+
{'_id': sentinel.id_3, 'segment': 49, 'compressed': False},
114+
{'_id': sentinel.id_4, 'segment': 50, 'compressed': False},
115+
{'_id': sentinel.id_5, 'segment': 51, 'compressed': False}]
114116
collection.update_many.return_value = create_autospec(UpdateResult, matched_count=1)
115117
NdarrayStore._concat_and_rewrite(self, collection, version, symbol, item, previous_version)
116118
assert self.check_written.call_count == 1
@@ -131,8 +133,11 @@ def test_concat_and_rewrite_checks_different_id():
131133
item = []
132134

133135
collection.find.side_effect = [
134-
[{'_id': sentinel.id, 'segment' : 47, 'compressed': True}, {'compressed': True},
135-
{'compressed': False}, {'compressed': False}, {'compressed': False}], # 3 appended items
136+
[{'_id': sentinel.id, 'segment' : 47, 'compressed': True},
137+
{'_id': sentinel.id_3, 'segment': 48, 'compressed': True},
138+
{'_id': sentinel.id_4, 'segment': 49, 'compressed': False},
139+
{'_id': sentinel.id_5, 'segment': 50, 'compressed': False},
140+
{'_id': sentinel.id_6, 'segment': 51, 'compressed': False}], # 3 appended items
136141
[{'_id': sentinel.id_2}] # the returned id is different after the update_many
137142
]
138143

@@ -163,9 +168,9 @@ def test_concat_and_rewrite_checks_fewer_updated():
163168
[{'_id': sentinel.id_1, 'segment': 47, 'compressed': True},
164169
{'_id': sentinel.id_2, 'segment': 48, 'compressed': True},
165170
{'_id': sentinel.id_3, 'segment': 49, 'compressed': True},
166-
{'compressed': False},
167-
{'compressed': False},
168-
{'compressed': False}], # 3 appended items
171+
{'_id': sentinel.id_4, 'segment': 50, 'compressed': False},
172+
{'_id': sentinel.id_5, 'segment': 51, 'compressed': False},
173+
{'_id': sentinel.id_6, 'segment': 52, 'compressed': False}], # 3 appended items
169174
[{'_id': sentinel.id_1}] # the returned id is different after the update_many
170175
]
171176

tests/unit/store/test_pickle_store.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,13 @@ def test_read_object_2():
5858
coll = Mock()
5959
arctic_lib = Mock()
6060
coll.find.return_value = [{'data': Binary(compressHC(cPickle.dumps(object))),
61-
'symbol': 'sentinel.symbol'}
61+
'symbol': 'sentinel.symbol',
62+
'segment': 1}
6263
]
6364
arctic_lib.get_top_level_collection.return_value = coll
6465

6566
assert PickleStore.read(self, arctic_lib, version, sentinel.symbol) == object
66-
assert coll.find.call_args_list == [call({'symbol': sentinel.symbol, 'parent': sentinel._id}, sort=[('segment', 1)])]
67+
assert coll.find.call_args_list == [call({'symbol': sentinel.symbol, 'parent': sentinel._id})]
6768

6869

6970
def test_read_with_base_version_id():
@@ -74,12 +75,13 @@ def test_read_with_base_version_id():
7475
coll = Mock()
7576
arctic_lib = Mock()
7677
coll.find.return_value = [{'data': Binary(compressHC(cPickle.dumps(object))),
77-
'symbol': 'sentinel.symbol'}
78+
'symbol': 'sentinel.symbol',
79+
'segment': 1}
7880
]
7981
arctic_lib.get_top_level_collection.return_value = coll
8082

8183
assert PickleStore.read(self, arctic_lib, version, sentinel.symbol) == object
82-
assert coll.find.call_args_list == [call({'symbol': sentinel.symbol, 'parent': sentinel.base_version_id}, sort=[('segment', 1)])]
84+
assert coll.find.call_args_list == [call({'symbol': sentinel.symbol, 'parent': sentinel.base_version_id})]
8385

8486

8587
@pytest.mark.xfail(sys.version_info >= (3,),

0 commit comments

Comments
 (0)