From 88974c5e9a7d2eadec8f5389847db4dfae68402d Mon Sep 17 00:00:00 2001 From: Inada Naoki Date: Sat, 20 Oct 2018 11:55:54 +0900 Subject: [PATCH] Add missing checks for connection before calling mysql APIs Fixes #270 --- _mysql.c | 102 ++++++++++++++++++++++++++++++------------------------- 1 file changed, 55 insertions(+), 47 deletions(-) diff --git a/_mysql.c b/_mysql.c index a4f6976f..cea98203 100644 --- a/_mysql.c +++ b/_mysql.c @@ -69,9 +69,14 @@ typedef struct { PyObject *converter; } _mysql_ConnectionObject; -#define check_connection(c) if (!(c->open)) return _mysql_Exception(c) +#define check_connection(c, func) \ + if (!(c->open)) { \ + PyErr_SetString(_mysql_ProgrammingError, func "() is called for closed connection"); \ + return NULL; \ + }; + #define result_connection(r) ((_mysql_ConnectionObject *)r->conn) -#define check_result_connection(r) check_connection(result_connection(r)) +#define check_result_connection(r, func) check_connection(result_connection(r), func) extern PyTypeObject _mysql_ConnectionObject_Type; @@ -750,6 +755,7 @@ static PyObject * _mysql_ConnectionObject_fileno( _mysql_ConnectionObject *self) { + check_connection(self, "fileno"); return PyInt_FromLong(self->connection.net.fd); } @@ -761,16 +767,11 @@ _mysql_ConnectionObject_close( _mysql_ConnectionObject *self, PyObject *noargs) { - if (self->open) { - Py_BEGIN_ALLOW_THREADS - mysql_close(&(self->connection)); - Py_END_ALLOW_THREADS - self->open = 0; - } else { - PyErr_SetString(_mysql_ProgrammingError, - "closing a closed connection"); - return NULL; - } + check_connection(self, "close"); + Py_BEGIN_ALLOW_THREADS + mysql_close(&(self->connection)); + Py_END_ALLOW_THREADS + self->open = 0; _mysql_ConnectionObject_clear(self); Py_RETURN_NONE; } @@ -786,7 +787,7 @@ _mysql_ConnectionObject_affected_rows( PyObject *noargs) { my_ulonglong ret; - check_connection(self); + check_connection(self, "affected_rows"); ret = mysql_affected_rows(&(self->connection)); if (ret == (my_ulonglong)-1) return PyInt_FromLong(-1); @@ -823,7 +824,7 @@ _mysql_ConnectionObject_dump_debug_info( PyObject *noargs) { int err; - check_connection(self); + check_connection(self, "dump_debug_info"); Py_BEGIN_ALLOW_THREADS err = mysql_dump_debug_info(&(self->connection)); Py_END_ALLOW_THREADS @@ -842,6 +843,7 @@ _mysql_ConnectionObject_autocommit( { int flag, err; if (!PyArg_ParseTuple(args, "i", &flag)) return NULL; + check_connection(self, "autocommit"); Py_BEGIN_ALLOW_THREADS err = mysql_autocommit(&(self->connection), flag); Py_END_ALLOW_THREADS @@ -858,6 +860,7 @@ _mysql_ConnectionObject_get_autocommit( _mysql_ConnectionObject *self, PyObject *args) { + check_connection(self, "get_autocommit"); if (self->connection.server_status & SERVER_STATUS_AUTOCOMMIT) { Py_RETURN_TRUE; } @@ -873,6 +876,7 @@ _mysql_ConnectionObject_commit( PyObject *noargs) { int err; + check_connection(self, "commit"); Py_BEGIN_ALLOW_THREADS err = mysql_commit(&(self->connection)); Py_END_ALLOW_THREADS @@ -890,13 +894,13 @@ _mysql_ConnectionObject_rollback( PyObject *noargs) { int err; + check_connection(self, "rollback"); Py_BEGIN_ALLOW_THREADS err = mysql_rollback(&(self->connection)); Py_END_ALLOW_THREADS if (err) return _mysql_Exception(self); - Py_INCREF(Py_None); - return Py_None; -} + Py_RETURN_NONE; +} static char _mysql_ConnectionObject_next_result__doc__[] = "If more query results exist, next_result() reads the next query\n\ @@ -917,6 +921,7 @@ _mysql_ConnectionObject_next_result( PyObject *noargs) { int err; + check_connection(self, "next_result"); Py_BEGIN_ALLOW_THREADS err = mysql_next_result(&(self->connection)); Py_END_ALLOW_THREADS @@ -939,6 +944,7 @@ _mysql_ConnectionObject_set_server_option( int err, flags=0; if (!PyArg_ParseTuple(args, "i", &flags)) return NULL; + check_connection(self, "set_server_option"); Py_BEGIN_ALLOW_THREADS err = mysql_set_server_option(&(self->connection), flags); Py_END_ALLOW_THREADS @@ -963,6 +969,7 @@ _mysql_ConnectionObject_sqlstate( _mysql_ConnectionObject *self, PyObject *noargs) { + check_connection(self, "sqlstate"); return PyString_FromString(mysql_sqlstate(&(self->connection))); } @@ -977,6 +984,7 @@ _mysql_ConnectionObject_warning_count( _mysql_ConnectionObject *self, PyObject *noargs) { + check_connection(self, "warning_count"); return PyInt_FromLong(mysql_warning_count(&(self->connection))); } @@ -991,7 +999,7 @@ _mysql_ConnectionObject_errno( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); + check_connection(self, "errno"); return PyInt_FromLong((long)mysql_errno(&(self->connection))); } @@ -1006,7 +1014,7 @@ _mysql_ConnectionObject_error( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); + check_connection(self, "error"); return PyString_FromString(mysql_error(&(self->connection))); } @@ -1249,7 +1257,7 @@ _mysql_ResultObject_describe( PyObject *d; MYSQL_FIELD *fields; unsigned int i, n; - check_result_connection(self); + check_result_connection(self, "describe"); n = mysql_num_fields(self->result); fields = mysql_fetch_fields(self->result); if (!(d = PyTuple_New(n))) return NULL; @@ -1284,7 +1292,7 @@ _mysql_ResultObject_field_flags( PyObject *d; MYSQL_FIELD *fields; unsigned int i, n; - check_result_connection(self); + check_result_connection(self, "field_flags"); n = mysql_num_fields(self->result); fields = mysql_fetch_fields(self->result); if (!(d = PyTuple_New(n))) return NULL; @@ -1523,7 +1531,7 @@ _mysql_ResultObject_fetch_row( if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|ii:fetch_row", kwlist, &maxrows, &how)) return NULL; - check_result_connection(self); + check_result_connection(self, "fetch_row"); if (how >= (int)sizeof(row_converters)) { PyErr_SetString(PyExc_ValueError, "how out of range"); return NULL; @@ -1592,7 +1600,7 @@ _mysql_ConnectionObject_change_user( if (!PyArg_ParseTupleAndKeywords(args, kwargs, "s|ss:change_user", kwlist, &user, &pwd, &db)) return NULL; - check_connection(self); + check_connection(self, "change_user"); Py_BEGIN_ALLOW_THREADS r = mysql_change_user(&(self->connection), user, pwd, db); Py_END_ALLOW_THREADS @@ -1612,7 +1620,7 @@ _mysql_ConnectionObject_character_set_name( PyObject *noargs) { const char *s; - check_connection(self); + check_connection(self, "character_set_name"); s = mysql_character_set_name(&(self->connection)); return PyString_FromString(s); } @@ -1630,7 +1638,7 @@ _mysql_ConnectionObject_set_character_set( const char *s; int err; if (!PyArg_ParseTuple(args, "s", &s)) return NULL; - check_connection(self); + check_connection(self, "set_character_set"); Py_BEGIN_ALLOW_THREADS err = mysql_set_character_set(&(self->connection), s); Py_END_ALLOW_THREADS @@ -1669,7 +1677,7 @@ _mysql_ConnectionObject_get_character_set_info( PyObject *result; MY_CHARSET_INFO cs; - check_connection(self); + check_connection(self, "get_character_set_info"); mysql_get_character_set_info(&(self->connection), &cs); if (!(result = PyDict_New())) return NULL; if (cs.csname) @@ -1701,7 +1709,7 @@ _mysql_ConnectionObject_get_native_connection( PyObject *noargs) { PyObject *result; - check_connection(self); + check_connection(self, "_get_native_connection"); result = PyCapsule_New(&(self->connection), "_mysql.connection.native_connection", NULL); return result; @@ -1730,7 +1738,7 @@ _mysql_ConnectionObject_get_host_info( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); + check_connection(self, "get_host_info"); return PyString_FromString(mysql_get_host_info(&(self->connection))); } @@ -1744,7 +1752,7 @@ _mysql_ConnectionObject_get_proto_info( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); + check_connection(self, "get_proto_info"); return PyInt_FromLong((long)mysql_get_proto_info(&(self->connection))); } @@ -1758,7 +1766,7 @@ _mysql_ConnectionObject_get_server_info( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); + check_connection(self, "get_server_info"); return PyString_FromString(mysql_get_server_info(&(self->connection))); } @@ -1774,7 +1782,7 @@ _mysql_ConnectionObject_info( PyObject *noargs) { const char *s; - check_connection(self); + check_connection(self, "info"); s = mysql_info(&(self->connection)); if (s) return PyString_FromString(s); Py_INCREF(Py_None); @@ -1808,7 +1816,7 @@ _mysql_ConnectionObject_insert_id( PyObject *noargs) { my_ulonglong r; - check_connection(self); + check_connection(self, "insert_id"); Py_BEGIN_ALLOW_THREADS r = mysql_insert_id(&(self->connection)); Py_END_ALLOW_THREADS @@ -1827,7 +1835,7 @@ _mysql_ConnectionObject_kill( unsigned long pid; int r; if (!PyArg_ParseTuple(args, "k:kill", &pid)) return NULL; - check_connection(self); + check_connection(self, "kill"); Py_BEGIN_ALLOW_THREADS r = mysql_kill(&(self->connection), pid); Py_END_ALLOW_THREADS @@ -1847,7 +1855,7 @@ _mysql_ConnectionObject_field_count( _mysql_ConnectionObject *self, PyObject *noargs) { - check_connection(self); + check_connection(self, "field_count"); return PyInt_FromLong((long)mysql_field_count(&(self->connection))); } @@ -1859,7 +1867,7 @@ _mysql_ResultObject_num_fields( _mysql_ResultObject *self, PyObject *noargs) { - check_result_connection(self); + check_result_connection(self, "num_fields"); return PyInt_FromLong((long)mysql_num_fields(self->result)); } @@ -1874,7 +1882,7 @@ _mysql_ResultObject_num_rows( _mysql_ResultObject *self, PyObject *noargs) { - check_result_connection(self); + check_result_connection(self, "num_rows"); return PyLong_FromUnsignedLongLong(mysql_num_rows(self->result)); } @@ -1904,7 +1912,7 @@ _mysql_ConnectionObject_ping( { int r, reconnect = -1; if (!PyArg_ParseTuple(args, "|I", &reconnect)) return NULL; - check_connection(self); + check_connection(self, "ping"); if (reconnect != -1) { my_bool recon = (my_bool)reconnect; mysql_options(&self->connection, MYSQL_OPT_RECONNECT, &recon); @@ -1931,7 +1939,7 @@ _mysql_ConnectionObject_query( char *query; int len, r; if (!PyArg_ParseTuple(args, "s#:query", &query, &len)) return NULL; - check_connection(self); + check_connection(self, "query"); Py_BEGIN_ALLOW_THREADS r = mysql_real_query(&(self->connection), query, len); @@ -1955,7 +1963,7 @@ _mysql_ConnectionObject_send_query( int len, r; MYSQL *mysql = &(self->connection); if (!PyArg_ParseTuple(args, "s#:query", &query, &len)) return NULL; - check_connection(self); + check_connection(self, "send_query"); Py_BEGIN_ALLOW_THREADS r = mysql_send_query(mysql, query, len); @@ -1976,7 +1984,7 @@ _mysql_ConnectionObject_read_query_result( { int r; MYSQL *mysql = &(self->connection); - check_connection(self); + check_connection(self, "reqd_query_result"); Py_BEGIN_ALLOW_THREADS r = (int)mysql_read_query_result(mysql); @@ -2006,7 +2014,7 @@ _mysql_ConnectionObject_select_db( char *db; int r; if (!PyArg_ParseTuple(args, "s:select_db", &db)) return NULL; - check_connection(self); + check_connection(self, "select_db"); Py_BEGIN_ALLOW_THREADS r = mysql_select_db(&(self->connection), db); Py_END_ALLOW_THREADS @@ -2026,7 +2034,7 @@ _mysql_ConnectionObject_shutdown( PyObject *noargs) { int r; - check_connection(self); + check_connection(self, "shutdown"); Py_BEGIN_ALLOW_THREADS r = mysql_shutdown(&(self->connection), SHUTDOWN_DEFAULT); Py_END_ALLOW_THREADS @@ -2048,7 +2056,7 @@ _mysql_ConnectionObject_stat( PyObject *noargs) { const char *s; - check_connection(self); + check_connection(self, "stat"); Py_BEGIN_ALLOW_THREADS s = mysql_stat(&(self->connection)); Py_END_ALLOW_THREADS @@ -2070,7 +2078,7 @@ _mysql_ConnectionObject_store_result( PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL; _mysql_ResultObject *r=NULL; - check_connection(self); + check_connection(self, "store_result"); arglist = Py_BuildValue("(OiO)", self, 0, self->converter); if (!arglist) goto error; kwarglist = PyDict_New(); @@ -2108,7 +2116,7 @@ _mysql_ConnectionObject_thread_id( PyObject *noargs) { unsigned long pid; - check_connection(self); + check_connection(self, "thread_id"); Py_BEGIN_ALLOW_THREADS pid = mysql_thread_id(&(self->connection)); Py_END_ALLOW_THREADS @@ -2129,7 +2137,7 @@ _mysql_ConnectionObject_use_result( PyObject *arglist=NULL, *kwarglist=NULL, *result=NULL; _mysql_ResultObject *r=NULL; - check_connection(self); + check_connection(self, "use_result"); arglist = Py_BuildValue("(OiO)", self, 1, self->converter); if (!arglist) return NULL; kwarglist = PyDict_New(); @@ -2187,7 +2195,7 @@ _mysql_ResultObject_data_seek( { unsigned int row; if (!PyArg_ParseTuple(args, "i:data_seek", &row)) return NULL; - check_result_connection(self); + check_result_connection(self, "data_seek"); mysql_data_seek(self->result, row); Py_INCREF(Py_None); return Py_None;