Skip to content

Commit 46ef481

Browse files
author
takashi hashida
committed
Support ListType serialization / deserialization
1 parent 54cb9c0 commit 46ef481

File tree

9 files changed

+237
-55
lines changed

9 files changed

+237
-55
lines changed

csharp/src/Apache.Arrow/Arrays/ListArray.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,14 @@ public ListArray(IArrowType dataType, int length,
3232
: this(new ArrayData(dataType, length, nullCount, offset,
3333
new[] {nullBitmapBuffer, valueOffsetsBuffer}, new[] {values.Data}))
3434
{
35-
Values = values;
3635
}
3736

3837
public ListArray(ArrayData data)
3938
: base(data)
4039
{
4140
data.EnsureBufferCount(2);
4241
data.EnsureDataType(ArrowTypeId.List);
42+
Values = ArrowArrayFactory.BuildArray(data.Children[0]);
4343
}
4444

4545
public override void Accept(IArrowArrayVisitor visitor) => Accept(this, visitor);

csharp/src/Apache.Arrow/Ipc/ArrowReaderImplementation.cs

Lines changed: 71 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919
using System.Collections.Generic;
2020
using System.Diagnostics;
2121
using System.IO;
22+
using System.Linq;
2223
using System.Threading;
2324
using System.Threading.Tasks;
25+
using Apache.Arrow.Types;
2426

2527
namespace Apache.Arrow.Ipc
2628
{
@@ -113,33 +115,35 @@ private List<IArrowArray> BuildArrays(
113115
ByteBuffer messageBuffer,
114116
Flatbuf.RecordBatch recordBatchMessage)
115117
{
116-
var arrays = new List<IArrowArray>(recordBatchMessage.NodesLength);
117-
int bufferIndex = 0;
118+
return CreateInner().ToList();
118119

119-
for (var n = 0; n < recordBatchMessage.NodesLength; n++)
120+
IEnumerable<IArrowArray> CreateInner()
120121
{
121-
Field field = schema.GetFieldByIndex(n);
122-
Flatbuf.FieldNode fieldNode = recordBatchMessage.Nodes(n).GetValueOrDefault();
122+
var recordBatchManipulator = new RecordBatchManipulator(in recordBatchMessage);
123123

124-
ArrayData arrayData = field.DataType.IsFixedPrimitive() ?
125-
LoadPrimitiveField(field, fieldNode, recordBatchMessage, messageBuffer, ref bufferIndex) :
126-
LoadVariableField(field, fieldNode, recordBatchMessage, messageBuffer, ref bufferIndex);
124+
while (!recordBatchManipulator.IsAllNodeRead)
125+
{
126+
var field = schema.GetFieldByIndex(recordBatchManipulator.CurrentNodeIndex);
127+
Flatbuf.FieldNode fieldNode = recordBatchManipulator.UnshiftNode();
127128

128-
arrays.Add(ArrowArrayFactory.BuildArray(arrayData));
129-
}
129+
var arrayData = field.DataType.IsFixedPrimitive() ?
130+
LoadPrimitiveField(recordBatchManipulator, field, in fieldNode, messageBuffer) :
131+
LoadVariableField(recordBatchManipulator, field, in fieldNode, messageBuffer);
130132

131-
return arrays;
133+
yield return ArrowArrayFactory.BuildArray(arrayData);
134+
}
135+
}
132136
}
133137

138+
134139
private ArrayData LoadPrimitiveField(
140+
RecordBatchManipulator recordBatchManipulator,
135141
Field field,
136-
Flatbuf.FieldNode fieldNode,
137-
Flatbuf.RecordBatch recordBatch,
138-
ByteBuffer bodyData,
139-
ref int bufferIndex)
142+
in Flatbuf.FieldNode fieldNode,
143+
ByteBuffer bodyData)
140144
{
141-
var nullBitmapBuffer = recordBatch.Buffers(bufferIndex++).GetValueOrDefault();
142-
var valueBuffer = recordBatch.Buffers(bufferIndex++).GetValueOrDefault();
145+
var nullBitmapBuffer = recordBatchManipulator.UnshiftBuffer();
146+
var valueBuffer = recordBatchManipulator.UnshiftBuffer();
143147

144148
ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, nullBitmapBuffer);
145149
ArrowBuffer valueArrowBuffer = BuildArrowBuffer(bodyData, valueBuffer);
@@ -158,20 +162,21 @@ private ArrayData LoadPrimitiveField(
158162
}
159163

160164
var arrowBuff = new[] { nullArrowBuffer, valueArrowBuffer };
165+
var offspring = GetOffspring(recordBatchManipulator, field, bodyData);
161166

162-
return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff);
167+
return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, offspring.ToArray());
163168
}
164169

170+
165171
private ArrayData LoadVariableField(
172+
RecordBatchManipulator recordBatchManipulator,
166173
Field field,
167-
Flatbuf.FieldNode fieldNode,
168-
Flatbuf.RecordBatch recordBatch,
169-
ByteBuffer bodyData,
170-
ref int bufferIndex)
174+
in Flatbuf.FieldNode fieldNode,
175+
ByteBuffer bodyData)
171176
{
172-
var nullBitmapBuffer = recordBatch.Buffers(bufferIndex++).GetValueOrDefault();
173-
var offsetBuffer = recordBatch.Buffers(bufferIndex++).GetValueOrDefault();
174-
var valueBuffer = recordBatch.Buffers(bufferIndex++).GetValueOrDefault();
177+
var nullBitmapBuffer = recordBatchManipulator.UnshiftBuffer();
178+
var offsetBuffer = recordBatchManipulator.UnshiftBuffer();
179+
var valueBuffer = recordBatchManipulator.UnshiftBuffer();
175180

176181
ArrowBuffer nullArrowBuffer = BuildArrowBuffer(bodyData, nullBitmapBuffer);
177182
ArrowBuffer offsetArrowBuffer = BuildArrowBuffer(bodyData, offsetBuffer);
@@ -191,8 +196,24 @@ private ArrayData LoadVariableField(
191196
}
192197

193198
var arrowBuff = new[] { nullArrowBuffer, offsetArrowBuffer, valueArrowBuffer };
199+
var offspring = GetOffspring(recordBatchManipulator, field, bodyData);
200+
201+
return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff, offspring.ToArray());
202+
}
194203

195-
return new ArrayData(field.DataType, fieldLength, fieldNullCount, 0, arrowBuff);
204+
private IEnumerable<ArrayData> GetOffspring(
205+
RecordBatchManipulator recordBatchManipulator,
206+
Field field,
207+
ByteBuffer bodyData)
208+
{
209+
if (!(field.DataType is NestedType type)) yield break;
210+
foreach (var childField in type.Children)
211+
{
212+
Flatbuf.FieldNode childFieldNode = recordBatchManipulator.UnshiftNode();
213+
yield return childField.DataType.IsFixedPrimitive()
214+
? LoadPrimitiveField(recordBatchManipulator, childField, in childFieldNode, bodyData)
215+
: LoadVariableField(recordBatchManipulator, childField, in childFieldNode, bodyData);
216+
}
196217
}
197218

198219
private ArrowBuffer BuildArrowBuffer(ByteBuffer bodyData, Flatbuf.Buffer buffer)
@@ -209,4 +230,28 @@ private ArrowBuffer BuildArrowBuffer(ByteBuffer bodyData, Flatbuf.Buffer buffer)
209230
return new ArrowBuffer(data);
210231
}
211232
}
233+
234+
internal class RecordBatchManipulator
235+
{
236+
private int CurrentBufferIndex { get; set; }
237+
private Flatbuf.RecordBatch RecordBatch { get; }
238+
internal int CurrentNodeIndex { get; set; }
239+
internal bool IsAllNodeRead => CurrentNodeIndex >= RecordBatch.NodesLength;
240+
241+
internal RecordBatchManipulator(in Flatbuf.RecordBatch recordBatch)
242+
{
243+
RecordBatch = recordBatch;
244+
}
245+
246+
internal Flatbuf.Buffer UnshiftBuffer()
247+
{
248+
return RecordBatch.Buffers(CurrentBufferIndex++).GetValueOrDefault();
249+
}
250+
251+
internal Flatbuf.FieldNode UnshiftNode()
252+
{
253+
return RecordBatch.Nodes(CurrentNodeIndex++).GetValueOrDefault();
254+
}
255+
256+
}
212257
}

csharp/src/Apache.Arrow/Ipc/ArrowStreamWriter.cs

Lines changed: 67 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
using System.Buffers.Binary;
1919
using System.Collections.Generic;
2020
using System.IO;
21+
using System.Linq;
2122
using System.Threading;
2223
using System.Threading.Tasks;
24+
using Apache.Arrow.Types;
2325
using FlatBuffers;
2426

2527
namespace Apache.Arrow.Ipc
@@ -108,7 +110,7 @@ private void CreateBuffers(BooleanArray array)
108110
}
109111

110112
private void CreateBuffers<T>(PrimitiveArray<T> array)
111-
where T: struct
113+
where T : struct
112114
{
113115
_buffers.Add(CreateBuffer(array.NullBitmapBuffer));
114116
_buffers.Add(CreateBuffer(array.ValueBuffer));
@@ -175,6 +177,49 @@ public ArrowStreamWriter(Stream baseStream, Schema schema, bool leaveOpen, IpcOp
175177
_options = options ?? IpcOptions.Default;
176178
}
177179

180+
private VectorOffset CreateFieldVector(IEnumerable<IArrowArray> fieldArrayList)
181+
{
182+
var allArrowArrayList = GetAll(fieldArrayList).ToList();
183+
184+
Flatbuf.RecordBatch.StartNodesVector(Builder, allArrowArrayList.Count);
185+
186+
foreach (var array in allArrowArrayList)
187+
{
188+
Flatbuf.FieldNode.CreateFieldNode(Builder, array.Length, array.NullCount);
189+
}
190+
191+
return Builder.EndVector();
192+
193+
194+
//Inner methods
195+
196+
IEnumerable<IArrowArray> GetAll(IEnumerable<IArrowArray> targetArrayList)
197+
{
198+
foreach (var arrowArray in targetArrayList)
199+
{
200+
foreach (var arr in GetSelfAndOffspring(arrowArray))
201+
{
202+
yield return arr;
203+
}
204+
}
205+
}
206+
207+
IEnumerable<IArrowArray> GetSelfAndOffspring(IArrowArray targetArray)
208+
{
209+
if (targetArray.Data.DataType is NestedType)
210+
{
211+
foreach (var child in targetArray.Data.Children)
212+
{
213+
foreach (var offspring in GetSelfAndOffspring(ArrowArrayFactory.BuildArray(child)))
214+
{
215+
yield return offspring;
216+
}
217+
}
218+
}
219+
yield return targetArray;
220+
}
221+
}
222+
178223
private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBatch,
179224
CancellationToken cancellationToken = default)
180225
{
@@ -189,23 +234,13 @@ private protected async Task WriteRecordBatchInternalAsync(RecordBatch recordBat
189234
Builder.Clear();
190235

191236
// Serialize field nodes
192-
193-
var fieldCount = Schema.Fields.Count;
194-
195-
Flatbuf.RecordBatch.StartNodesVector(Builder, fieldCount);
196-
197237
// flatbuffer struct vectors have to be created in reverse order
198-
for (var i = fieldCount - 1; i >= 0; i--)
199-
{
200-
var fieldArray = recordBatch.Column(i);
201-
Flatbuf.FieldNode.CreateFieldNode(Builder, fieldArray.Length, fieldArray.NullCount);
202-
}
203-
204-
var fieldNodesVectorOffset = Builder.EndVector();
238+
var fieldNodesVectorOffset = CreateFieldVector(recordBatch.Arrays.Reverse());
205239

206240
// Serialize buffers
207-
208241
var recordBatchBuilder = new ArrowRecordBatchFlatBufferBuilder();
242+
243+
var fieldCount = Schema.Fields.Count;
209244
for (var i = 0; i < fieldCount; i++)
210245
{
211246
var fieldArray = recordBatch.Column(i);
@@ -285,7 +320,7 @@ public virtual Task WriteRecordBatchAsync(RecordBatch recordBatch, CancellationT
285320
{
286321
return WriteRecordBatchInternalAsync(recordBatch, cancellationToken);
287322
}
288-
323+
289324
public async Task WriteEndAsync(CancellationToken cancellationToken = default)
290325
{
291326
if (!HasWrittenEnd)
@@ -307,14 +342,14 @@ private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken ca
307342
// Build fields
308343

309344
var fieldOffsets = new Offset<Flatbuf.Field>[schema.Fields.Count];
310-
var fieldChildren = new List<Offset<Flatbuf.Field>>();
311345

312346
for (var i = 0; i < fieldOffsets.Length; i++)
313347
{
314348
var field = schema.GetFieldByIndex(i);
315349
var fieldNameOffset = Builder.CreateString(field.Name);
316350
var fieldType = _fieldTypeBuilder.BuildFieldType(field);
317351

352+
var fieldChildren = GetChildrenFieldOffset(field).ToArray();
318353
var fieldChildrenOffsets = Builder.CreateVectorOfTables(fieldChildren.ToArray());
319354

320355
fieldOffsets[i] = Flatbuf.Field.CreateField(Builder,
@@ -332,6 +367,21 @@ private ValueTask WriteBufferAsync(ArrowBuffer arrowBuffer, CancellationToken ca
332367
Builder, endianness, fieldsVectorOffset);
333368
}
334369

370+
private protected IEnumerable<Offset<Flatbuf.Field>> GetChildrenFieldOffset(Field field)
371+
{
372+
if (!(field.DataType is NestedType type)) yield break;
373+
foreach (var child in type.Children)
374+
{
375+
var fieldNameOffset = Builder.CreateString(child.Name);
376+
var fieldType = _fieldTypeBuilder.BuildFieldType(child);
377+
var fieldChildrenOffsets = Builder.CreateVectorOfTables(GetChildrenFieldOffset(child).ToArray());
378+
379+
yield return Flatbuf.Field.CreateField(Builder,
380+
fieldNameOffset, child.IsNullable, fieldType.Type, fieldType.Offset,
381+
default, fieldChildrenOffsets, default);
382+
}
383+
}
384+
335385
private async ValueTask<Offset<Flatbuf.Schema>> WriteSchemaAsync(Schema schema, CancellationToken cancellationToken)
336386
{
337387
Builder.Clear();
@@ -357,7 +407,7 @@ await WriteMessageAsync(Flatbuf.MessageHeader.Schema, schemaOffset, 0, cancellat
357407
private async ValueTask<long> WriteMessageAsync<T>(
358408
Flatbuf.MessageHeader headerType, Offset<T> headerOffset, int bodyLength,
359409
CancellationToken cancellationToken)
360-
where T: struct
410+
where T : struct
361411
{
362412
var messageOffset = Flatbuf.Message.CreateMessage(
363413
Builder, CurrentMetadataVersion, headerType, headerOffset.Value,

csharp/src/Apache.Arrow/Ipc/ArrowTypeFlatbufferBuilder.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ public void Visit(BinaryType type)
9898

9999
public void Visit(ListType type)
100100
{
101-
throw new NotImplementedException();
101+
Flatbuf.List.StartList(Builder);
102+
Result = FieldType.Build(
103+
Flatbuf.Type.List,
104+
Flatbuf.List.EndList(Builder));
102105
}
103106

104107
public void Visit(UnionType type)

csharp/src/Apache.Arrow/Ipc/MessageSerializer.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ private static Types.IArrowType GetFieldArrowType(Flatbuf.Field field)
125125
return new Types.StringType();
126126
case Flatbuf.Type.Binary:
127127
return Types.BinaryType.Default;
128+
case Flatbuf.Type.List:
129+
if (field.ChildrenLength != 1)
130+
{
131+
throw new InvalidDataException($"List type must have only one child.");
132+
}
133+
return new Types.ListType(GetFieldArrowType(field.Children(0).GetValueOrDefault()));
128134
default:
129135
throw new InvalidDataException($"Arrow primitive '{field.TypeType}' is unsupported.");
130136
}

csharp/src/Apache.Arrow/Types/ListType.cs

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,17 @@
1717

1818
namespace Apache.Arrow.Types
1919
{
20-
public sealed class ListType : ArrowType
20+
public sealed class ListType : NestedType
2121
{
2222
public override ArrowTypeId TypeId => ArrowTypeId.List;
2323
public override string Name => "list";
2424

25-
public Field ValueField { get; }
26-
public IArrowType ValueDataType { get; }
25+
public Field ValueField => Child(0);
26+
27+
public IArrowType ValueDataType => Child(0).DataType;
2728

2829
public ListType(Field valueField)
29-
{
30-
ValueField = valueField ?? throw new ArgumentNullException(nameof(valueField));
31-
ValueDataType = ValueField.DataType;
32-
}
30+
: base(valueField){ }
3331

3432
public ListType(IArrowType valueDataType)
3533
: this(new Field("item", valueDataType, true)) { }

0 commit comments

Comments
 (0)