Skip to content

Stop using surrogate escape. #302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions MySQLdb/_mysql.c
Original file line number Diff line number Diff line change
Expand Up @@ -915,14 +915,15 @@ _mysql.string_literal(obj) cannot handle character sets.";
static PyObject *
_mysql_string_literal(
_mysql_ConnectionObject *self,
PyObject *args)
PyObject *o)
{
PyObject *str, *s, *o, *d;
PyObject *str, *s;
char *in, *out;
int len, size;

if (self && PyModule_Check((PyObject*)self))
self = NULL;
if (!PyArg_ParseTuple(args, "O|O:string_literal", &o, &d)) return NULL;

if (PyBytes_Check(o)) {
s = o;
Py_INCREF(s);
Expand Down Expand Up @@ -965,33 +966,25 @@ static PyObject *_mysql_NULL;

static PyObject *
_escape_item(
PyObject *self,
PyObject *item,
PyObject *d)
{
PyObject *quoted=NULL, *itemtype, *itemconv;
if (!(itemtype = PyObject_Type(item)))
goto error;
if (!(itemtype = PyObject_Type(item))) {
return NULL;
}
itemconv = PyObject_GetItem(d, itemtype);
Py_DECREF(itemtype);
if (!itemconv) {
PyErr_Clear();
itemconv = PyObject_GetItem(d,
#ifdef IS_PY3K
(PyObject *) &PyUnicode_Type);
#else
(PyObject *) &PyString_Type);
#endif
}
if (!itemconv) {
PyErr_SetString(PyExc_TypeError,
"no default type converter defined");
goto error;
return _mysql_string_literal((_mysql_ConnectionObject*)self, item);
}
Py_INCREF(d);
quoted = PyObject_CallFunction(itemconv, "OO", item, d);
Py_DECREF(d);
Py_DECREF(itemconv);
error:

return quoted;
}

Expand All @@ -1013,14 +1006,14 @@ _mysql_escape(
"argument 2 must be a mapping");
return NULL;
}
return _escape_item(o, d);
return _escape_item(self, o, d);
} else {
if (!self) {
PyErr_SetString(PyExc_TypeError,
"argument 2 must be a mapping");
return NULL;
}
return _escape_item(o,
return _escape_item(self, o,
((_mysql_ConnectionObject *) self)->converter);
}
}
Expand Down Expand Up @@ -2264,7 +2257,7 @@ static PyMethodDef _mysql_ConnectionObject_methods[] = {
{
"string_literal",
(PyCFunction)_mysql_string_literal,
METH_VARARGS,
METH_O,
_mysql_string_literal__doc__},
{
"thread_id",
Expand Down Expand Up @@ -2587,7 +2580,7 @@ _mysql_methods[] = {
{
"string_literal",
(PyCFunction)_mysql_string_literal,
METH_VARARGS,
METH_O,
_mysql_string_literal__doc__
},
{
Expand Down
48 changes: 10 additions & 38 deletions MySQLdb/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,6 @@
)


if not PY2:
if sys.version_info[:2] < (3, 6):
# See http://bugs.python.org/issue24870
_surrogateescape_table = [chr(i) if i < 0x80 else chr(i + 0xdc00) for i in range(256)]

def _fast_surrogateescape(s):
return s.decode('latin1').translate(_surrogateescape_table)
else:
def _fast_surrogateescape(s):
return s.decode('ascii', 'surrogateescape')


re_numeric_part = re.compile(r"^(\d+)")

def numeric_part(s):
Expand Down Expand Up @@ -183,21 +171,8 @@ class object, used to create cursors (keyword only)
self.encoding = 'ascii' # overridden in set_character_set()
db = proxy(self)

# Note: string_literal() is called for bytes object on Python 3 (via bytes_literal)
def string_literal(obj, dummy=None):
return db.string_literal(obj)

if PY2:
# unicode_literal is called for only unicode object.
def unicode_literal(u, dummy=None):
return db.string_literal(u.encode(db.encoding))
else:
# unicode_literal() is called for arbitrary object.
def unicode_literal(u, dummy=None):
return db.string_literal(str(u).encode(db.encoding))

def bytes_literal(obj, dummy=None):
return b'_binary' + db.string_literal(obj)
def unicode_literal(u, dummy=None):
return db.string_literal(u.encode(db.encoding))

def string_decoder(s):
return s.decode(db.encoding)
Expand All @@ -214,7 +189,6 @@ def string_decoder(s):
FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.LONG_BLOB, FIELD_TYPE.BLOB):
self.converter[t].append((None, string_decoder))

self.encoders[bytes] = string_literal
self.encoders[unicode] = unicode_literal
self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS
if self._transactional:
Expand Down Expand Up @@ -250,7 +224,7 @@ def _bytes_literal(self, bs):
return x

def _tuple_literal(self, t):
return "(%s)" % (','.join(map(self.literal, t)))
return b"(%s)" % (b','.join(map(self.literal, t)))

def literal(self, o):
"""If o is a single object, returns an SQL literal as a string.
Expand All @@ -260,29 +234,27 @@ def literal(self, o):
Non-standard. For internal use; do not use this in your
applications.
"""
if isinstance(o, bytearray):
if isinstance(o, unicode):
s = self.string_literal(o.encode(self.encoding))
elif isinstance(o, bytearray):
s = self._bytes_literal(o)
elif not PY2 and isinstance(o, bytes):
s = self._bytes_literal(o)
elif isinstance(o, (tuple, list)):
s = self._tuple_literal(o)
else:
s = self.escape(o, self.encoders)
# Python 3(~3.4) doesn't support % operation for bytes object.
# We should decode it before using %.
# Decoding with ascii and surrogateescape allows convert arbitrary
# bytes to unicode and back again.
# See http://python.org/dev/peps/pep-0383/
if not PY2 and isinstance(s, (bytes, bytearray)):
return _fast_surrogateescape(s)
if isinstance(s, unicode):
s = s.encode(self.encoding)
assert isinstance(s, bytes)
return s

def begin(self):
"""Explicitly begin a connection.

This method is not used when autocommit=False (default).
"""
self.query("BEGIN")
self.query(b"BEGIN")

if not hasattr(_mysql.connection, 'warning_count'):

Expand Down
4 changes: 2 additions & 2 deletions MySQLdb/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def Str2Set(s):

def Set2Str(s, d):
# Only support ascii string. Not tested.
return string_literal(','.join(s), d)
return string_literal(','.join(s))

def Thing2Str(s, d):
"""Convert something into a string via str()."""
Expand All @@ -80,7 +80,7 @@ def Thing2Literal(o, d):
MySQL-3.23 or newer, string_literal() is a method of the
_mysql.MYSQL object, and this function will be overridden with
that method when the connection is created."""
return string_literal(o, d)
return string_literal(o)

def Decimal2Literal(o, d):
return format(o, 'f')
Expand Down
90 changes: 28 additions & 62 deletions MySQLdb/cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@
NotSupportedError, ProgrammingError)


PY2 = sys.version_info[0] == 2
if PY2:
text_type = unicode
else:
text_type = str


#: Regular expression for :meth:`Cursor.executemany`.
#: executemany only supports simple bulk insert.
#: You can use it to load large dataset.
Expand Down Expand Up @@ -95,31 +88,28 @@ def __exit__(self, *exc_info):
del exc_info
self.close()

def _ensure_bytes(self, x, encoding=None):
if isinstance(x, text_type):
x = x.encode(encoding)
elif isinstance(x, (tuple, list)):
x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
return x

def _escape_args(self, args, conn):
ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
encoding = conn.encoding
literal = conn.literal

def ensure_bytes(x):
if isinstance(x, unicode):
return x.encode(encoding)
elif isinstance(x, tuple):
return tuple(map(ensure_bytes, x))
elif isinstance(x, list):
return list(map(ensure_bytes, x))
return x

if isinstance(args, (tuple, list)):
if PY2:
args = tuple(map(ensure_bytes, args))
return tuple(conn.literal(arg) for arg in args)
return tuple(literal(ensure_bytes(arg)) for arg in args)
elif isinstance(args, dict):
if PY2:
args = dict((ensure_bytes(key), ensure_bytes(val)) for
(key, val) in args.items())
return dict((key, conn.literal(val)) for (key, val) in args.items())
return {ensure_bytes(key): literal(ensure_bytes(val))
for (key, val) in args.items()}
else:
# If it's not a dictionary let's try escaping it anyways.
# Worst case it will throw a Value error
if PY2:
args = ensure_bytes(args)
return conn.literal(args)
return literal(ensure_bytes(args))

def _check_executed(self):
if not self._executed:
Expand Down Expand Up @@ -186,31 +176,20 @@ def execute(self, query, args=None):
pass
db = self._get_db()

# NOTE:
# Python 2: query should be bytes when executing %.
# All unicode in args should be encoded to bytes on Python 2.
# Python 3: query should be str (unicode) when executing %.
# All bytes in args should be decoded with ascii and surrogateescape on Python 3.
# db.literal(obj) always returns str.

if PY2 and isinstance(query, unicode):
if isinstance(query, unicode):
query = query.encode(db.encoding)

if args is not None:
if isinstance(args, dict):
args = dict((key, db.literal(item)) for key, item in args.items())
else:
args = tuple(map(db.literal, args))
if not PY2 and isinstance(query, (bytes, bytearray)):
query = query.decode(db.encoding)
try:
query = query % args
except TypeError as m:
raise ProgrammingError(str(m))

if isinstance(query, unicode):
query = query.encode(db.encoding, 'surrogateescape')

assert isinstance(query, (bytes, bytearray))
res = self._query(query)
return res

Expand Down Expand Up @@ -247,29 +226,19 @@ def executemany(self, query, args):
def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
conn = self._get_db()
escape = self._escape_args
if isinstance(prefix, text_type):
if isinstance(prefix, unicode):
prefix = prefix.encode(encoding)
if PY2 and isinstance(values, text_type):
if isinstance(values, unicode):
values = values.encode(encoding)
if isinstance(postfix, text_type):
if isinstance(postfix, unicode):
postfix = postfix.encode(encoding)
sql = bytearray(prefix)
args = iter(args)
v = values % escape(next(args), conn)
if isinstance(v, text_type):
if PY2:
v = v.encode(encoding)
else:
v = v.encode(encoding, 'surrogateescape')
sql += v
rows = 0
for arg in args:
v = values % escape(arg, conn)
if isinstance(v, text_type):
if PY2:
v = v.encode(encoding)
else:
v = v.encode(encoding, 'surrogateescape')
if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
rows += self.execute(sql + postfix)
sql = bytearray(prefix)
Expand Down Expand Up @@ -308,22 +277,19 @@ def callproc(self, procname, args=()):
to advance through all result sets; otherwise you may get
disconnected.
"""

db = self._get_db()
if isinstance(procname, unicode):
procname = procname.encode(db.encoding)
if args:
fmt = '@_{0}_%d=%s'.format(procname)
q = 'SET %s' % ','.join(fmt % (index, db.literal(arg))
for index, arg in enumerate(args))
if isinstance(q, unicode):
q = q.encode(db.encoding, 'surrogateescape')
fmt = b'@_' + procname + b'_%d=%s'
q = b'SET %s' % b','.join(fmt % (index, db.literal(arg))
for index, arg in enumerate(args))
self._query(q)
self.nextset()

q = "CALL %s(%s)" % (procname,
','.join(['@_%s_%d' % (procname, i)
for i in range(len(args))]))
if isinstance(q, unicode):
q = q.encode(db.encoding, 'surrogateescape')
q = b"CALL %s(%s)" % (procname,
b','.join([b'@_%s_%d' % (procname, i)
for i in range(len(args))]))
self._query(q)
return args

Expand Down
4 changes: 2 additions & 2 deletions MySQLdb/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,11 @@ def Date_or_None(s):

def DateTime2literal(d, c):
"""Format a DateTime object as an ISO timestamp."""
return string_literal(format_TIMESTAMP(d), c)
return string_literal(format_TIMESTAMP(d))

def DateTimeDelta2literal(d, c):
"""Format a DateTimeDelta object as a time."""
return string_literal(format_TIMEDELTA(d),c)
return string_literal(format_TIMEDELTA(d))

def mysql_timestamp_converter(s):
"""Convert a MySQL TIMESTAMP to a Timestamp object."""
Expand Down