Skip to content

Commit ccbbcda

Browse files
committed
added unit tests for dbapi module
1 parent d98b0e9 commit ccbbcda

File tree

7 files changed

+167
-25
lines changed

7 files changed

+167
-25
lines changed

tarantool/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,4 +75,4 @@ def connectmesh(addrs=({'host': 'localhost', 'port': 3301},), user=None,
7575

7676
__all__ = ['connect', 'Connection', 'connectmesh', 'MeshConnection', 'Schema',
7777
'Error', 'DatabaseError', 'NetworkError', 'NetworkWarning',
78-
'SchemaError']
78+
'SchemaError', 'dbapi']

tarantool/connection.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
ITERATOR_ALL
5252
)
5353
from tarantool.error import (
54+
Error,
5455
NetworkError,
5556
DatabaseError,
5657
InterfaceError,
@@ -85,7 +86,7 @@ class Connection(object):
8586
(insert/delete/update/select).
8687
'''
8788
# DBAPI Extension: supply exceptions as attributes on the connection
88-
Error = tarantool.error
89+
Error = Error
8990
DatabaseError = DatabaseError
9091
InterfaceError = InterfaceError
9192
SchemaError = SchemaError
@@ -108,6 +109,7 @@ def __init__(self, host, port,
108109
connect_now=True,
109110
encoding=ENCODING_DEFAULT,
110111
call_16=False,
112+
use_list=True,
111113
connection_timeout=CONNECTION_TIMEOUT):
112114
'''
113115
Initialize a connection to the server.
@@ -140,6 +142,7 @@ def __init__(self, host, port,
140142
self._socket = None
141143
self.connected = False
142144
self.error = True
145+
self.use_list = use_list
143146
self.encoding = encoding
144147
self.call_16 = call_16
145148
self.connection_timeout = connection_timeout
@@ -277,7 +280,7 @@ def _send_request_wo_reconnect(self, request):
277280
while True:
278281
try:
279282
self._socket.sendall(bytes(request))
280-
response = Response(self, self._read_response())
283+
response = Response(self, self._read_response(), self.use_list)
281284
break
282285
except SchemaReloadException as e:
283286
self.update_schema(e.schema_version)
@@ -461,7 +464,7 @@ def _join_v16(self, server_uuid):
461464
self._socket.sendall(bytes(request))
462465

463466
while True:
464-
resp = Response(self, self._read_response())
467+
resp = Response(self, self._read_response(), self.use_list)
465468
yield resp
466469
if resp.code == REQUEST_TYPE_OK or resp.code >= REQUEST_TYPE_ERROR:
467470
return
@@ -475,7 +478,7 @@ class JoinState:
475478
self._socket.sendall(bytes(request))
476479
state = JoinState.Handshake
477480
while True:
478-
resp = Response(self, self._read_response())
481+
resp = Response(self, self._read_response(), self.use_list)
479482
yield resp
480483
if resp.code >= REQUEST_TYPE_ERROR:
481484
return
@@ -504,7 +507,7 @@ def subscribe(self, cluster_uuid, server_uuid, vclock=None):
504507
request = RequestSubscribe(self, cluster_uuid, server_uuid, vclock)
505508
self._socket.sendall(bytes(request))
506509
while True:
507-
resp = Response(self, self._read_response())
510+
resp = Response(self, self._read_response(), self.use_list)
508511
yield resp
509512
if resp.code >= REQUEST_TYPE_ERROR:
510513
return

tarantool/dbapi.py

+50-12
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,21 @@
66
from tarantool.error import InterfaceError
77
from .connection import Connection as BaseConnection
88

9-
update_insert_pattern = re.compile(r'^UPDATE|INSERT')
9+
update_insert_pattern = re.compile(r'^UPDATE|^INSERT', re.IGNORECASE)
1010

1111

1212
class Cursor:
1313
_lastrowid = 0
1414
_rowcount = 0
1515
description = None
1616
position = 0
17-
arraysize = 200
17+
arraysize = 1
1818
autocommit = True
19+
closed = False
1920

2021
def __init__(self, connection):
2122
self._c = connection
22-
self.rows = []
23+
self.rows = None
2324

2425
def callproc(self, procname, *params): # TODO
2526
"""
@@ -48,6 +49,8 @@ def _convert_param(p):
4849
return "NULL"
4950
if isinstance(p, bool):
5051
return str(p)
52+
if isinstance(p, str):
53+
return "'%s'" % p.replace("'", "''")
5154
return "'%s'" % p
5255

5356
@staticmethod
@@ -85,29 +88,37 @@ def execute(self, query, params=None):
8588
8689
Return values are not defined.
8790
"""
91+
if self.closed:
92+
raise self._c.ProgrammingError
8893
if params:
8994
query = query % tuple(
9095
self._convert_param(param) for param in params)
9196

9297
response = self._c.execute(query)
9398

9499
self.rows = tuple(response.body.values())[1] if len(
95-
response.body) > 1 else []
100+
response.body) > 1 else None
96101

97-
if update_insert_pattern.match(query.upper()):
102+
if update_insert_pattern.match(query):
98103
try:
99104
self._rowcount = response.rowcount
100105
except InterfaceError:
101-
self._rowcount = 1
106+
self._rowcount = -1
102107
else:
103-
self._rowcount = 1
108+
self._rowcount = -1
104109

105110
if query.upper().startswith('INSERT'):
106111
self._lastrowid = self._extract_last_row_id(response.body)
107112
return response
108113

109-
def executemany(self, query, params):
110-
return self.execute(query, params)
114+
def executemany(self, query, param_sets):
115+
rowcounts = []
116+
for params in param_sets:
117+
self.execute(query, params)
118+
rowcounts.append(self.rowcount)
119+
120+
self._rowcount = -1 if -1 in rowcounts else sum(rowcounts)
121+
return self
111122

112123
@property
113124
def lastrowid(self):
@@ -150,10 +161,11 @@ def fetchone(self):
150161
An Error (or subclass) exception is raised if the previous call to
151162
.execute*() did not produce any result set or no call was issued yet.
152163
"""
153-
164+
if self.rows is None:
165+
raise self._c.Error
154166
return self.fetchmany(1)[0] if len(self.rows) else None
155167

156-
def fetchmany(self, size):
168+
def fetchmany(self, size=None):
157169
"""
158170
Fetch the next set of rows of a query result, returning a sequence of
159171
sequences (e.g. a list of tuples). An empty sequence is returned when
@@ -174,6 +186,11 @@ def fetchmany(self, size):
174186
.arraysize attribute. If the size parameter is used, then it is best
175187
for it to retain the same value from one .fetchmany() call to the next.
176188
"""
189+
size = size or self.arraysize
190+
191+
if self.rows is None:
192+
raise self._c.ProgrammingError
193+
177194
if len(self.rows) < size:
178195
items = self.rows
179196
self.rows = []
@@ -190,6 +207,9 @@ def fetchall(self):
190207
An Error (or subclass) exception is raised if the previous call to
191208
.execute*() did not produce any result set or no call was issued yet.
192209
"""
210+
if self.rows is None:
211+
raise self._c.ProgrammingError
212+
193213
items = self.rows[:]
194214
self.rows = []
195215
return items
@@ -216,10 +236,17 @@ def setoutputsize(self, size, column=None):
216236

217237
class Connection(BaseConnection):
218238
_cursor = None
239+
paramstyle = 'format'
240+
apilevel = "2.0"
241+
threadsafety = 0
219242

220243
server_version = 2
221244

222-
def commit(self): # TODO
245+
def connect(self):
246+
super().connect()
247+
return self
248+
249+
def commit(self):
223250
"""
224251
Commit any pending transaction to the database.
225252
@@ -230,6 +257,8 @@ def commit(self): # TODO
230257
Database modules that do not support transactions should implement
231258
this method with void functionality.
232259
"""
260+
if self._socket is None:
261+
raise self.ProgrammingError
233262

234263
def rollback(self):
235264
"""
@@ -238,6 +267,13 @@ def rollback(self):
238267
Closing a connection without committing the changes first will cause
239268
an implicit rollback to be performed.
240269
"""
270+
if self._socket is None:
271+
raise self.ProgrammingError
272+
273+
def execute(self, query, params=None):
274+
if self._socket is None:
275+
raise self.ProgrammingError
276+
return super().execute(query, params)
241277

242278
def close(self):
243279
"""
@@ -252,6 +288,8 @@ def close(self):
252288
if self._socket:
253289
self._socket.close()
254290
self._socket = None
291+
else:
292+
raise self.ProgrammingError
255293

256294
def _set_cursor(self):
257295
self._cursor = Cursor(self)

test.sh

+2-6
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,11 @@
33
set -exu # Strict shell (w/o -o pipefail)
44

55
# Install tarantool.
6-
curl http://download.tarantool.org/tarantool/2x/gpgkey | sudo apt-key add -
7-
release=`lsb_release -c -s`
8-
echo "deb http://download.tarantool.org/tarantool/2x/ubuntu/ ${release} main" | sudo tee /etc/apt/sources.list.d/tarantool_2x.list
9-
sudo apt-get update > /dev/null
10-
sudo apt-get -q -y install tarantool
6+
curl -L https://tarantool.io/installer.sh | VER=2.4 bash
117

128
# Install testing dependencies.
139
pip install -r requirements.txt
14-
pip install pyyaml
10+
pip install pyyaml dbapi-compliance==1.15.0
1511

1612
# Run tests.
1713
python setup.py test

unit/suites/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
from .test_protocol import TestSuite_Protocol
1010
from .test_reconnect import TestSuite_Reconnect
1111
from .test_mesh import TestSuite_Mesh
12+
from .test_dbapi import TestSuite_DBAPI
1213

1314
test_cases = (TestSuite_Schema, TestSuite_Request, TestSuite_Protocol,
14-
TestSuite_Reconnect, TestSuite_Mesh)
15+
TestSuite_Reconnect, TestSuite_Mesh, TestSuite_DBAPI)
1516

1617
def load_tests(loader, tests, pattern):
1718
suite = unittest.TestSuite()

unit/suites/box_dbapi.lua

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env tarantool
2+
os = require('os')
3+
4+
require('console').listen(os.getenv("ADMIN_PORT"))
5+
box.cfg {
6+
listen = os.getenv("PRIMARY_PORT"),
7+
memtx_memory = 0.1 * 1024 ^ 3, -- 0.1 GiB
8+
pid_file = "box.pid",
9+
}
10+
box.schema.user.grant('guest', 'create,read,write', 'universe')

unit/suites/test_dbapi.py

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from __future__ import print_function
4+
5+
import sys
6+
import unittest
7+
8+
from dbapi20 import DatabaseAPI20Test
9+
10+
from tarantool.dbapi import Connection
11+
from .lib.tarantool_server import TarantoolServer
12+
13+
14+
class TestSuite_DBAPI(DatabaseAPI20Test):
15+
table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
16+
17+
ddl1 = 'create table %sbooze (name varchar(20) primary key)' % table_prefix
18+
ddl2 = 'create table %sbarflys (name varchar(20) primary key, ' \
19+
'drink varchar(30))' % table_prefix
20+
21+
@classmethod
22+
def setUpClass(cls):
23+
print(' DBAPI '.center(70, '='), file=sys.stderr)
24+
print('-' * 70, file=sys.stderr)
25+
26+
def setUp(self):
27+
self.srv = TarantoolServer()
28+
self.srv.script = 'unit/suites/box_dbapi.lua'
29+
self.srv.start()
30+
self.driver = Connection(self.srv.host, self.srv.args['primary'])
31+
# prevent a remote tarantool from clean our session
32+
if self.srv.is_started():
33+
self.srv.touch_lock()
34+
35+
def tearDown(self):
36+
# self.driver.close()
37+
self.srv.stop()
38+
self.srv.clean()
39+
40+
@unittest.skip('Not implemented')
41+
def test_Binary(self):
42+
pass
43+
44+
@unittest.skip('Not implemented')
45+
def test_STRING(self):
46+
pass
47+
48+
@unittest.skip('Not implemented')
49+
def test_BINARY(self):
50+
pass
51+
52+
@unittest.skip('Not implemented')
53+
def test_NUMBER(self):
54+
pass
55+
56+
@unittest.skip('Not implemented')
57+
def test_DATETIME(self):
58+
pass
59+
60+
@unittest.skip('Not implemented')
61+
def test_ROWID(self):
62+
pass
63+
64+
@unittest.skip('Not implemented')
65+
def test_Date(self):
66+
pass
67+
68+
@unittest.skip('Not implemented')
69+
def test_Time(self):
70+
pass
71+
72+
@unittest.skip('Not implemented')
73+
def test_Timestamp(self):
74+
pass
75+
76+
@unittest.skip('Not implemented as optional.')
77+
def test_nextset(self):
78+
pass
79+
80+
@unittest.skip('To do')
81+
def test_callproc(self):
82+
pass
83+
84+
@unittest.skip('To do')
85+
def test_setoutputsize(self):
86+
pass
87+
88+
@unittest.skip('To do')
89+
def test_description(self):
90+
pass
91+
92+
@unittest.skip('To do')
93+
def test_close(self):
94+
pass

0 commit comments

Comments
 (0)