diff --git a/spec/API_specification/dataframe_api/__init__.py b/spec/API_specification/dataframe_api/__init__.py index 24df3ed4..74dc4346 100644 --- a/spec/API_specification/dataframe_api/__init__.py +++ b/spec/API_specification/dataframe_api/__init__.py @@ -11,35 +11,36 @@ from .dtypes import * if TYPE_CHECKING: - from .typing import DType + from .typing import DType, Scalar __all__ = [ - "__dataframe_api_version__", - "DataFrame", - "Column", - "column_from_sequence", - "column_from_1d_array", - "concat", - "dataframe_from_columns", - "dataframe_from_2d_array", - "is_null", - "null", - "Int64", - "Int32", - "Int16", - "Int8", - "UInt64", - "UInt32", - "UInt16", - "UInt8", - "Float64", - "Float32", "Bool", + "Column", + "DataFrame", "Date", "Datetime", "Duration", + "Float32", + "Float64", + "Int16", + "Int32", + "Int64", + "Int8", "String", + "UInt16", + "UInt32", + "UInt64", + "UInt8", + "__dataframe_api_version__", + "column_from_1d_array", + "column_from_sequence", + "concat", + "dataframe_from_2d_array", + "dataframe_from_columns", + "date", "is_dtype", + "is_null", + "null", ] @@ -234,3 +235,21 @@ def is_dtype(dtype: DType, kind: str | tuple[str, ...]) -> bool: ------- bool """ + +def date(year: int, month: int, day: int) -> Scalar: + """ + Create date object which can be used for filtering. + + The full 32-bit signed integer range of days since epoch should be supported (between -5877641-06-23 and 5881580-07-11 inclusive). + + Examples + -------- + >>> df: DataFrame + >>> namespace = df.__dataframe_namespace__() + >>> mask = ( + ... (df.get_column_by_name('date') >= namespace.date(2020, 1, 1)) + ... & (df.get_column_by_name('date') < namespace.date(2021, 1, 1)) + ... ) + >>> df.filter(mask) + """ + diff --git a/spec/API_specification/dataframe_api/typing.py b/spec/API_specification/dataframe_api/typing.py index 2e011a02..969faa0f 100644 --- a/spec/API_specification/dataframe_api/typing.py +++ b/spec/API_specification/dataframe_api/typing.py @@ -144,6 +144,9 @@ def is_null(value: object, /) -> bool: def is_dtype(dtype: Any, kind: str | tuple[str, ...]) -> bool: ... + @staticmethod + def date(year: int, month: int, day: int) -> Scalar: + ... class SupportsDataFrameAPI(Protocol): def __dataframe_consortium_standard__( diff --git a/spec/API_specification/examples/tpch/q5.py b/spec/API_specification/examples/tpch/q5.py index cdca0806..0ed69b67 100644 --- a/spec/API_specification/examples/tpch/q5.py +++ b/spec/API_specification/examples/tpch/q5.py @@ -58,8 +58,8 @@ def query( == result.get_column_by_name("s_nationkey") ) & (result.get_column_by_name("r_name") == "ASIA") - & (result.get_column_by_name("o_orderdate") >= namespace.date(1994, 1, 1)) # type: ignore - & (result.get_column_by_name("o_orderdate") < namespace.date(1995, 1, 1)) # type: ignore + & (result.get_column_by_name("o_orderdate") >= namespace.date(1994, 1, 1)) + & (result.get_column_by_name("o_orderdate") < namespace.date(1995, 1, 1)) ) result = result.filter(mask)