diff --git a/src/MySqlConnector/MySqlClient/MySqlDataReader.cs b/src/MySqlConnector/MySqlClient/MySqlDataReader.cs index bee8e65e3..42fb132ad 100644 --- a/src/MySqlConnector/MySqlClient/MySqlDataReader.cs +++ b/src/MySqlConnector/MySqlClient/MySqlDataReader.cs @@ -71,7 +71,14 @@ private void ActivateResultSet(ResultSet resultSet) { if (resultSet.ReadResultSetHeaderException != null) { - throw resultSet.ReadResultSetHeaderException is MySqlException mySqlException ? + var mySqlException = resultSet.ReadResultSetHeaderException as MySqlException; + + // for any exception not created from an ErrorPayload, mark the session as failed (because we can't guarantee that all data + // has been read from the connection and that the socket is still usable) + if (mySqlException?.SqlState == null) + Command.Connection.Session.SetFailed(); + + throw mySqlException != null ? new MySqlException(mySqlException.Number, mySqlException.SqlState, mySqlException.Message, mySqlException) : resultSet.ReadResultSetHeaderException; } diff --git a/src/MySqlConnector/MySqlClient/Results/ResultSet.cs b/src/MySqlConnector/MySqlClient/Results/ResultSet.cs index 8961ae3bd..f427e30a1 100644 --- a/src/MySqlConnector/MySqlClient/Results/ResultSet.cs +++ b/src/MySqlConnector/MySqlClient/Results/ResultSet.cs @@ -79,6 +79,9 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior) { var reader = new ByteArrayReader(payload.ArraySegment); var columnCount = (int) reader.ReadLengthEncodedInteger(); + if (reader.BytesRemaining != 0) + throw new MySqlException("Unexpected data at end of column_count packet; see https://github.com/mysql-net/MySqlConnector/issues/324"); + ColumnDefinitions = new ColumnDefinitionPayload[columnCount]; m_dataOffsets = new int[columnCount]; m_dataLengths = new int[columnCount]; @@ -89,8 +92,11 @@ public async Task ReadResultSetHeaderAsync(IOBehavior ioBehavior) ColumnDefinitions[column] = ColumnDefinitionPayload.Create(payload); } - payload = await Session.ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); - EofPayload.Create(payload); + if (!Session.SupportsDeprecateEof) + { + payload = await Session.ReceiveReplyAsync(ioBehavior, CancellationToken.None).ConfigureAwait(false); + EofPayload.Create(payload); + } LastInsertId = -1; State = ResultSetState.ReadResultSetHeader; @@ -189,12 +195,22 @@ async Task ScanRowAsyncAwaited(Task payloadTask, CancellationT Row ScanRowAsyncRemainder(PayloadData payload) { - if (EofPayload.IsEof(payload)) + if (payload.HeaderByte == EofPayload.Signature) { - var eof = EofPayload.Create(payload); - BufferState = (eof.ServerStatus & ServerStatus.MoreResultsExist) == 0 ? ResultSetState.NoMoreData : ResultSetState.HasMoreData; - m_rowBuffered = null; - return null; + if (Session.SupportsDeprecateEof && OkPayload.IsOk(payload, Session.SupportsDeprecateEof)) + { + var ok = OkPayload.Create(payload, Session.SupportsDeprecateEof); + BufferState = (ok.ServerStatus & ServerStatus.MoreResultsExist) == 0 ? ResultSetState.NoMoreData : ResultSetState.HasMoreData; + m_rowBuffered = null; + return null; + } + if (!Session.SupportsDeprecateEof && EofPayload.IsEof(payload)) + { + var eof = EofPayload.Create(payload); + BufferState = (eof.ServerStatus & ServerStatus.MoreResultsExist) == 0 ? ResultSetState.NoMoreData : ResultSetState.HasMoreData; + m_rowBuffered = null; + return null; + } } var reader = new ByteArrayReader(payload.ArraySegment); diff --git a/src/MySqlConnector/Serialization/EofPayload.cs b/src/MySqlConnector/Serialization/EofPayload.cs index 54f6ef0ab..608dc6ecd 100644 --- a/src/MySqlConnector/Serialization/EofPayload.cs +++ b/src/MySqlConnector/Serialization/EofPayload.cs @@ -1,4 +1,4 @@ -using System; +using System; namespace MySql.Data.Serialization { @@ -32,7 +32,7 @@ public static EofPayload Create(PayloadData payload) public static bool IsEof(PayloadData payload) => payload.ArraySegment.Count > 0 && payload.ArraySegment.Count < 9 && payload.ArraySegment.Array[payload.ArraySegment.Offset] == Signature; - private const byte Signature = 0xFE; + public const byte Signature = 0xFE; private EofPayload(int warningCount, ServerStatus status) { diff --git a/src/MySqlConnector/Serialization/HandshakeResponse41Packet.cs b/src/MySqlConnector/Serialization/HandshakeResponse41Packet.cs index 14170c8d4..0a4c23332 100644 --- a/src/MySqlConnector/Serialization/HandshakeResponse41Packet.cs +++ b/src/MySqlConnector/Serialization/HandshakeResponse41Packet.cs @@ -1,4 +1,4 @@ -namespace MySql.Data.Serialization +namespace MySql.Data.Serialization { internal sealed class HandshakeResponse41Packet { @@ -14,12 +14,12 @@ private static PayloadWriter CreateCapabilitiesPayload(ProtocolCapabilities serv (serverCapabilities & ProtocolCapabilities.PluginAuthLengthEncodedClientData) | ProtocolCapabilities.MultiStatements | ProtocolCapabilities.MultiResults | - ProtocolCapabilities.PreparedStatementMultiResults | ProtocolCapabilities.LocalFiles | (string.IsNullOrWhiteSpace(cs.Database) ? 0 : ProtocolCapabilities.ConnectWithDatabase) | (cs.UseAffectedRows ? 0 : ProtocolCapabilities.FoundRows) | (useCompression ? ProtocolCapabilities.Compress : ProtocolCapabilities.None) | (serverCapabilities & ProtocolCapabilities.ConnectionAttributes) | + (serverCapabilities & ProtocolCapabilities.DeprecateEof) | additionalCapabilities)); writer.WriteInt32(0x4000_0000); writer.WriteByte((byte) CharacterSet.Utf8Mb4Binary); diff --git a/src/MySqlConnector/Serialization/MySqlSession.cs b/src/MySqlConnector/Serialization/MySqlSession.cs index 17d3ce3ec..67f8fb625 100644 --- a/src/MySqlConnector/Serialization/MySqlSession.cs +++ b/src/MySqlConnector/Serialization/MySqlSession.cs @@ -46,6 +46,7 @@ public MySqlSession(ConnectionPool pool, int poolGeneration, int id) public string DatabaseOverride { get; set; } public IPAddress IPAddress => (m_tcpClient?.Client.RemoteEndPoint as IPEndPoint)?.Address; public WeakReference OwningConnection { get; set; } + public bool SupportsDeprecateEof => m_supportsDeprecateEof; public void ReturnToPool() { @@ -142,8 +143,10 @@ public void FinishQuerying() lock (m_lock) { - VerifyState(State.Querying, State.ClearingPendingCancellation); - m_state = State.Connected; + if (m_state == State.Querying || m_state == State.ClearingPendingCancellation) + m_state = State.Connected; + else + VerifyState(State.Failed); m_activeCommandId = 0; } } @@ -234,6 +237,8 @@ public async Task ConnectAsync(ConnectionSettings cs, IOBehavior ioBehavior, Can if (m_supportsConnectionAttributes && s_connectionAttributes == null) s_connectionAttributes = CreateConnectionAttributes(); + m_supportsDeprecateEof = (initialHandshake.ProtocolCapabilities & ProtocolCapabilities.DeprecateEof) != 0; + var response = HandshakeResponse41Packet.Create(initialHandshake, cs, m_useCompression, m_supportsConnectionAttributes ? s_connectionAttributes : null); payload = new PayloadData(new ArraySegment(response)); await SendReplyAsync(payload, ioBehavior, cancellationToken).ConfigureAwait(false); @@ -788,7 +793,7 @@ private PayloadData TryAsyncContinuation(Task> task) return payload; } - private void SetFailed() + internal void SetFailed() { lock (m_lock) m_state = State.Failed; @@ -894,5 +899,6 @@ private enum State bool m_useCompression; bool m_isSecureConnection; bool m_supportsConnectionAttributes; + bool m_supportsDeprecateEof; } } diff --git a/src/MySqlConnector/Serialization/OkPayload.cs b/src/MySqlConnector/Serialization/OkPayload.cs index ef289060c..2587200a5 100644 --- a/src/MySqlConnector/Serialization/OkPayload.cs +++ b/src/MySqlConnector/Serialization/OkPayload.cs @@ -1,4 +1,6 @@ -namespace MySql.Data.Serialization +using System; + +namespace MySql.Data.Serialization { internal sealed class OkPayload { @@ -9,10 +11,24 @@ internal sealed class OkPayload public const byte Signature = 0x00; - public static OkPayload Create(PayloadData payload) + /* See + * http://web.archive.org/web/20160604101747/http://dev.mysql.com/doc/internals/en/packet-OK_Packet.html + * https://mariadb.com/kb/en/the-mariadb-library/resultset/ + * https://github.com/MariaDB/mariadb-connector-j/blob/5fa814ac6e1b4c9cb6d141bd221cbd5fc45c8a78/src/main/java/org/mariadb/jdbc/internal/com/read/resultset/SelectResultSet.java#L443-L444 + */ + public static bool IsOk(PayloadData payload, bool deprecateEof) => + payload.ArraySegment.Array != null && payload.ArraySegment.Count > 0 && + ((payload.ArraySegment.Count > 6 && payload.ArraySegment.Array[payload.ArraySegment.Offset] == Signature) || + (deprecateEof && payload.ArraySegment.Count < 0xFF_FFFF && payload.ArraySegment.Array[payload.ArraySegment.Offset] == EofPayload.Signature)); + + public static OkPayload Create(PayloadData payload) => Create(payload, false); + + public static OkPayload Create(PayloadData payload, bool deprecateEof) { var reader = new ByteArrayReader(payload.ArraySegment); - reader.ReadByte(Signature); + var signature = reader.ReadByte(); + if (signature != Signature && (!deprecateEof || signature != EofPayload.Signature)) + throw new FormatException("Expected to read 0x00 or 0xFE but got 0x{0:X2}".FormatInvariant(signature)); var affectedRowCount = checked((int) reader.ReadLengthEncodedInteger()); var lastInsertId = checked((long) reader.ReadLengthEncodedInteger()); var serverStatus = (ServerStatus) reader.ReadUInt16();