Skip to content

Commit 139d8ed

Browse files
authored
ARROW-55 Allow specifying a custom type for converting ExtensionScalar to python object (#120)
1 parent b9c5164 commit 139d8ed

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

bindings/python/pymongoarrow/types.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from bson import Decimal128, Int64, ObjectId
2121
from pyarrow import DataType as _ArrowDataType
2222
from pyarrow import (
23+
ExtensionScalar,
2324
PyExtensionType,
2425
binary,
2526
bool_,
@@ -48,6 +49,11 @@ class _BsonArrowTypes(enum.Enum):
4849
# for details.
4950

5051

52+
class ObjectIdScalar(ExtensionScalar):
53+
def as_py(self):
54+
return ObjectId(self.value.as_py())
55+
56+
5157
class ObjectIdType(PyExtensionType):
5258
_type_marker = _BsonArrowTypes.objectid
5359

@@ -57,6 +63,14 @@ def __init__(self):
5763
def __reduce__(self):
5864
return ObjectIdType, ()
5965

66+
def __arrow_ext_scalar_class__(self):
67+
return ObjectIdScalar
68+
69+
70+
class Decimal128Scalar(ExtensionScalar):
71+
def as_py(self):
72+
return Decimal128(self.value.as_py())
73+
6074

6175
class Decimal128StringType(PyExtensionType):
6276
_type_marker = _BsonArrowTypes.decimal128_str
@@ -67,6 +81,9 @@ def __init__(self):
6781
def __reduce__(self):
6882
return Decimal128StringType, ()
6983

84+
def __arrow_ext_scalar_class__(self):
85+
return Decimal128Scalar
86+
7087

7188
# Internal Type Handling.
7289

bindings/python/test/test_schema.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,38 @@
1414
from datetime import datetime
1515
from unittest import TestCase
1616

17-
from bson import Int64
18-
from pyarrow import float64, int64, timestamp
17+
from bson import Decimal128, Int64, ObjectId
18+
from pyarrow import Table, float64, int64
19+
from pyarrow import schema as ArrowSchema
20+
from pyarrow import timestamp
1921
from pymongoarrow.schema import Schema
22+
from pymongoarrow.types import _TYPE_NORMALIZER_FACTORY
2023

2124

2225
class TestSchema(TestCase):
26+
def test_as_py(self):
27+
# Some of the classes want special things in their constructors.
28+
instantiated_objs = {
29+
datetime: datetime(1, 1, 1),
30+
str: "hell0",
31+
bool: True,
32+
}
33+
# The extension types need to be provided to from_pydict as strings or binary,
34+
# but we also need the original object for the assertion at the end of the test.
35+
oid = ObjectId()
36+
dec = Decimal128("1.000")
37+
lookup = {Decimal128: dec, ObjectId: oid}
38+
instantiated_objs.update({Decimal128: str(dec), ObjectId: oid.binary})
39+
40+
for k, v in _TYPE_NORMALIZER_FACTORY.items():
41+
# Make an array of 4 elements with either the instantiated object or 1.
42+
column = [instantiated_objs.get(k, 1)] * 4
43+
t = Table.from_pydict(
44+
{"value": column},
45+
ArrowSchema([("value", v(True))]),
46+
)
47+
self.assertEqual(t.to_pylist(), [{"value": lookup.get(k, i)} for i in column])
48+
2349
def test_initialization(self):
2450
dict_schema = Schema({"field1": int, "field2": datetime, "field3": float})
2551
self.assertEqual(

0 commit comments

Comments
 (0)