diff --git a/MySQLdb/_mysql.c b/MySQLdb/_mysql.c index b45682df..6c12b923 100644 --- a/MySQLdb/_mysql.c +++ b/MySQLdb/_mysql.c @@ -915,14 +915,15 @@ _mysql.string_literal(obj) cannot handle character sets."; static PyObject * _mysql_string_literal( _mysql_ConnectionObject *self, - PyObject *args) + PyObject *o) { - PyObject *str, *s, *o, *d; + PyObject *str, *s; char *in, *out; int len, size; + if (self && PyModule_Check((PyObject*)self)) self = NULL; - if (!PyArg_ParseTuple(args, "O|O:string_literal", &o, &d)) return NULL; + if (PyBytes_Check(o)) { s = o; Py_INCREF(s); @@ -965,33 +966,25 @@ static PyObject *_mysql_NULL; static PyObject * _escape_item( + PyObject *self, PyObject *item, PyObject *d) { PyObject *quoted=NULL, *itemtype, *itemconv; - if (!(itemtype = PyObject_Type(item))) - goto error; + if (!(itemtype = PyObject_Type(item))) { + return NULL; + } itemconv = PyObject_GetItem(d, itemtype); Py_DECREF(itemtype); if (!itemconv) { PyErr_Clear(); - itemconv = PyObject_GetItem(d, -#ifdef IS_PY3K - (PyObject *) &PyUnicode_Type); -#else - (PyObject *) &PyString_Type); -#endif - } - if (!itemconv) { - PyErr_SetString(PyExc_TypeError, - "no default type converter defined"); - goto error; + return _mysql_string_literal((_mysql_ConnectionObject*)self, item); } Py_INCREF(d); quoted = PyObject_CallFunction(itemconv, "OO", item, d); Py_DECREF(d); Py_DECREF(itemconv); -error: + return quoted; } @@ -1013,14 +1006,14 @@ _mysql_escape( "argument 2 must be a mapping"); return NULL; } - return _escape_item(o, d); + return _escape_item(self, o, d); } else { if (!self) { PyErr_SetString(PyExc_TypeError, "argument 2 must be a mapping"); return NULL; } - return _escape_item(o, + return _escape_item(self, o, ((_mysql_ConnectionObject *) self)->converter); } } @@ -2264,7 +2257,7 @@ static PyMethodDef _mysql_ConnectionObject_methods[] = { { "string_literal", (PyCFunction)_mysql_string_literal, - METH_VARARGS, + METH_O, _mysql_string_literal__doc__}, { "thread_id", @@ -2587,7 +2580,7 @@ _mysql_methods[] = { { "string_literal", (PyCFunction)_mysql_string_literal, - METH_VARARGS, + METH_O, _mysql_string_literal__doc__ }, { diff --git a/MySQLdb/connections.py b/MySQLdb/connections.py index 96b01528..27a8d437 100644 --- a/MySQLdb/connections.py +++ b/MySQLdb/connections.py @@ -16,18 +16,6 @@ ) -if not PY2: - if sys.version_info[:2] < (3, 6): - # See http://bugs.python.org/issue24870 - _surrogateescape_table = [chr(i) if i < 0x80 else chr(i + 0xdc00) for i in range(256)] - - def _fast_surrogateescape(s): - return s.decode('latin1').translate(_surrogateescape_table) - else: - def _fast_surrogateescape(s): - return s.decode('ascii', 'surrogateescape') - - re_numeric_part = re.compile(r"^(\d+)") def numeric_part(s): @@ -183,21 +171,8 @@ class object, used to create cursors (keyword only) self.encoding = 'ascii' # overridden in set_character_set() db = proxy(self) - # Note: string_literal() is called for bytes object on Python 3 (via bytes_literal) - def string_literal(obj, dummy=None): - return db.string_literal(obj) - - if PY2: - # unicode_literal is called for only unicode object. - def unicode_literal(u, dummy=None): - return db.string_literal(u.encode(db.encoding)) - else: - # unicode_literal() is called for arbitrary object. - def unicode_literal(u, dummy=None): - return db.string_literal(str(u).encode(db.encoding)) - - def bytes_literal(obj, dummy=None): - return b'_binary' + db.string_literal(obj) + def unicode_literal(u, dummy=None): + return db.string_literal(u.encode(db.encoding)) def string_decoder(s): return s.decode(db.encoding) @@ -214,7 +189,6 @@ def string_decoder(s): FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.LONG_BLOB, FIELD_TYPE.BLOB): self.converter[t].append((None, string_decoder)) - self.encoders[bytes] = string_literal self.encoders[unicode] = unicode_literal self._transactional = self.server_capabilities & CLIENT.TRANSACTIONS if self._transactional: @@ -250,7 +224,7 @@ def _bytes_literal(self, bs): return x def _tuple_literal(self, t): - return "(%s)" % (','.join(map(self.literal, t))) + return b"(%s)" % (b','.join(map(self.literal, t))) def literal(self, o): """If o is a single object, returns an SQL literal as a string. @@ -260,7 +234,9 @@ def literal(self, o): Non-standard. For internal use; do not use this in your applications. """ - if isinstance(o, bytearray): + if isinstance(o, unicode): + s = self.string_literal(o.encode(self.encoding)) + elif isinstance(o, bytearray): s = self._bytes_literal(o) elif not PY2 and isinstance(o, bytes): s = self._bytes_literal(o) @@ -268,13 +244,9 @@ def literal(self, o): s = self._tuple_literal(o) else: s = self.escape(o, self.encoders) - # Python 3(~3.4) doesn't support % operation for bytes object. - # We should decode it before using %. - # Decoding with ascii and surrogateescape allows convert arbitrary - # bytes to unicode and back again. - # See http://python.org/dev/peps/pep-0383/ - if not PY2 and isinstance(s, (bytes, bytearray)): - return _fast_surrogateescape(s) + if isinstance(s, unicode): + s = s.encode(self.encoding) + assert isinstance(s, bytes) return s def begin(self): @@ -282,7 +254,7 @@ def begin(self): This method is not used when autocommit=False (default). """ - self.query("BEGIN") + self.query(b"BEGIN") if not hasattr(_mysql.connection, 'warning_count'): diff --git a/MySQLdb/converters.py b/MySQLdb/converters.py index c13e4265..20d919f7 100644 --- a/MySQLdb/converters.py +++ b/MySQLdb/converters.py @@ -53,7 +53,7 @@ def Str2Set(s): def Set2Str(s, d): # Only support ascii string. Not tested. - return string_literal(','.join(s), d) + return string_literal(','.join(s)) def Thing2Str(s, d): """Convert something into a string via str().""" @@ -80,7 +80,7 @@ def Thing2Literal(o, d): MySQL-3.23 or newer, string_literal() is a method of the _mysql.MYSQL object, and this function will be overridden with that method when the connection is created.""" - return string_literal(o, d) + return string_literal(o) def Decimal2Literal(o, d): return format(o, 'f') diff --git a/MySQLdb/cursors.py b/MySQLdb/cursors.py index 97908249..9a5e76f9 100644 --- a/MySQLdb/cursors.py +++ b/MySQLdb/cursors.py @@ -15,13 +15,6 @@ NotSupportedError, ProgrammingError) -PY2 = sys.version_info[0] == 2 -if PY2: - text_type = unicode -else: - text_type = str - - #: Regular expression for :meth:`Cursor.executemany`. #: executemany only supports simple bulk insert. #: You can use it to load large dataset. @@ -95,31 +88,28 @@ def __exit__(self, *exc_info): del exc_info self.close() - def _ensure_bytes(self, x, encoding=None): - if isinstance(x, text_type): - x = x.encode(encoding) - elif isinstance(x, (tuple, list)): - x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x) - return x - def _escape_args(self, args, conn): - ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding) + encoding = conn.encoding + literal = conn.literal + + def ensure_bytes(x): + if isinstance(x, unicode): + return x.encode(encoding) + elif isinstance(x, tuple): + return tuple(map(ensure_bytes, x)) + elif isinstance(x, list): + return list(map(ensure_bytes, x)) + return x if isinstance(args, (tuple, list)): - if PY2: - args = tuple(map(ensure_bytes, args)) - return tuple(conn.literal(arg) for arg in args) + return tuple(literal(ensure_bytes(arg)) for arg in args) elif isinstance(args, dict): - if PY2: - args = dict((ensure_bytes(key), ensure_bytes(val)) for - (key, val) in args.items()) - return dict((key, conn.literal(val)) for (key, val) in args.items()) + return {ensure_bytes(key): literal(ensure_bytes(val)) + for (key, val) in args.items()} else: # If it's not a dictionary let's try escaping it anyways. # Worst case it will throw a Value error - if PY2: - args = ensure_bytes(args) - return conn.literal(args) + return literal(ensure_bytes(args)) def _check_executed(self): if not self._executed: @@ -186,14 +176,7 @@ def execute(self, query, args=None): pass db = self._get_db() - # NOTE: - # Python 2: query should be bytes when executing %. - # All unicode in args should be encoded to bytes on Python 2. - # Python 3: query should be str (unicode) when executing %. - # All bytes in args should be decoded with ascii and surrogateescape on Python 3. - # db.literal(obj) always returns str. - - if PY2 and isinstance(query, unicode): + if isinstance(query, unicode): query = query.encode(db.encoding) if args is not None: @@ -201,16 +184,12 @@ def execute(self, query, args=None): args = dict((key, db.literal(item)) for key, item in args.items()) else: args = tuple(map(db.literal, args)) - if not PY2 and isinstance(query, (bytes, bytearray)): - query = query.decode(db.encoding) try: query = query % args except TypeError as m: raise ProgrammingError(str(m)) - if isinstance(query, unicode): - query = query.encode(db.encoding, 'surrogateescape') - + assert isinstance(query, (bytes, bytearray)) res = self._query(query) return res @@ -247,29 +226,19 @@ def executemany(self, query, args): def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding): conn = self._get_db() escape = self._escape_args - if isinstance(prefix, text_type): + if isinstance(prefix, unicode): prefix = prefix.encode(encoding) - if PY2 and isinstance(values, text_type): + if isinstance(values, unicode): values = values.encode(encoding) - if isinstance(postfix, text_type): + if isinstance(postfix, unicode): postfix = postfix.encode(encoding) sql = bytearray(prefix) args = iter(args) v = values % escape(next(args), conn) - if isinstance(v, text_type): - if PY2: - v = v.encode(encoding) - else: - v = v.encode(encoding, 'surrogateescape') sql += v rows = 0 for arg in args: v = values % escape(arg, conn) - if isinstance(v, text_type): - if PY2: - v = v.encode(encoding) - else: - v = v.encode(encoding, 'surrogateescape') if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: rows += self.execute(sql + postfix) sql = bytearray(prefix) @@ -308,22 +277,19 @@ def callproc(self, procname, args=()): to advance through all result sets; otherwise you may get disconnected. """ - db = self._get_db() + if isinstance(procname, unicode): + procname = procname.encode(db.encoding) if args: - fmt = '@_{0}_%d=%s'.format(procname) - q = 'SET %s' % ','.join(fmt % (index, db.literal(arg)) - for index, arg in enumerate(args)) - if isinstance(q, unicode): - q = q.encode(db.encoding, 'surrogateescape') + fmt = b'@_' + procname + b'_%d=%s' + q = b'SET %s' % b','.join(fmt % (index, db.literal(arg)) + for index, arg in enumerate(args)) self._query(q) self.nextset() - q = "CALL %s(%s)" % (procname, - ','.join(['@_%s_%d' % (procname, i) - for i in range(len(args))])) - if isinstance(q, unicode): - q = q.encode(db.encoding, 'surrogateescape') + q = b"CALL %s(%s)" % (procname, + b','.join([b'@_%s_%d' % (procname, i) + for i in range(len(args))])) self._query(q) return args diff --git a/MySQLdb/times.py b/MySQLdb/times.py index 510d1c7c..a7eaa53b 100644 --- a/MySQLdb/times.py +++ b/MySQLdb/times.py @@ -124,11 +124,11 @@ def Date_or_None(s): def DateTime2literal(d, c): """Format a DateTime object as an ISO timestamp.""" - return string_literal(format_TIMESTAMP(d), c) + return string_literal(format_TIMESTAMP(d)) def DateTimeDelta2literal(d, c): """Format a DateTimeDelta object as a time.""" - return string_literal(format_TIMEDELTA(d),c) + return string_literal(format_TIMEDELTA(d)) def mysql_timestamp_converter(s): """Convert a MySQL TIMESTAMP to a Timestamp object."""