diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index d35334e5..c1fa6567 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -19,7 +19,7 @@ cdef enum ProtocolState: PROTOCOL_CANCELLED = 3 PROTOCOL_AUTH = 10 - PROTOCOL_PREPARE = 11 + PROTOCOL_PARSE_DESCRIBE = 11 PROTOCOL_BIND_EXECUTE = 12 PROTOCOL_BIND_EXECUTE_MANY = 13 PROTOCOL_CLOSE_STMT_PORTAL = 14 @@ -105,7 +105,7 @@ cdef class CoreProtocol: bint result_execute_completed cdef _process__auth(self, char mtype) - cdef _process__prepare(self, char mtype) + cdef _process__parse_describe(self, char mtype) cdef _process__bind_execute(self, char mtype) cdef _process__bind_execute_many(self, char mtype) cdef _process__close_stmt_portal(self, char mtype) diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index 3ac317bb..21498e7d 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -44,21 +44,56 @@ cdef class CoreProtocol: if mtype == b'S': # ParameterStatus self._parse_msg_parameter_status() - continue + elif mtype == b'A': # NotificationResponse self._parse_msg_notification() - continue + elif mtype == b'N': # 'N' - NoticeResponse self._on_notice(self._parse_msg_error_response(False)) - continue - if state == PROTOCOL_AUTH: + elif mtype == b'E': + # ErrorResponse + self._parse_msg_error_response(True) + # In all cases, except Auth, ErrorResponse will + # be followed by a ReadyForQuery, which is when + # _push_result() will be called. + if state == PROTOCOL_AUTH: + self._push_result() + + elif mtype == b'Z': + # ReadyForQuery + self._parse_msg_ready_for_query() + + if state != PROTOCOL_BIND_EXECUTE_MANY: + self._push_result() + + else: + if self.result_type == RESULT_FAILED: + self._push_result() + else: + try: + buf = next(self._execute_iter) + except StopIteration: + self._push_result() + except Exception as e: + self.result_type = RESULT_FAILED + self.result = e + self._push_result() + else: + # Next iteration over the executemany() + # arg sequence. + self._send_bind_message( + self._execute_portal_name, + self._execute_stmt_name, + buf, 0) + + elif state == PROTOCOL_AUTH: self._process__auth(mtype) - elif state == PROTOCOL_PREPARE: - self._process__prepare(mtype) + elif state == PROTOCOL_PARSE_DESCRIBE: + self._process__parse_describe(mtype) elif state == PROTOCOL_BIND_EXECUTE: self._process__bind_execute(mtype) @@ -93,42 +128,26 @@ cdef class CoreProtocol: elif state == PROTOCOL_CANCELLED: # discard all messages until the sync message - if mtype == b'E': - self._parse_msg_error_response(True) - elif mtype == b'Z': - self._parse_msg_ready_for_query() - self._push_result() - else: - self.buffer.consume_message() + self.buffer.consume_message() elif state == PROTOCOL_ERROR_CONSUME: # Error in protocol (on asyncpg side); # discard all messages until sync message - - if mtype == b'Z': - # Sync point, self to push the result - if self.result_type != RESULT_FAILED: - self.result_type = RESULT_FAILED - self.result = apg_exc.InternalClientError( - 'unknown error in protocol implementation') - - self._push_result() - - else: - self.buffer.consume_message() + self.buffer.consume_message() else: raise apg_exc.InternalClientError( 'protocol is in an unknown state {}'.format(state)) except Exception as ex: + self.state = PROTOCOL_ERROR_CONSUME self.result_type = RESULT_FAILED self.result = ex if mtype == b'Z': + # This should only happen if _parse_msg_ready_for_query() + # has failed. self._push_result() - else: - self.state = PROTOCOL_ERROR_CONSUME finally: if self._skip_discard: @@ -153,43 +172,27 @@ cdef class CoreProtocol: # BackendKeyData self._parse_msg_backend_key_data() - elif mtype == b'E': - # ErrorResponse - self.con_status = CONNECTION_BAD - self._parse_msg_error_response(True) - self._push_result() - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - self.con_status = CONNECTION_OK - self._push_result() - - cdef _process__prepare(self, char mtype): - if mtype == b't': - # Parameters description - self.result_param_desc = self.buffer.consume_message().as_bytes() + # push_result() will be initiated by handling + # ReadyForQuery or ErrorResponse in the main loop. - elif mtype == b'1': + cdef _process__parse_describe(self, char mtype): + if mtype == b'1': # ParseComplete self.buffer.consume_message() + elif mtype == b't': + # ParameterDescription + self.result_param_desc = self.buffer.consume_message().as_bytes() + elif mtype == b'T': - # Row description + # RowDescription self.result_row_desc = self.buffer.consume_message().as_bytes() - - elif mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() self._push_result() elif mtype == b'n': # NoData self.buffer.consume_message() + self._push_result() cdef _process__bind_execute(self, char mtype): if mtype == b'D': @@ -199,28 +202,22 @@ cdef class CoreProtocol: elif mtype == b's': # PortalSuspended self.buffer.consume_message() + self._push_result() elif mtype == b'C': # CommandComplete self.result_execute_completed = True self._parse_msg_command_complete() - - elif mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) + self._push_result() elif mtype == b'2': # BindComplete self.buffer.consume_message() - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - self._push_result() - elif mtype == b'I': # EmptyQueryResponse self.buffer.consume_message() + self._push_result() cdef _process__bind_execute_many(self, char mtype): cdef WriteBuffer buf @@ -237,64 +234,24 @@ cdef class CoreProtocol: # CommandComplete self._parse_msg_command_complete() - elif mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) - elif mtype == b'2': # BindComplete self.buffer.consume_message() - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - if self.result_type == RESULT_FAILED: - self._push_result() - else: - try: - buf = next(self._execute_iter) - except StopIteration: - self._push_result() - except Exception as e: - self.result_type = RESULT_FAILED - self.result = e - self._push_result() - else: - # Next iteration over the executemany() arg sequence - self._send_bind_message( - self._execute_portal_name, self._execute_stmt_name, - buf, 0) - elif mtype == b'I': # EmptyQueryResponse self.buffer.consume_message() cdef _process__bind(self, char mtype): - if mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) - - elif mtype == b'2': + if mtype == b'2': # BindComplete self.buffer.consume_message() - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() self._push_result() cdef _process__close_stmt_portal(self, char mtype): - if mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) - - elif mtype == b'3': + if mtype == b'3': # CloseComplete self.buffer.consume_message() - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() self._push_result() cdef _process__simple_query(self, char mtype): @@ -304,42 +261,21 @@ cdef class CoreProtocol: # 'T' - RowDescription self.buffer.consume_message() - elif mtype == b'E': - # ErrorResponse - self._parse_msg_error_response(True) - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - self._push_result() - elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() - else: # We don't really care about COPY IN etc self.buffer.consume_message() cdef _process__copy_out(self, char mtype): - if mtype == b'E': - self._parse_msg_error_response(True) - - elif mtype == b'H': + if mtype == b'H': # CopyOutResponse self._set_state(PROTOCOL_COPY_OUT_DATA) self.buffer.consume_message() - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - self._push_result() - cdef _process__copy_out_data(self, char mtype): - if mtype == b'E': - self._parse_msg_error_response(True) - - elif mtype == b'd': + if mtype == b'd': # CopyData self._parse_copy_data_msgs() @@ -351,37 +287,18 @@ cdef class CoreProtocol: elif mtype == b'C': # CommandComplete self._parse_msg_command_complete() - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() self._push_result() cdef _process__copy_in(self, char mtype): - if mtype == b'E': - self._parse_msg_error_response(True) - - elif mtype == b'G': + if mtype == b'G': # CopyInResponse self._set_state(PROTOCOL_COPY_IN_DATA) self.buffer.consume_message() - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() - self._push_result() - cdef _process__copy_in_data(self, char mtype): - if mtype == b'E': - self._parse_msg_error_response(True) - - elif mtype == b'C': + if mtype == b'C': # CommandComplete self._parse_msg_command_complete() - - elif mtype == b'Z': - # ReadyForQuery - self._parse_msg_ready_for_query() self._push_result() cdef _parse_msg_command_complete(self): @@ -739,7 +656,7 @@ cdef class CoreProtocol: WriteBuffer buf self._ensure_connected() - self._set_state(PROTOCOL_PREPARE) + self._set_state(PROTOCOL_PARSE_DESCRIBE) buf = WriteBuffer.new_message(b'P') buf.write_str(stmt_name, self.encoding) diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index e137c74b..3c222db8 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -713,6 +713,7 @@ cdef class BaseProtocol(CoreProtocol): return self.waiter cdef _on_result__connect(self, object waiter): + self.con_status = CONNECTION_OK waiter.set_result(True) cdef _on_result__prepare(self, object waiter): @@ -790,6 +791,10 @@ cdef class BaseProtocol(CoreProtocol): self.result, query=self.last_query) else: exc = self.result + + if self.state == PROTOCOL_AUTH: + self.con_status = CONNECTION_BAD + waiter.set_exception(exc) return @@ -797,7 +802,7 @@ cdef class BaseProtocol(CoreProtocol): if self.state == PROTOCOL_AUTH: self._on_result__connect(waiter) - elif self.state == PROTOCOL_PREPARE: + elif self.state == PROTOCOL_PARSE_DESCRIBE: self._on_result__prepare(waiter) elif self.state == PROTOCOL_BIND_EXECUTE: @@ -847,11 +852,12 @@ cdef class BaseProtocol(CoreProtocol): self.cancel_waiter = None if self.waiter is not None and self.waiter.done(): self.waiter = None - if self.waiter is None: - return try: - self._dispatch_result() + if self.waiter is not None: + # _on_result() might be called several times in the + # process, or the waiter might have been cancelled. + self._dispatch_result() finally: self.statement = None self.last_query = None