Skip to content

Commit b8668bb

Browse files
Merge pull request pandas-dev#13 from manahl/hooks_and_backports
Fixed the hooks and backported some changes
2 parents 6368fbc + f72b7e3 commit b8668bb

File tree

7 files changed

+123
-38
lines changed

7 files changed

+123
-38
lines changed

arctic/auth.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from .logging import logger
44

5-
65
def authenticate(db, user, password):
76
"""
87
Return True / False on authentication success.
@@ -20,9 +19,9 @@ def authenticate(db, user, password):
2019

2120
Credential = namedtuple("MongoCredentials", ['database', 'user', 'password'])
2221

23-
2422
def get_auth(host, app_name, database_name):
2523
"""
2624
Authentication hook to allow plugging in custom authentication credential providers
2725
"""
28-
return None
26+
from hooks import _get_auth_hook
27+
return _get_auth_hook(host, app_name, database_name)

arctic/hooks.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
_resolve_mongodb_hook = lambda env: env
44
_log_exception_hook = lambda *args, **kwargs: None
5+
_get_auth_hook = lambda *args, **kwargs: None
56

67

78
def get_mongodb_uri(host):
@@ -16,7 +17,7 @@ def get_mongodb_uri(host):
1617

1718
def register_resolve_mongodb_hook(hook):
1819
global _resolve_mongodb_hook
19-
_mongodb_resolve_hook = hook
20+
_resolve_mongodb_hook = hook
2021

2122

2223
def log_exception(fn_name, exception, retry_count, **kwargs):
@@ -29,3 +30,8 @@ def log_exception(fn_name, exception, retry_count, **kwargs):
2930
def register_log_exception_hook(hook):
3031
global _log_exception_hook
3132
_log_exception_hook = hook
33+
34+
35+
def register_get_auth_hook(hook):
36+
global _get_auth_hook
37+
_get_auth_hook = hook

arctic/store/version_store.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def _write_handler(self, version, symbol, data, **kwargs):
273273
handler = self._bson_handler
274274
return handler
275275

276-
def read(self, symbol, as_of=None, from_version=None, **kwargs):
276+
def read(self, symbol, as_of=None, from_version=None, allow_secondary=None, **kwargs):
277277
"""
278278
Read data for the named symbol. Returns a VersionedItem object with
279279
a data and metdata element (as passed into write).
@@ -292,9 +292,10 @@ def read(self, symbol, as_of=None, from_version=None, **kwargs):
292292
-------
293293
VersionedItem namedtuple which contains a .data and .metadata element
294294
"""
295+
allow_secondary = self._allow_secondary if allow_secondary is None else allow_secondary
295296
try:
296-
_version = self._read_metadata(symbol, as_of=as_of)
297-
read_preference = ReadPreference.NEAREST if self._allow_secondary else None
297+
read_preference = ReadPreference.NEAREST if allow_secondary else ReadPreference.PRIMARY
298+
_version = self._read_metadata(symbol, as_of=as_of, read_preference=read_preference)
298299
return self._do_read(symbol, _version, from_version, read_preference=read_preference, **kwargs)
299300
except (OperationFailure, AutoReconnect) as e:
300301
# Log the exception so we know how often this is happening

arctic/tickstore/tickstore.py

+25-15
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def read(self, symbol, date_range=None, columns=None, include_images=False, _tar
243243
for b in self._collection.find(query, projection=projection).sort([(START, pymongo.ASCENDING)],):
244244
data = self._read_bucket(b, column_set, column_dtypes,
245245
multiple_symbols or (columns is not None and 'SYMBOL' in columns),
246-
include_images)
246+
include_images, columns)
247247
for k, v in data.iteritems():
248248
try:
249249
rtn[k].append(v)
@@ -325,24 +325,35 @@ def _set_or_promote_dtype(self, column_dtypes, c, dtype):
325325
dtype = np.dtype('f8')
326326
column_dtypes[c] = np.promote_types(column_dtypes.get(c, dtype), dtype)
327327

328-
def _prepend_image(self, document, im):
328+
def _prepend_image(self, document, im, rtn_length, column_dtypes, column_set, columns):
329329
image = im[IMAGE]
330330
first_dt = im['t']
331331
if not first_dt.tzinfo:
332332
first_dt = first_dt.replace(tzinfo=mktz('UTC'))
333333
document[INDEX] = np.insert(document[INDEX], 0, np.uint64(datetime_to_ms(first_dt)))
334-
for field in document:
335-
if field == INDEX or document[field] is None:
334+
for field in image:
335+
if field == INDEX:
336336
continue
337-
if field in image:
338-
val = image[field]
339-
else:
340-
logger.debug("Field %s is missing from image!", field)
341-
val = np.nan
337+
if columns and field not in columns:
338+
continue
339+
if field not in document or document[field] is None:
340+
col_dtype = np.dtype(str if isinstance(image[field], basestring) else 'f8')
341+
document[field] = self._empty(rtn_length, dtype=col_dtype)
342+
column_dtypes[field] = col_dtype
343+
column_set.add(field)
344+
val = image[field]
342345
document[field] = np.insert(document[field], 0, document[field].dtype.type(val))
346+
# Now insert rows for fields in document that are not in the image
347+
for field in set(document).difference(set(image)):
348+
if field == INDEX:
349+
continue
350+
logger.debug("Field %s is missing from image!", field)
351+
if document[field] is not None:
352+
val = np.nan
353+
document[field] = np.insert(document[field], 0, document[field].dtype.type(val))
343354
return document
344355

345-
def _read_bucket(self, doc, columns, column_dtypes, include_symbol, include_images):
356+
def _read_bucket(self, doc, column_set, column_dtypes, include_symbol, include_images, columns):
346357
rtn = {}
347358
if doc[VERSION] != 3:
348359
raise ArcticException("Unhandled document version: %s" % doc[VERSION])
@@ -351,8 +362,8 @@ def _read_bucket(self, doc, columns, column_dtypes, include_symbol, include_imag
351362
rtn_length = len(rtn[INDEX])
352363
if include_symbol:
353364
rtn['SYMBOL'] = [doc[SYMBOL], ] * rtn_length
354-
columns.update(doc[COLUMNS].keys())
355-
for c in columns:
365+
column_set.update(doc[COLUMNS].keys())
366+
for c in column_set:
356367
try:
357368
coldata = doc[COLUMNS][c]
358369
dtype = np.dtype(coldata[DTYPE])
@@ -366,7 +377,7 @@ def _read_bucket(self, doc, columns, column_dtypes, include_symbol, include_imag
366377
rtn[c] = None
367378

368379
if include_images and doc.get(IMAGE_DOC, {}).get(IMAGE, {}):
369-
rtn = self._prepend_image(rtn, doc[IMAGE_DOC])
380+
rtn = self._prepend_image(rtn, doc[IMAGE_DOC], rtn_length, column_dtypes, column_set, columns)
370381
return rtn
371382

372383
def _empty(self, length, dtype):
@@ -493,8 +504,7 @@ def _to_dt(self, date, default_tz=None):
493504
elif date.tzinfo is None:
494505
if default_tz is None:
495506
raise ValueError("Must specify a TimeZone on incoming data")
496-
# Treat naive datetimes as London
497-
return date.replace(tzinfo=mktz())
507+
return date.replace(tzinfo=default_tz)
498508
return date
499509

500510
def _str_dtype(self, dtype):

tests/integration/tickstore/test_ts_read.py

+31-15
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@
77
import pytest
88
import pytz
99

10-
from arctic import arctic as m
1110
from arctic.date import DateRange, mktz, CLOSED_CLOSED, CLOSED_OPEN, OPEN_CLOSED, OPEN_OPEN
12-
from arctic.exceptions import OverlappingDataException, NoDataFoundException
11+
from arctic.exceptions import NoDataFoundException
1312

1413

1514
def test_read(tickstore_lib):
@@ -356,11 +355,11 @@ def test_read_longs(tickstore_lib):
356355
def test_read_with_image(tickstore_lib):
357356
DUMMY_DATA = [
358357
{'a': 1.,
359-
'index': dt(2013, 6, 1, 12, 00, tzinfo=mktz('Europe/London'))
358+
'index': dt(2013, 1, 1, 11, 00, tzinfo=mktz('Europe/London'))
360359
},
361360
{
362361
'b': 4.,
363-
'index': dt(2013, 6, 1, 13, 00, tzinfo=mktz('Europe/London'))
362+
'index': dt(2013, 1, 1, 12, 00, tzinfo=mktz('Europe/London'))
364363
},
365364
]
366365
# Add an image
@@ -371,21 +370,38 @@ def test_read_with_image(tickstore_lib):
371370
{'a': 37.,
372371
'c': 2.,
373372
},
374-
't': dt(2013, 6, 1, 11, 0)
373+
't': dt(2013, 1, 1, 10, tzinfo=mktz('Europe/London'))
375374
}
376375
}
377376
}
378377
)
379378

380-
tickstore_lib.read('SYM', columns=None)
381-
read = tickstore_lib.read('SYM', columns=None, date_range=DateRange(dt(2013, 6, 1), dt(2013, 6, 2)))
382-
assert read['a'][0] == 1
379+
dr = DateRange(dt(2013, 1, 1), dt(2013, 1, 2))
380+
# tickstore_lib.read('SYM', columns=None)
381+
df = tickstore_lib.read('SYM', columns=None, date_range=dr)
382+
assert df['a'][0] == 1
383383

384384
# Read with the image as well
385-
read = tickstore_lib.read('SYM', columns=None, date_range=DateRange(dt(2013, 6, 1), dt(2013, 6, 2)),
386-
include_images=True)
387-
assert read['a'][0] == 37
388-
assert read['a'][1] == 1
389-
assert np.isnan(read['b'][0])
390-
assert read['b'][2] == 4
391-
assert read.index[0] == dt(2013, 6, 1, 11)
385+
df = tickstore_lib.read('SYM', columns=None, date_range=dr, include_images=True)
386+
assert set(df.columns) == set(('a', 'b', 'c'))
387+
assert_array_equal(df['a'].values, np.array([37, 1, np.nan]))
388+
assert_array_equal(df['b'].values, np.array([np.nan, np.nan, 4]))
389+
assert_array_equal(df['c'].values, np.array([2, np.nan, np.nan]))
390+
assert df.index[0] == dt(2013, 1, 1, 10)
391+
assert df.index[1] == dt(2013, 1, 1, 11)
392+
assert df.index[2] == dt(2013, 1, 1, 12)
393+
394+
df = tickstore_lib.read('SYM', columns=('a', 'b'), date_range=dr, include_images=True)
395+
assert set(df.columns) == set(('a', 'b'))
396+
assert_array_equal(df['a'].values, np.array([37, 1, np.nan]))
397+
assert_array_equal(df['b'].values, np.array([np.nan, np.nan, 4]))
398+
assert df.index[0] == dt(2013, 1, 1, 10)
399+
assert df.index[1] == dt(2013, 1, 1, 11)
400+
assert df.index[2] == dt(2013, 1, 1, 12)
401+
402+
df = tickstore_lib.read('SYM', columns=['c'], date_range=dr, include_images=True)
403+
assert set(df.columns) == set(['c'])
404+
assert_array_equal(df['c'].values, np.array([2, np.nan, np.nan]))
405+
assert df.index[0] == dt(2013, 1, 1, 10)
406+
assert df.index[1] == dt(2013, 1, 1, 11)
407+
assert df.index[2] == dt(2013, 1, 1, 12)

tests/unit/store/test_version_store.py

+28-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
import pymongo
8-
from pymongo import ReadPreference
8+
from pymongo import ReadPreference, read_preferences
99

1010
from arctic.date import mktz
1111
from arctic.store import version_store
@@ -44,6 +44,33 @@ def test_list_versions_localTime():
4444
'snapshots': 'snap'}
4545

4646

47+
def test_get_version_allow_secondary_True():
48+
vs = create_autospec(VersionStore, instance=True,
49+
_versions=Mock())
50+
vs._allow_secondary = True
51+
vs._find_snapshots.return_value = 'snap'
52+
vs._versions.find.return_value = [{'_id': bson.ObjectId.from_datetime(dt(2013, 4, 1, 9, 0)),
53+
'symbol': 's', 'version': 10}]
54+
55+
VersionStore.read(vs, "symbol")
56+
assert vs._read_metadata.call_args_list == [call('symbol', as_of=None, read_preference=ReadPreference.NEAREST)]
57+
assert vs._do_read.call_args_list == [call('symbol', vs._read_metadata.return_value, None, read_preference=ReadPreference.NEAREST)]
58+
59+
60+
def test_get_version_allow_secondary_user_override_False():
61+
"""Ensure user can override read preference when calling read"""
62+
vs = create_autospec(VersionStore, instance=True,
63+
_versions=Mock())
64+
vs._allow_secondary = True
65+
vs._find_snapshots.return_value = 'snap'
66+
vs._versions.find.return_value = [{'_id': bson.ObjectId.from_datetime(dt(2013, 4, 1, 9, 0)),
67+
'symbol': 's', 'version': 10}]
68+
69+
VersionStore.read(vs, "symbol", allow_secondary=False)
70+
assert vs._read_metadata.call_args_list == [call('symbol', as_of=None, read_preference=ReadPreference.PRIMARY)]
71+
assert vs._do_read.call_args_list == [call('symbol', vs._read_metadata.return_value, None, read_preference=ReadPreference.PRIMARY)]
72+
73+
4774
def test_read_as_of_LondonTime():
4875
# When we do a read, with naive as_of, that as_of is treated in London Time.
4976
vs = create_autospec(VersionStore, instance=True,

tests/unit/test_hooks.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from mock import sentinel, call, Mock
2+
from arctic.hooks import register_get_auth_hook, register_log_exception_hook, \
3+
register_resolve_mongodb_hook, get_mongodb_uri, log_exception
4+
from arctic.auth import get_auth
5+
6+
7+
def test_log_exception_hook():
8+
logger = Mock()
9+
register_log_exception_hook(logger)
10+
log_exception(sentinel.fn, sentinel.e, sentinel.r)
11+
assert logger.call_args_list == [call(sentinel.fn, sentinel.e, sentinel.r)]
12+
13+
14+
def test_get_mongodb_uri_hook():
15+
resolver = Mock()
16+
resolver.return_value = sentinel.result
17+
register_resolve_mongodb_hook(resolver)
18+
assert get_mongodb_uri(sentinel.host) == sentinel.result
19+
assert resolver.call_args_list == [call(sentinel.host)]
20+
21+
22+
def test_get_auth_hook():
23+
auth_resolver = Mock()
24+
register_get_auth_hook(auth_resolver)
25+
get_auth(sentinel.host, sentinel.app_name, sentinel.database_name)
26+
assert auth_resolver.call_args_list == [call(sentinel.host, sentinel.app_name, sentinel.database_name)]

0 commit comments

Comments
 (0)