Skip to content

Commit 9781463

Browse files
committed
Fixed LOAD DATA LOCAL INFILE commands
* Fixes for both "buffered" & "unbuffered" cursor types * Registered `test_load_local` to run with tests * Refactored `test_load_local` tests to work with the tornado framework More information here: #29
1 parent 45a4eab commit 9781463

File tree

3 files changed

+38
-22
lines changed

3 files changed

+38
-22
lines changed

tornado_mysql/connections.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def is_ascii(data):
105105
def _scramble(password, message):
106106
if not password:
107107
return b'\0'
108-
if DEBUG: print('password=' + password)
108+
if DEBUG: print('password=' + str(password))
109109
stage1 = sha_new(password).digest()
110110
stage2 = sha_new(stage1).digest()
111111
s = sha_new()
@@ -1064,7 +1064,7 @@ def read(self):
10641064
if first_packet.is_ok_packet():
10651065
self._read_ok_packet(first_packet)
10661066
elif first_packet.is_load_local_packet():
1067-
self._read_load_local_packet(first_packet)
1067+
yield self._read_load_local_packet(first_packet)
10681068
else:
10691069
yield self._read_result_packet(first_packet)
10701070
finally:
@@ -1079,6 +1079,10 @@ def init_unbuffered_query(self):
10791079
self._read_ok_packet(first_packet)
10801080
self.unbuffered_active = False
10811081
self.connection = None
1082+
elif first_packet.is_load_local_packet():
1083+
yield self._read_load_local_packet(first_packet)
1084+
self.unbuffered_active = False
1085+
self.connection = None
10821086
else:
10831087
self.field_count = first_packet.read_length_encoded_integer()
10841088
yield self._get_descriptions()
@@ -1097,12 +1101,13 @@ def _read_ok_packet(self, first_packet):
10971101
self.message = ok_packet.message
10981102
self.has_next = ok_packet.has_next
10991103

1104+
@gen.coroutine
11001105
def _read_load_local_packet(self, first_packet):
11011106
load_packet = LoadLocalPacketWrapper(first_packet)
11021107
sender = LoadLocalFile(load_packet.filename, self.connection)
11031108
sender.send_data()
11041109

1105-
ok_packet = self.connection._read_packet()
1110+
ok_packet = yield self.connection._read_packet()
11061111
if not ok_packet.is_ok_packet():
11071112
raise OperationalError(2014, "Commands Out of Sync")
11081113
self._read_ok_packet(ok_packet)
@@ -1219,7 +1224,7 @@ def __init__(self, filename, connection):
12191224

12201225
def send_data(self):
12211226
"""Send data packets from the local file to the server"""
1222-
if not self.connection.socket:
1227+
if not self.connection._stream:
12231228
raise InterfaceError("(0, '')")
12241229

12251230
# sequence id is 2 as we already sent a query packet

tornado_mysql/tests/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from tornado_mysql.tests.test_issues import *
22
from tornado_mysql.tests.test_basic import *
3+
from tornado_mysql.tests.test_load_local import *
34
from tornado_mysql.tests.test_nextset import *
45
from tornado_mysql.tests.test_DictCursor import *
56
from tornado_mysql.tests.test_connection import TestConnection

tornado_mysql/tests/test_load_local.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from pymysql import OperationalError, Warning
2-
from pymysql.tests import base
1+
from tornado.testing import gen_test
2+
from tornado_mysql.err import OperationalError, Warning
3+
from tornado_mysql.tests import base
34

45
import os
56
import warnings
@@ -8,59 +9,68 @@
89

910

1011
class TestLoadLocal(base.PyMySQLTestCase):
12+
@gen_test
1113
def test_no_file(self):
1214
"""Test load local infile when the file does not exist"""
1315
conn = self.connections[0]
1416
c = conn.cursor()
15-
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
17+
with warnings.catch_warnings():
18+
warnings.simplefilter("ignore")
19+
yield c.execute("DROP TABLE IF EXISTS test_load_local")
20+
yield c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
1621
try:
17-
self.assertRaises(
18-
OperationalError,
19-
c.execute,
20-
("LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE "
21-
"test_load_local fields terminated by ','")
22-
)
22+
with self.assertRaises(OperationalError) as cm:
23+
yield c.execute(
24+
"LOAD DATA LOCAL INFILE 'no_data.txt' INTO TABLE "
25+
"test_load_local fields terminated by ','")
2326
finally:
24-
c.execute("DROP TABLE test_load_local")
27+
yield c.execute("DROP TABLE IF EXISTS test_load_local")
2528
c.close()
26-
29+
@gen_test
2730
def test_load_file(self):
2831
"""Test load local infile with a valid file"""
2932
conn = self.connections[0]
3033
c = conn.cursor()
31-
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
34+
with warnings.catch_warnings():
35+
warnings.simplefilter("ignore")
36+
yield c.execute("DROP TABLE IF EXISTS test_load_local")
37+
yield c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
3238
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
3339
'data',
3440
'load_local_data.txt')
3541
try:
36-
c.execute(
42+
yield c.execute(
3743
("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
3844
"test_load_local FIELDS TERMINATED BY ','").format(filename)
3945
)
40-
c.execute("SELECT COUNT(*) FROM test_load_local")
46+
yield c.execute("SELECT COUNT(*) FROM test_load_local")
4147
self.assertEqual(22749, c.fetchone()[0])
4248
finally:
43-
c.execute("DROP TABLE test_load_local")
49+
yield c.execute("DROP TABLE IF EXISTS test_load_local")
4450

51+
@gen_test
4552
def test_load_warnings(self):
4653
"""Test load local infile produces the appropriate warnings"""
4754
conn = self.connections[0]
4855
c = conn.cursor()
49-
c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
56+
with warnings.catch_warnings():
57+
warnings.simplefilter("ignore")
58+
yield c.execute("DROP TABLE IF EXISTS test_load_local")
59+
yield c.execute("CREATE TABLE test_load_local (a INTEGER, b INTEGER)")
5060
filename = os.path.join(os.path.dirname(os.path.realpath(__file__)),
5161
'data',
5262
'load_local_warn_data.txt')
5363
try:
5464
with warnings.catch_warnings(record=True) as w:
5565
warnings.simplefilter('always')
56-
c.execute(
66+
yield c.execute(
5767
("LOAD DATA LOCAL INFILE '{0}' INTO TABLE " +
5868
"test_load_local FIELDS TERMINATED BY ','").format(filename)
5969
)
6070
self.assertEqual(w[0].category, Warning)
6171
self.assertTrue("Incorrect integer value" in str(w[-1].message))
6272
finally:
63-
c.execute("DROP TABLE test_load_local")
73+
yield c.execute("DROP TABLE IF EXISTS test_load_local")
6474

6575

6676
if __name__ == "__main__":

0 commit comments

Comments
 (0)