Skip to content

Commit b162ddd

Browse files
authored
Fix Connection.escape() with Unicode input (#608)
After aed1dd2, Connection.escape() used ASCII to escape Unicode input. This commit makes it uses connection encoding instead.
1 parent 44d0f7a commit b162ddd

File tree

1 file changed

+43
-21
lines changed

1 file changed

+43
-21
lines changed

src/MySQLdb/_mysql.c

+43-21
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ _mysql_escape_string(
943943
{
944944
PyObject *str;
945945
char *in, *out;
946-
int len;
946+
unsigned long len;
947947
Py_ssize_t size;
948948
if (!PyArg_ParseTuple(args, "s#:escape_string", &in, &size)) return NULL;
949949
str = PyBytes_FromStringAndSize((char *) NULL, size*2+1);
@@ -980,35 +980,52 @@ _mysql_string_literal(
980980
_mysql_ConnectionObject *self,
981981
PyObject *o)
982982
{
983-
PyObject *str, *s;
984-
char *in, *out;
985-
unsigned long len;
986-
Py_ssize_t size;
983+
PyObject *s; // input string or bytes. need to decref.
987984

988985
if (self && PyModule_Check((PyObject*)self))
989986
self = NULL;
990987

991988
if (PyBytes_Check(o)) {
992989
s = o;
993990
Py_INCREF(s);
994-
} else {
995-
s = PyObject_Str(o);
996-
if (!s) return NULL;
997-
{
998-
PyObject *t = PyUnicode_AsASCIIString(s);
999-
Py_DECREF(s);
1000-
if (!t) return NULL;
991+
}
992+
else {
993+
PyObject *t = PyObject_Str(o);
994+
if (!t) return NULL;
995+
996+
const char *encoding = (self && self->open) ?
997+
_get_encoding(&self->connection) : utf8;
998+
if (encoding == utf8) {
1001999
s = t;
10021000
}
1001+
else {
1002+
s = PyUnicode_AsEncodedString(t, encoding, "strict");
1003+
Py_DECREF(t);
1004+
if (!s) return NULL;
1005+
}
10031006
}
1004-
in = PyBytes_AsString(s);
1005-
size = PyBytes_GET_SIZE(s);
1006-
str = PyBytes_FromStringAndSize((char *) NULL, size*2+3);
1007+
1008+
// Prepare input string (in, size)
1009+
const char *in;
1010+
Py_ssize_t size;
1011+
if (PyUnicode_Check(s)) {
1012+
in = PyUnicode_AsUTF8AndSize(s, &size);
1013+
} else {
1014+
assert(PyBytes_Check(s));
1015+
in = PyBytes_AsString(s);
1016+
size = PyBytes_GET_SIZE(s);
1017+
}
1018+
1019+
// Prepare output buffer (str, out)
1020+
PyObject *str = PyBytes_FromStringAndSize((char *) NULL, size*2+3);
10071021
if (!str) {
10081022
Py_DECREF(s);
10091023
return PyErr_NoMemory();
10101024
}
1011-
out = PyBytes_AS_STRING(str);
1025+
char *out = PyBytes_AS_STRING(str);
1026+
1027+
// escape
1028+
unsigned long len;
10121029
if (self && self->open) {
10131030
#if MYSQL_VERSION_ID >= 50707 && !defined(MARIADB_BASE_VERSION) && !defined(MARIADB_VERSION_ID)
10141031
len = mysql_real_escape_string_quote(&(self->connection), out+1, in, size, '\'');
@@ -1018,10 +1035,14 @@ _mysql_string_literal(
10181035
} else {
10191036
len = mysql_escape_string(out+1, in, size);
10201037
}
1021-
*out = *(out+len+1) = '\'';
1022-
if (_PyBytes_Resize(&str, len+2) < 0) return NULL;
1038+
10231039
Py_DECREF(s);
1024-
return (str);
1040+
*out = *(out+len+1) = '\'';
1041+
if (_PyBytes_Resize(&str, len+2) < 0) {
1042+
Py_DECREF(str);
1043+
return NULL;
1044+
}
1045+
return str;
10251046
}
10261047

10271048
static PyObject *
@@ -1499,8 +1520,9 @@ _mysql_ResultObject_discard(
14991520
// do nothing
15001521
}
15011522
Py_END_ALLOW_THREADS
1502-
if (mysql_errno(self->conn)) {
1503-
return _mysql_Exception(self->conn);
1523+
_mysql_ConnectionObject *conn = (_mysql_ConnectionObject *)self->conn;
1524+
if (mysql_errno(&conn->connection)) {
1525+
return _mysql_Exception(conn);
15041526
}
15051527
Py_RETURN_NONE;
15061528
}

0 commit comments

Comments
 (0)