3
3
# See the LICENSE file in the project root for more information
4
4
5
5
import pytest
6
- from mock import patch
6
+ from mock import patch , Mock
7
7
8
8
from elasticsearch import helpers , TransportError
9
9
from elasticsearch .helpers import ScanError
@@ -312,20 +312,29 @@ async def scan_fixture(async_client):
312
312
await async_client .clear_scroll (scroll_id = "_all" )
313
313
314
314
315
- class TestScan :
316
- mock_scroll_responses = [
317
- {
318
- "_scroll_id" : "dummy_id" ,
319
- "_shards" : {"successful" : 4 , "total" : 5 , "skipped" : 0 },
320
- "hits" : {"hits" : [{"scroll_data" : 42 }]},
321
- },
322
- {
323
- "_scroll_id" : "dummy_id" ,
324
- "_shards" : {"successful" : 4 , "total" : 5 , "skipped" : 0 },
325
- "hits" : {"hits" : []},
326
- },
327
- ]
315
+ class MockScroll :
316
+ def __init__ (self ):
317
+ self .i = 0
318
+ self .values = [
319
+ {
320
+ "_scroll_id" : "dummy_id" ,
321
+ "_shards" : {"successful" : 4 , "total" : 5 , "skipped" : 0 },
322
+ "hits" : {"hits" : [{"scroll_data" : 42 }]},
323
+ },
324
+ {
325
+ "_scroll_id" : "dummy_id" ,
326
+ "_shards" : {"successful" : 4 , "total" : 5 , "skipped" : 0 },
327
+ "hits" : {"hits" : []},
328
+ },
329
+ ]
330
+
331
+ async def scroll (self , * args , ** kwargs ):
332
+ val = self .values [self .i ]
333
+ self .i += 1
334
+ return val
328
335
336
+
337
+ class TestScan :
329
338
async def test_order_can_be_preserved (self , async_client , scan_fixture ):
330
339
bulk = []
331
340
for x in range (100 ):
@@ -373,7 +382,7 @@ async def test_scroll_error(self, async_client, scan_fixture):
373
382
await async_client .bulk (bulk , refresh = True )
374
383
375
384
with patch .object (async_client , "scroll" ) as scroll_mock :
376
- scroll_mock .side_effect = self . mock_scroll_responses
385
+ scroll_mock .side_effect = MockScroll (). scroll
377
386
data = [
378
387
doc
379
388
async for doc in (
@@ -389,7 +398,7 @@ async def test_scroll_error(self, async_client, scan_fixture):
389
398
assert len (data ) == 3
390
399
assert data [- 1 ] == {"scroll_data" : 42 }
391
400
392
- scroll_mock .side_effect = self . mock_scroll_responses
401
+ scroll_mock .side_effect = MockScroll (). scroll
393
402
with pytest .raises (ScanError ):
394
403
data = [
395
404
doc
@@ -406,54 +415,62 @@ async def test_scroll_error(self, async_client, scan_fixture):
406
415
assert len (data ) == 3
407
416
assert data [- 1 ] == {"scroll_data" : 42 }
408
417
409
- async def test_initial_search_error (self , async_client , scan_fixture ):
410
- with patch .object (self , "client" ) as client_mock :
411
- client_mock .search .return_value = {
418
+ async def test_initial_search_error (self ):
419
+ client_mock = Mock ()
420
+
421
+ async def search_mock (* _ , ** __ ):
422
+ return {
412
423
"_scroll_id" : "dummy_id" ,
413
424
"_shards" : {"successful" : 4 , "total" : 5 , "skipped" : 0 },
414
425
"hits" : {"hits" : [{"search_data" : 1 }]},
415
426
}
416
- client_mock .scroll .side_effect = self .mock_scroll_responses
417
427
428
+ async def clear_scroll (* _ , ** __ ):
429
+ return {}
430
+
431
+ client_mock .search = search_mock
432
+ client_mock .scroll = MockScroll ().scroll
433
+ client_mock .clear_scroll = clear_scroll
434
+
435
+ data = [
436
+ doc
437
+ async for doc in (
438
+ helpers .async_scan (
439
+ client_mock , index = "test_index" , size = 2 , raise_on_error = False
440
+ )
441
+ )
442
+ ]
443
+ assert data == [{"search_data" : 1 }, {"scroll_data" : 42 }]
444
+
445
+ client_mock .scroll = MockScroll ().scroll
446
+ with pytest .raises (ScanError ):
418
447
data = [
419
448
doc
420
449
async for doc in (
421
450
helpers .async_scan (
422
- async_client , index = "test_index" , size = 2 , raise_on_error = False
451
+ client_mock , index = "test_index" , size = 2 , raise_on_error = True ,
423
452
)
424
453
)
425
454
]
426
- assert data == [{"search_data" : 1 }, {"scroll_data" : 42 }]
455
+ assert data == [{"search_data" : 1 }]
456
+ scroll_mock .assert_not_called ()
427
457
428
- client_mock .scroll .side_effect = self .mock_scroll_responses
429
- with pytest .raises (ScanError ):
430
- data = [
431
- doc
432
- async for doc in (
433
- helpers .async_scan (
434
- async_client ,
435
- index = "test_index" ,
436
- size = 2 ,
437
- raise_on_error = True ,
438
- )
439
- )
440
- ]
441
- assert data == [{"search_data" : 1 }]
442
- client_mock .scroll .assert_not_called ()
458
+ async def test_no_scroll_id_fast_route (self ):
459
+ client_mock = Mock ()
443
460
444
- async def test_no_scroll_id_fast_route ( self , async_client , scan_fixture ):
445
- with patch . object ( self , "client" ) as client_mock :
446
- client_mock . search . return_value = { "no" : "_scroll_id" }
447
- data = [
448
- doc
449
- async for doc in (helpers .async_scan (async_client , index = "test_index" ))
450
- ]
461
+ async def search_mock ( * args , ** kwargs ):
462
+ return { "no" : "_scroll_id" }
463
+
464
+ client_mock . search = search_mock
465
+ data = [
466
+ doc async for doc in (helpers .async_scan (client_mock , index = "test_index" ))
467
+ ]
451
468
452
- assert data == []
453
- client_mock .scroll .assert_not_called ()
454
- client_mock .clear_scroll .assert_not_called ()
469
+ assert data == []
470
+ client_mock .scroll .assert_not_called ()
471
+ client_mock .clear_scroll .assert_not_called ()
455
472
456
- @patch ("elasticsearch.helpers.actions.logger" )
473
+ @patch ("elasticsearch._async. helpers.actions.logger" )
457
474
async def test_logger (self , logger_mock , async_client , scan_fixture ):
458
475
bulk = []
459
476
for x in range (4 ):
@@ -462,7 +479,7 @@ async def test_logger(self, logger_mock, async_client, scan_fixture):
462
479
await async_client .bulk (bulk , refresh = True )
463
480
464
481
with patch .object (async_client , "scroll" ) as scroll_mock :
465
- scroll_mock .side_effect = self . mock_scroll_responses
482
+ scroll_mock .side_effect = MockScroll (). scroll
466
483
_ = [
467
484
doc
468
485
async for doc in (
@@ -477,7 +494,7 @@ async def test_logger(self, logger_mock, async_client, scan_fixture):
477
494
]
478
495
logger_mock .warning .assert_called ()
479
496
480
- scroll_mock .side_effect = self . mock_scroll_responses
497
+ scroll_mock .side_effect = MockScroll (). scroll
481
498
try :
482
499
_ = [
483
500
doc
0 commit comments