Skip to content

Commit 77bc63c

Browse files
committed
Support VECTOR data type
Related to #1549 --- For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/mysql-net/MySqlConnector/issues/1549?shareId=XXXX-XXXX-XXXX-XXXX).
1 parent 2f977d8 commit 77bc63c

File tree

7 files changed

+128
-1
lines changed

7 files changed

+128
-1
lines changed

src/MySqlConnector/Core/ColumnTypeMetadata.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,8 @@ internal sealed class ColumnTypeMetadata(string dataTypeName, DbTypeMapping dbTy
1616

1717
public string CreateLookupKey() => CreateLookupKey(DataTypeName, IsUnsigned, Length);
1818
}
19+
20+
internal static class ColumnTypeMetadataExtensions
21+
{
22+
public static ColumnTypeMetadata Vector { get; } = new("VECTOR", new DbTypeMapping(typeof(float[]), new[] { DbType.Object }, convert: o => (float[])o), MySqlDbType.Vector, isUnsigned: false, binary: true, length: 0, simpleDataTypeName: "VECTOR", createFormat: "VECTOR({0})");
23+
}

src/MySqlConnector/Core/TypeMapper.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ private TypeMapper()
5555
AddColumnTypeMetadata(new("DOUBLE", typeDouble, MySqlDbType.Double));
5656
AddColumnTypeMetadata(new("FLOAT", typeFloat, MySqlDbType.Float));
5757

58+
// vector
59+
var typeFloatArray = AddDbTypeMapping(new(typeof(float[]), [DbType.Object], convert: static o => (float[])o));
60+
AddColumnTypeMetadata(new("VECTOR", typeFloatArray, MySqlDbType.Vector, binary: true, simpleDataTypeName: "VECTOR", createFormat: "VECTOR({0})"));
61+
5862
// string
5963
var typeFixedString = AddDbTypeMapping(new(typeof(string), [DbType.StringFixedLength, DbType.AnsiStringFixedLength], convert: Convert.ToString!));
6064
var typeString = AddDbTypeMapping(new(typeof(string), [DbType.String, DbType.AnsiString, DbType.Xml], convert: Convert.ToString!));
@@ -303,6 +307,9 @@ public static MySqlDbType ConvertToMySqlDbType(ColumnDefinitionPayload columnDef
303307
case ColumnType.Set:
304308
return MySqlDbType.Set;
305309

310+
case ColumnType.Vector:
311+
return MySqlDbType.Vector;
312+
306313
default:
307314
throw new NotImplementedException($"ConvertToMySqlDbType for {columnDefinition.ColumnType} is not implemented");
308315
}
@@ -339,6 +346,7 @@ public static ushort ConvertToColumnTypeAndFlags(MySqlDbType dbType, MySqlGuidFo
339346
MySqlDbType.NewDecimal => ColumnType.NewDecimal,
340347
MySqlDbType.Geometry => ColumnType.Geometry,
341348
MySqlDbType.Null => ColumnType.Null,
349+
MySqlDbType.Vector => ColumnType.Vector,
342350
_ => throw new NotImplementedException($"ConvertToColumnTypeAndFlags for {dbType} is not implemented"),
343351
};
344352
return (ushort) ((byte) columnType | (isUnsigned ? 0x8000 : 0));

src/MySqlConnector/MySqlDataReader.cs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,15 @@ public override long GetChars(int ordinal, long dataOffset, char[]? buffer, int
252252
#endif
253253
public override Type GetFieldType(int ordinal) => GetResultSet().GetFieldType(ordinal);
254254

255-
public override object GetValue(int ordinal) => GetResultSet().GetCurrentRow().GetValue(ordinal);
255+
public override object GetValue(int ordinal)
256+
{
257+
var resultSet = GetResultSet();
258+
if (resultSet.GetDataTypeName(ordinal) == "VECTOR")
259+
{
260+
return resultSet.GetCurrentRow().GetValue(ordinal);
261+
}
262+
return resultSet.GetCurrentRow().GetValue(ordinal);
263+
}
256264

257265
public override IEnumerator GetEnumerator() => new DbEnumerator(this, closeReader: false);
258266

@@ -428,6 +436,14 @@ public override T GetFieldValue<T>(int ordinal)
428436
return (T) (object) GetDateOnly(ordinal);
429437
if (typeof(T) == typeof(TimeOnly))
430438
return (T) (object) GetTimeOnly(ordinal);
439+
if (typeof(T) == typeof(ReadOnlySpan<float>))
440+
{
441+
var value = GetValue(ordinal);
442+
if (value is float[] floatArray)
443+
{
444+
return (T) (object) new ReadOnlySpan<float>(floatArray);
445+
}
446+
}
431447
#endif
432448

433449
return base.GetFieldValue<T>(ordinal);

src/MySqlConnector/MySqlDbType.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ namespace MySqlConnector;
22

33
#pragma warning disable CA1720 // Identifier contains type name
44

5+
/// <summary>
6+
/// Specifies the MySQL data type of a field, property, for use in a <see cref="MySqlParameter"/>.
7+
/// </summary>
58
public enum MySqlDbType
69
{
710
Bool = -1,
@@ -37,6 +40,10 @@ public enum MySqlDbType
3740
VarChar,
3841
String,
3942
Geometry,
43+
/// <summary>
44+
/// A MySQL VECTOR data type.
45+
/// </summary>
46+
Vector = 242,
4047
UByte = 501,
4148
UInt16,
4249
UInt32,

src/MySqlConnector/MySqlParameter.cs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,17 @@ internal void AppendSqlString(ByteBufferWriter writer, StatementPreparerOptions
554554
{
555555
writer.WriteString((ulong) Value);
556556
}
557+
else if (Value is float[] floatArrayValue)
558+
{
559+
writer.Write((byte) '[');
560+
for (int i = 0; i < floatArrayValue.Length; i++)
561+
{
562+
if (i > 0)
563+
writer.Write((byte) ',');
564+
writer.WriteString(floatArrayValue[i]);
565+
}
566+
writer.Write((byte) ']');
567+
}
557568
else
558569
{
559570
throw new NotSupportedException($"Parameter type {Value.GetType().Name} is not supported; see https://mysqlconnector.net/param-type. Value: {Value}");
@@ -871,6 +882,14 @@ private void AppendBinary(ByteBufferWriter writer, object value, StatementPrepar
871882
{
872883
writer.Write((ulong) value);
873884
}
885+
else if (value is float[] floatArrayValue)
886+
{
887+
writer.WriteLengthEncodedInteger((ulong) (floatArrayValue.Length * 4));
888+
foreach (var floatValue in floatArrayValue)
889+
{
890+
writer.Write(BitConverter.GetBytes(floatValue));
891+
}
892+
}
874893
else
875894
{
876895
throw new NotSupportedException($"Parameter type {value.GetType().Name} is not supported; see https://mysqlconnector.net/param-type. Value: {value}");
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
using System;
2+
using System.Data;
3+
using MySqlConnector;
4+
using Xunit;
5+
6+
namespace MySqlConnector.Tests
7+
{
8+
public class MySqlDataReaderTests
9+
{
10+
[Fact]
11+
public void GetVectorDataType()
12+
{
13+
using var connection = new MySqlConnection("your_connection_string");
14+
connection.Open();
15+
16+
using var command = new MySqlCommand("SELECT CAST('[1.0, 2.0, 3.0]' AS VECTOR)", connection);
17+
using var reader = command.ExecuteReader();
18+
19+
Assert.True(reader.Read());
20+
var vector = reader.GetValue(0) as float[];
21+
Assert.NotNull(vector);
22+
Assert.Equal(new float[] { 1.0f, 2.0f, 3.0f }, vector);
23+
}
24+
25+
[Fact]
26+
public void GetReadOnlySpanFloat()
27+
{
28+
using var connection = new MySqlConnection("your_connection_string");
29+
connection.Open();
30+
31+
using var command = new MySqlCommand("SELECT CAST('[1.0, 2.0, 3.0]' AS VECTOR)", connection);
32+
using var reader = command.ExecuteReader();
33+
34+
Assert.True(reader.Read());
35+
var span = reader.GetFieldValue<ReadOnlySpan<float>>(0);
36+
Assert.Equal(new float[] { 1.0f, 2.0f, 3.0f }, span.ToArray());
37+
}
38+
}
39+
}

tests/MySqlConnector.Tests/MySqlParameterTests.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,4 +423,37 @@ public void ScaleMixed()
423423
((IDbDataParameter) parameter).Scale = 12;
424424
Assert.Equal((byte) 12, ((MySqlParameter) parameter).Scale);
425425
}
426+
427+
[Fact]
428+
public void SetValueToFloatArrayInfersType()
429+
{
430+
var parameter = new MySqlParameter { Value = new float[] { 1.0f, 2.0f, 3.0f } };
431+
Assert.Equal(DbType.Object, parameter.DbType);
432+
Assert.Equal(MySqlDbType.Vector, parameter.MySqlDbType);
433+
}
434+
435+
[Fact]
436+
public void ConstructorNameTypeVector()
437+
{
438+
var parameter = new MySqlParameter("@vector", MySqlDbType.Vector);
439+
Assert.Equal("@vector", parameter.ParameterName);
440+
Assert.Equal(MySqlDbType.Vector, parameter.MySqlDbType);
441+
Assert.Equal(DbType.Object, parameter.DbType);
442+
Assert.False(parameter.IsNullable);
443+
Assert.Null(parameter.Value);
444+
Assert.Equal(ParameterDirection.Input, parameter.Direction);
445+
Assert.Equal(0, parameter.Precision);
446+
Assert.Equal(0, parameter.Scale);
447+
Assert.Equal(0, parameter.Size);
448+
#if MYSQL_DATA
449+
Assert.Equal(DataRowVersion.Default, parameter.SourceVersion);
450+
#else
451+
Assert.Equal(DataRowVersion.Current, parameter.SourceVersion);
452+
#endif
453+
#if MYSQL_DATA
454+
Assert.Null(parameter.SourceColumn);
455+
#else
456+
Assert.Equal("", parameter.SourceColumn);
457+
#endif
458+
}
426459
}

0 commit comments

Comments
 (0)