@@ -26,8 +26,9 @@ import bson
26
26
import numpy as np
27
27
from pyarrow import timestamp, struct , field
28
28
from pyarrow.lib import (
29
- tobytes, StructType, int32, int64, float64, string, bool_
29
+ tobytes, StructType, int32, int64, float64, string, bool_, list_
30
30
)
31
+
31
32
from pymongoarrow.errors import InvalidBSON, PyMongoArrowError
32
33
from pymongoarrow.context import PyMongoArrowContext
33
34
from pymongoarrow.types import _BsonArrowTypes, _atypes, ObjectIdType, Decimal128StringType
@@ -65,7 +66,8 @@ _builder_type_map = {
65
66
BSON_TYPE_UTF8: StringBuilder,
66
67
BSON_TYPE_BOOL: BoolBuilder,
67
68
BSON_TYPE_DOCUMENT: DocumentBuilder,
68
- BSON_TYPE_DECIMAL128: StringBuilder
69
+ BSON_TYPE_DECIMAL128: StringBuilder,
70
+ BSON_TYPE_ARRAY: ListBuilder,
69
71
}
70
72
71
73
_field_type_map = {
@@ -75,9 +77,26 @@ _field_type_map = {
75
77
BSON_TYPE_OID: ObjectIdType(),
76
78
BSON_TYPE_UTF8: string(),
77
79
BSON_TYPE_BOOL: bool_(),
78
- BSON_TYPE_DECIMAL128: Decimal128StringType()
80
+ BSON_TYPE_DECIMAL128: Decimal128StringType(),
79
81
}
80
82
83
+ cdef extract_field_dtype(bson_iter_t * doc_iter, bson_iter_t * child_iter, bson_type_t value_t, context):
84
+ """ Get the appropropriate data type for a specific field"""
85
+ if value_t in _field_type_map:
86
+ field_type = _field_type_map[value_t]
87
+ elif value_t == BSON_TYPE_ARRAY:
88
+ bson_iter_recurse(doc_iter, child_iter)
89
+ list_dtype = extract_array_dtype(child_iter, context)
90
+ field_type = list_(list_dtype)
91
+ elif value_t == BSON_TYPE_DOCUMENT:
92
+ bson_iter_recurse(doc_iter, child_iter)
93
+ field_type = extract_document_dtype(child_iter, context)
94
+ elif value_t == BSON_TYPE_DATE_TIME:
95
+ field_type = timestamp(' ms' , tz = context.tzinfo)
96
+ else :
97
+ raise PyMongoArrowError(' unknown value type {}' .format(value_t))
98
+ return field_type
99
+
81
100
82
101
cdef extract_document_dtype(bson_iter_t * doc_iter, context):
83
102
""" Get the appropropriate data type for a sub document"""
@@ -88,19 +107,21 @@ cdef extract_document_dtype(bson_iter_t * doc_iter, context):
88
107
while bson_iter_next(doc_iter):
89
108
key = bson_iter_key(doc_iter)
90
109
value_t = bson_iter_type(doc_iter)
91
- if value_t in _field_type_map:
92
- field_type = _field_type_map[value_t]
93
- elif value_t == BSON_TYPE_DOCUMENT:
94
- bson_iter_recurse(doc_iter, & child_iter)
95
- field_type = extract_document_dtype(& child_iter, context)
96
- elif value_t == BSON_TYPE_DATE_TIME:
97
- field_type = timestamp(' ms' , tz = context.tzinfo)
98
-
110
+ field_type = extract_field_dtype(doc_iter, & child_iter, value_t, context)
99
111
fields.append(field(key.decode(' utf-8' ), field_type))
100
112
return struct (fields)
101
113
114
+ cdef extract_array_dtype(bson_iter_t * doc_iter, context):
115
+ """ Get the appropropriate data type for a sub array"""
116
+ cdef const char * key
117
+ cdef bson_type_t value_t
118
+ cdef bson_iter_t child_iter
119
+ fields = []
120
+ first_item = bson_iter_next(doc_iter)
121
+ value_t = bson_iter_type(doc_iter)
122
+ return extract_field_dtype(doc_iter, & child_iter, value_t, context)
102
123
103
- def process_bson_stream (bson_stream , context ):
124
+ def process_bson_stream (bson_stream , context , arr_value_builder = None ):
104
125
""" Process a bson byte stream using a PyMongoArrowContext"""
105
126
cdef const uint8_t* docstream = < const uint8_t * > bson_stream
106
127
cdef size_t length = < size_t> PyBytes_Size(bson_stream)
@@ -110,6 +131,8 @@ def process_bson_stream(bson_stream, context):
110
131
cdef uint32_t str_len
111
132
cdef const uint8_t * doc_buf = NULL
112
133
cdef uint32_t doc_buf_len = 0 ;
134
+ cdef const uint8_t * arr_buf = NULL
135
+ cdef uint32_t arr_buf_len = 0 ;
113
136
cdef bson_decimal128_t dec128
114
137
cdef bson_type_t value_t
115
138
cdef const char * bson_str
@@ -131,12 +154,13 @@ def process_bson_stream(bson_stream, context):
131
154
t_string = _BsonArrowTypes.string
132
155
t_bool = _BsonArrowTypes.bool
133
156
t_document = _BsonArrowTypes.document
157
+ t_array = _BsonArrowTypes.array
158
+
134
159
135
160
# initialize count to current length of builders
136
161
for _, builder in builder_map.items():
137
162
count = len (builder)
138
163
break
139
-
140
164
try :
141
165
while True :
142
166
doc = bson_reader_read_safe(stream_reader)
@@ -146,7 +170,10 @@ def process_bson_stream(bson_stream, context):
146
170
raise InvalidBSON(" Could not read BSON document" )
147
171
while bson_iter_next(& doc_iter):
148
172
key = bson_iter_key(& doc_iter)
149
- builder = builder_map.get(key)
173
+ if arr_value_builder is not None :
174
+ builder = arr_value_builder
175
+ else :
176
+ builder = builder_map.get(key)
150
177
if builder is None :
151
178
builder = builder_map.get(key)
152
179
if builder is None and context.schema is None :
@@ -165,10 +192,15 @@ def process_bson_stream(bson_stream, context):
165
192
bson_iter_recurse(& doc_iter, & child_iter)
166
193
struct_dtype = extract_document_dtype(& child_iter, context)
167
194
builder = DocumentBuilder(struct_dtype, context.tzinfo)
195
+ elif builder_type == ListBuilder:
196
+ bson_iter_recurse(& doc_iter, & child_iter)
197
+ list_dtype = extract_array_dtype(& child_iter, context)
198
+ list_dtype = list_(list_dtype)
199
+ builder = ListBuilder(list_dtype, context.tzinfo, value_builder = arr_value_builder)
168
200
else :
169
201
builder = builder_type()
170
-
171
- builder_map[key] = builder
202
+ if arr_value_builder is None :
203
+ builder_map[key] = builder
172
204
for _ in range (count):
173
205
builder.append_null()
174
206
@@ -231,6 +263,14 @@ def process_bson_stream(bson_stream, context):
231
263
builder.append(< bytes> doc_buf[:doc_buf_len])
232
264
else :
233
265
builder.append_null()
266
+ elif ftype == t_array:
267
+ if value_t == BSON_TYPE_ARRAY:
268
+ bson_iter_array(& doc_iter, & doc_buf_len, & doc_buf)
269
+ if doc_buf_len <= 0 :
270
+ raise ValueError (" Subarray is invalid" )
271
+ builder.append(< bytes> doc_buf[:doc_buf_len])
272
+ else :
273
+ builder.append_null()
234
274
else :
235
275
raise PyMongoArrowError(' unknown ftype {}' .format(ftype))
236
276
count += 1
@@ -467,7 +507,11 @@ cdef class BoolBuilder(_ArrayBuilderBase):
467
507
cdef object get_field_builder(field, tzinfo):
468
508
""" "Find the appropriate field builder given a pyarrow field"""
469
509
cdef object field_builder
470
- field_type = field.type
510
+ cdef DataType field_type
511
+ if isinstance (field, DataType):
512
+ field_type = field
513
+ else :
514
+ field_type = field.type
471
515
if _atypes.is_int32(field_type):
472
516
field_builder = Int32Builder()
473
517
elif _atypes.is_int64(field_type):
@@ -484,6 +528,8 @@ cdef object get_field_builder(field, tzinfo):
484
528
field_builder = BoolBuilder()
485
529
elif _atypes.is_struct(field_type):
486
530
field_builder = DocumentBuilder(field_type, tzinfo)
531
+ elif _atypes.is_list(field_type):
532
+ field_builder = ListBuilder(field_type, tzinfo)
487
533
elif getattr (field_type, ' _type_marker' ) == _BsonArrowTypes.objectid:
488
534
field_builder = ObjectIdBuilder()
489
535
elif getattr (field_type, ' _type_marker' ) == _BsonArrowTypes.decimal128_str:
@@ -549,3 +595,55 @@ cdef class DocumentBuilder(_ArrayBuilderBase):
549
595
550
596
cdef shared_ptr[CStructBuilder] unwrap(self ):
551
597
return self .builder
598
+
599
+ cdef class ListBuilder(_ArrayBuilderBase):
600
+ type_marker = _BsonArrowTypes.array
601
+
602
+ cdef:
603
+ shared_ptr[CListBuilder] builder
604
+ _ArrayBuilderBase child_builder
605
+ object dtype
606
+ object context
607
+
608
+ def __cinit__ (self , DataType dtype , tzinfo = None , MemoryPool memory_pool = None , value_builder = None ):
609
+ cdef StringBuilder field_builder
610
+ cdef CMemoryPool* pool = maybe_unbox_memory_pool(memory_pool)
611
+ cdef shared_ptr[CArrayBuilder] grandchild_builder
612
+ self .dtype = dtype
613
+ if not _atypes.is_list(dtype):
614
+ raise ValueError (" dtype must be a list_()" )
615
+ self .context = context = PyMongoArrowContext(None , {})
616
+ self .context.tzinfo = tzinfo
617
+ field_builder = < StringBuilder> get_field_builder(self .dtype.value_type, tzinfo)
618
+ grandchild_builder = < shared_ptr[CArrayBuilder]> field_builder.builder
619
+ self .child_builder = field_builder
620
+ self .builder.reset(new CListBuilder(pool, grandchild_builder, pyarrow_unwrap_data_type(dtype)))
621
+
622
+
623
+ @property
624
+ def dtype (self ):
625
+ return self .dtype
626
+
627
+ cpdef append_null(self ):
628
+ self .builder.get().AppendNull()
629
+
630
+ def __len__ (self ):
631
+ return self .builder.get().length()
632
+
633
+ cpdef append(self , value):
634
+ if not isinstance (value, bytes):
635
+ value = bson.encode(value)
636
+ # Append an element to the array.
637
+ # arr_value_builder will be appended to by process_bson_stream.
638
+ self .builder.get().Append(True )
639
+ process_bson_stream(value, self .context, arr_value_builder = self .child_builder)
640
+
641
+
642
+ cpdef finish(self ):
643
+ cdef shared_ptr[CArray] out
644
+ with nogil:
645
+ self .builder.get().Finish(& out)
646
+ return pyarrow_wrap_array(out)
647
+
648
+ cdef shared_ptr[CListBuilder] unwrap(self ):
649
+ return self .builder
0 commit comments