@@ -943,7 +943,7 @@ _mysql_escape_string(
943
943
{
944
944
PyObject * str ;
945
945
char * in , * out ;
946
- int len ;
946
+ unsigned long len ;
947
947
Py_ssize_t size ;
948
948
if (!PyArg_ParseTuple (args , "s#:escape_string" , & in , & size )) return NULL ;
949
949
str = PyBytes_FromStringAndSize ((char * ) NULL , size * 2 + 1 );
@@ -980,35 +980,52 @@ _mysql_string_literal(
980
980
_mysql_ConnectionObject * self ,
981
981
PyObject * o )
982
982
{
983
- PyObject * str , * s ;
984
- char * in , * out ;
985
- unsigned long len ;
986
- Py_ssize_t size ;
983
+ PyObject * s ; // input string or bytes. need to decref.
987
984
988
985
if (self && PyModule_Check ((PyObject * )self ))
989
986
self = NULL ;
990
987
991
988
if (PyBytes_Check (o )) {
992
989
s = o ;
993
990
Py_INCREF (s );
994
- } else {
995
- s = PyObject_Str (o );
996
- if (!s ) return NULL ;
997
- {
998
- PyObject * t = PyUnicode_AsASCIIString (s );
999
- Py_DECREF (s );
1000
- if (!t ) return NULL ;
991
+ }
992
+ else {
993
+ PyObject * t = PyObject_Str (o );
994
+ if (!t ) return NULL ;
995
+
996
+ const char * encoding = (self && self -> open ) ?
997
+ _get_encoding (& self -> connection ) : utf8 ;
998
+ if (encoding == utf8 ) {
1001
999
s = t ;
1002
1000
}
1001
+ else {
1002
+ s = PyUnicode_AsEncodedString (t , encoding , "strict" );
1003
+ Py_DECREF (t );
1004
+ if (!s ) return NULL ;
1005
+ }
1003
1006
}
1004
- in = PyBytes_AsString (s );
1005
- size = PyBytes_GET_SIZE (s );
1006
- str = PyBytes_FromStringAndSize ((char * ) NULL , size * 2 + 3 );
1007
+
1008
+ // Prepare input string (in, size)
1009
+ const char * in ;
1010
+ Py_ssize_t size ;
1011
+ if (PyUnicode_Check (s )) {
1012
+ in = PyUnicode_AsUTF8AndSize (s , & size );
1013
+ } else {
1014
+ assert (PyBytes_Check (s ));
1015
+ in = PyBytes_AsString (s );
1016
+ size = PyBytes_GET_SIZE (s );
1017
+ }
1018
+
1019
+ // Prepare output buffer (str, out)
1020
+ PyObject * str = PyBytes_FromStringAndSize ((char * ) NULL , size * 2 + 3 );
1007
1021
if (!str ) {
1008
1022
Py_DECREF (s );
1009
1023
return PyErr_NoMemory ();
1010
1024
}
1011
- out = PyBytes_AS_STRING (str );
1025
+ char * out = PyBytes_AS_STRING (str );
1026
+
1027
+ // escape
1028
+ unsigned long len ;
1012
1029
if (self && self -> open ) {
1013
1030
#if MYSQL_VERSION_ID >= 50707 && !defined(MARIADB_BASE_VERSION ) && !defined(MARIADB_VERSION_ID )
1014
1031
len = mysql_real_escape_string_quote (& (self -> connection ), out + 1 , in , size , '\'' );
@@ -1018,10 +1035,14 @@ _mysql_string_literal(
1018
1035
} else {
1019
1036
len = mysql_escape_string (out + 1 , in , size );
1020
1037
}
1021
- * out = * (out + len + 1 ) = '\'' ;
1022
- if (_PyBytes_Resize (& str , len + 2 ) < 0 ) return NULL ;
1038
+
1023
1039
Py_DECREF (s );
1024
- return (str );
1040
+ * out = * (out + len + 1 ) = '\'' ;
1041
+ if (_PyBytes_Resize (& str , len + 2 ) < 0 ) {
1042
+ Py_DECREF (str );
1043
+ return NULL ;
1044
+ }
1045
+ return str ;
1025
1046
}
1026
1047
1027
1048
static PyObject *
@@ -1499,8 +1520,9 @@ _mysql_ResultObject_discard(
1499
1520
// do nothing
1500
1521
}
1501
1522
Py_END_ALLOW_THREADS
1502
- if (mysql_errno (self -> conn )) {
1503
- return _mysql_Exception (self -> conn );
1523
+ _mysql_ConnectionObject * conn = (_mysql_ConnectionObject * )self -> conn ;
1524
+ if (mysql_errno (& conn -> connection )) {
1525
+ return _mysql_Exception (conn );
1504
1526
}
1505
1527
Py_RETURN_NONE ;
1506
1528
}
0 commit comments