diff --git a/tarantool/schema.py b/tarantool/schema.py index c06f3f5f..8d0ec790 100644 --- a/tarantool/schema.py +++ b/tarantool/schema.py @@ -10,32 +10,83 @@ integer_types, ) from tarantool.error import ( + Error, SchemaError, DatabaseError ) import tarantool.const as const +class RecursionError(Error): + """Report the situation when max recursion depth is reached. + + This is internal error for caller + and it should be re-raised properly be the caller. + """ + + +def to_unicode(s): + if isinstance(s, bytes): + return s.decode(encoding='utf-8') + return s + + +def to_unicode_recursive(x, max_depth): + """Same as to_unicode(), but traverses over dictionaries, + lists and tuples recursivery. + + x: value to convert + + max_depth: 1 accepts a scalar, 2 accepts a list of scalars, + etc. + """ + if max_depth <= 0: + raise RecursionError('Max recursion depth is reached') + + if isinstance(x, dict): + res = dict() + for key, val in x.items(): + key = to_unicode_recursive(key, max_depth - 1) + val = to_unicode_recursive(val, max_depth - 1) + res[key] = val + return res + + if isinstance(x, list) or isinstance(x, tuple): + res = [] + for val in x: + val = to_unicode_recursive(val, max_depth - 1) + res.append(val) + if isinstance(x, tuple): + return tuple(res) + return res + + return to_unicode(x) + + class SchemaIndex(object): def __init__(self, index_row, space): self.iid = index_row[1] self.name = index_row[2] - if isinstance(self.name, bytes): - self.name = self.name.decode() + self.name = to_unicode(index_row[2]) self.index = index_row[3] self.unique = index_row[4] self.parts = [] - if isinstance(index_row[5], (list, tuple)): - for val in index_row[5]: + try: + parts_raw = to_unicode_recursive(index_row[5], 3) + except RecursionError as e: + errmsg = 'Unexpected index parts structure: ' + str(e) + raise SchemaError(errmsg) + if isinstance(parts_raw, (list, tuple)): + for val in parts_raw: if isinstance(val, dict): self.parts.append((val['field'], val['type'])) else: self.parts.append((val[0], val[1])) else: - for i in range(index_row[5]): + for i in range(parts_raw): self.parts.append(( - index_row[5 + 1 + i * 2], - index_row[5 + 2 + i * 2] + to_unicode(index_row[5 + 1 + i * 2]), + to_unicode(index_row[5 + 2 + i * 2]) )) self.space = space self.space.indexes[self.iid] = self @@ -52,16 +103,19 @@ class SchemaSpace(object): def __init__(self, space_row, schema): self.sid = space_row[0] self.arity = space_row[1] - self.name = space_row[2] - if isinstance(self.name, bytes): - self.name = self.name.decode() + self.name = to_unicode(space_row[2]) self.indexes = {} self.schema = schema self.schema[self.sid] = self if self.name: self.schema[self.name] = self self.format = dict() - for part_id, part in enumerate(space_row[6]): + try: + format_raw = to_unicode_recursive(space_row[6], 3) + except RecursionError as e: + errmsg = 'Unexpected space format structure: ' + str(e) + raise SchemaError(errmsg) + for part_id, part in enumerate(format_raw): part['id'] = part_id self.format[part['name']] = part self.format[part_id ] = part @@ -78,6 +132,8 @@ def __init__(self, con): self.con = con def get_space(self, space): + space = to_unicode(space) + try: return self.schema[space] except KeyError: @@ -135,6 +191,9 @@ def fetch_space_all(self): SchemaSpace(row, self.schema) def get_index(self, space, index): + space = to_unicode(space) + index = to_unicode(index) + _space = self.get_space(space) try: return _space.indexes[index] @@ -203,6 +262,9 @@ def fetch_index_from(self, space, index): return index_row def get_field(self, space, field): + space = to_unicode(space) + field = to_unicode(field) + _space = self.get_space(space) try: return _space.format[field] diff --git a/unit/setup_command.py b/unit/setup_command.py index dbb6624b..65e6f780 100755 --- a/unit/setup_command.py +++ b/unit/setup_command.py @@ -23,7 +23,7 @@ def run(self): Find all tests in test/tarantool/ and run them ''' - tests = unittest.defaultTestLoader.discover('unit') + tests = unittest.defaultTestLoader.discover('unit', pattern='suites') test_runner = unittest.TextTestRunner(verbosity=2) result = test_runner.run(tests) if not result.wasSuccessful(): diff --git a/unit/suites/__init__.py b/unit/suites/__init__.py index ead75297..7e9d12e3 100644 --- a/unit/suites/__init__.py +++ b/unit/suites/__init__.py @@ -4,14 +4,17 @@ __tmp = os.getcwd() os.chdir(os.path.abspath(os.path.dirname(__file__))) -from .test_schema import TestSuite_Schema +from .test_schema import TestSuite_Schema_UnicodeConnection +from .test_schema import TestSuite_Schema_BinaryConnection from .test_dml import TestSuite_Request from .test_protocol import TestSuite_Protocol from .test_reconnect import TestSuite_Reconnect from .test_mesh import TestSuite_Mesh -test_cases = (TestSuite_Schema, TestSuite_Request, TestSuite_Protocol, - TestSuite_Reconnect, TestSuite_Mesh) +test_cases = (TestSuite_Schema_UnicodeConnection, + TestSuite_Schema_BinaryConnection, + TestSuite_Request, TestSuite_Protocol, TestSuite_Reconnect, + TestSuite_Mesh) def load_tests(loader, tests, pattern): suite = unittest.TestSuite() diff --git a/unit/suites/test_schema.py b/unit/suites/test_schema.py index cb772b2d..37850e06 100644 --- a/unit/suites/test_schema.py +++ b/unit/suites/test_schema.py @@ -7,22 +7,100 @@ import tarantool from .lib.tarantool_server import TarantoolServer -class TestSuite_Schema(unittest.TestCase): + +# FIXME: I'm quite sure that there is a simpler way to count +# a method calls, but I failed to find any. It seems, I should +# look at unittest.mock more thoroughly. +class MethodCallCounter: + def __init__(self, obj, method_name): + self._call_count = 0 + self._bind(obj, method_name) + + def _bind(self, obj, method_name): + self._obj = obj + self._method_name = method_name + self._saved_method = getattr(obj, method_name) + def wrapper(_, *args, **kwargs): + self._call_count += 1 + return self._saved_method(*args, **kwargs) + bound_wrapper = wrapper.__get__(obj.__class__, obj) + setattr(obj, method_name, bound_wrapper) + + def unbind(self): + if self._saved_method is not None: + setattr(self._obj, self._method_name, self._saved_method) + + def call_count(self): + return self._call_count + + +class TestSuite_Schema_Abstract(unittest.TestCase): + # Define 'encoding' field in a concrete class. + @classmethod def setUpClass(self): - print(' SCHEMA '.center(70, '='), file=sys.stderr) + params = 'connection.encoding: {}'.format(repr(self.encoding)) + print(' SCHEMA ({}) '.format(params).center(70, '='), file=sys.stderr) print('-' * 70, file=sys.stderr) self.srv = TarantoolServer() self.srv.script = 'unit/suites/box.lua' self.srv.start() - self.con = tarantool.Connection(self.srv.host, self.srv.args['primary']) + self.con = tarantool.Connection(self.srv.host, self.srv.args['primary'], + encoding=self.encoding) self.sch = self.con.schema + # The relevant test cases mainly target Python 2, where + # a user may want to pass a string literal as a space or + # an index name and don't bother whether all symbols in it + # are ASCII. + self.unicode_space_name_literal = '∞' + self.unicode_index_name_literal = '→' + + self.unicode_space_name_u = u'∞' + self.unicode_index_name_u = u'→' + self.unicode_space_id, self.unicode_index_id = self.srv.admin(""" + do + local space = box.schema.create_space('\\xe2\\x88\\x9e') + local index = space:create_index('\\xe2\\x86\\x92') + return space.id, index.id + end + """) + def setUp(self): # prevent a remote tarantool from clean our session if self.srv.is_started(): self.srv.touch_lock() + # Count calls of fetch methods. See . + self.fetch_space_counter = MethodCallCounter(self.sch, 'fetch_space') + self.fetch_index_counter = MethodCallCounter(self.sch, 'fetch_index') + + def tearDown(self): + self.fetch_space_counter.unbind() + self.fetch_index_counter.unbind() + + @property + def fetch_count(self): + """Amount of fetch_{space,index}() calls. + + It is initialized to zero before each test case. + """ + res = 0 + res += self.fetch_space_counter.call_count() + res += self.fetch_index_counter.call_count() + return res + + def verify_unicode_space(self, space): + self.assertEqual(space.sid, self.unicode_space_id) + self.assertEqual(space.name, self.unicode_space_name_u) + self.assertEqual(space.arity, 1) + + def verify_unicode_index(self, index): + self.assertEqual(index.space.name, self.unicode_space_name_u) + self.assertEqual(index.iid, self.unicode_index_id) + self.assertEqual(index.name, self.unicode_index_name_u) + self.assertEqual(len(index.parts), 1) + def test_00_authenticate(self): self.assertIsNone(self.srv.admin("box.schema.user.create('test', { password = 'test' })")) self.assertIsNone(self.srv.admin("box.schema.user.grant('test', 'read,write', 'space', '_space')")) @@ -72,6 +150,9 @@ def test_03_01_space_name__(self): self.assertEqual(space.name, '_index') self.assertEqual(space.arity, 1) + space = self.sch.get_space(self.unicode_space_name_literal) + self.verify_unicode_space(space) + def test_03_02_space_number(self): self.con.flush_schema() space = self.sch.get_space(272) @@ -87,6 +168,9 @@ def test_03_02_space_number(self): self.assertEqual(space.name, '_index') self.assertEqual(space.arity, 1) + space = self.sch.get_space(self.unicode_space_id) + self.verify_unicode_space(space) + def test_04_space_cached(self): space = self.sch.get_space('_schema') self.assertEqual(space.sid, 272) @@ -101,6 +185,15 @@ def test_04_space_cached(self): self.assertEqual(space.name, '_index') self.assertEqual(space.arity, 1) + # Verify that no schema fetches occurs. + self.assertEqual(self.fetch_count, 0) + + space = self.sch.get_space(self.unicode_space_name_literal) + self.verify_unicode_space(space) + + # Verify that no schema fetches occurs. + self.assertEqual(self.fetch_count, 0) + def test_05_01_index_name___name__(self): self.con.flush_schema() index = self.sch.get_index('_index', 'primary') @@ -124,6 +217,10 @@ def test_05_01_index_name___name__(self): self.assertEqual(index.name, 'name') self.assertEqual(len(index.parts), 1) + index = self.sch.get_index(self.unicode_space_name_literal, + self.unicode_index_name_literal) + self.verify_unicode_index(index) + def test_05_02_index_name___number(self): self.con.flush_schema() index = self.sch.get_index('_index', 0) @@ -147,6 +244,10 @@ def test_05_02_index_name___number(self): self.assertEqual(index.name, 'name') self.assertEqual(len(index.parts), 1) + index = self.sch.get_index(self.unicode_space_name_literal, + self.unicode_index_id) + self.verify_unicode_index(index) + def test_05_03_index_number_name__(self): self.con.flush_schema() index = self.sch.get_index(288, 'primary') @@ -170,6 +271,10 @@ def test_05_03_index_number_name__(self): self.assertEqual(index.name, 'name') self.assertEqual(len(index.parts), 1) + index = self.sch.get_index(self.unicode_space_id, + self.unicode_index_name_literal) + self.verify_unicode_index(index) + def test_05_04_index_number_number(self): self.con.flush_schema() index = self.sch.get_index(288, 0) @@ -193,6 +298,10 @@ def test_05_04_index_number_number(self): self.assertEqual(index.name, 'name') self.assertEqual(len(index.parts), 1) + index = self.sch.get_index(self.unicode_space_id, + self.unicode_index_id) + self.verify_unicode_index(index) + def test_06_index_cached(self): index = self.sch.get_index('_index', 'primary') self.assertEqual(index.space.name, '_index') @@ -215,6 +324,22 @@ def test_06_index_cached(self): self.assertEqual(index.name, 'name') self.assertEqual(len(index.parts), 1) + # Verify that no schema fetches occurs. + self.assertEqual(self.fetch_count, 0) + + cases = ( + (self.unicode_space_name_literal, self.unicode_index_name_literal), + (self.unicode_space_name_literal, self.unicode_index_id), + (self.unicode_space_id, self.unicode_index_name_literal), + (self.unicode_space_id, self.unicode_index_id), + ) + for s, i in cases: + index = self.sch.get_index(s, i) + self.verify_unicode_index(index) + + # Verify that no schema fetches occurs. + self.assertEqual(self.fetch_count, 0) + def test_07_schema_version_update(self): _space_len = len(self.con.select('_space')) self.srv.admin("box.schema.create_space('ttt22')") @@ -225,3 +350,11 @@ def tearDownClass(self): self.con.close() self.srv.stop() self.srv.clean() + + +class TestSuite_Schema_UnicodeConnection(TestSuite_Schema_Abstract): + encoding = 'utf-8' + + +class TestSuite_Schema_BinaryConnection(TestSuite_Schema_Abstract): + encoding = None