diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index e207dac71752e..593e42f7ed749 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -1413,6 +1413,33 @@ def is_extension_type(arr) -> bool: return False +def is_1d_only_ea_obj(obj: Any) -> bool: + """ + ExtensionArray that does not support 2D, or more specifically that does + not use HybridBlock. + """ + from pandas.core.arrays import ( + DatetimeArray, + ExtensionArray, + TimedeltaArray, + ) + + return isinstance(obj, ExtensionArray) and not isinstance( + obj, (DatetimeArray, TimedeltaArray) + ) + + +def is_1d_only_ea_dtype(dtype: Optional[DtypeObj]) -> bool: + """ + Analogue to is_extension_array_dtype but excluding DatetimeTZDtype. + """ + # Note: if other EA dtypes are ever held in HybridBlock, exclude those + # here too. + # NB: need to check DatetimeTZDtype and not is_datetime64tz_dtype + # to exclude ArrowTimestampUSDtype + return isinstance(dtype, ExtensionDtype) and not isinstance(dtype, DatetimeTZDtype) + + def is_extension_array_dtype(arr_or_dtype) -> bool: """ Check if an object is a pandas extension array type. diff --git a/pandas/core/dtypes/concat.py b/pandas/core/dtypes/concat.py index cfadb3e9f45c5..b0d00775bbed1 100644 --- a/pandas/core/dtypes/concat.py +++ b/pandas/core/dtypes/concat.py @@ -113,11 +113,15 @@ def is_nonempty(x) -> bool: to_concat = non_empties kinds = {obj.dtype.kind for obj in to_concat} + contains_datetime = any(kind in ["m", "M"] for kind in kinds) all_empty = not len(non_empties) single_dtype = len({x.dtype for x in to_concat}) == 1 any_ea = any(isinstance(x.dtype, ExtensionDtype) for x in to_concat) + if contains_datetime: + return _concat_datetime(to_concat, axis=axis) + if any_ea: # we ignore axis here, as internally concatting with EAs is always # for axis=0 @@ -131,9 +135,6 @@ def is_nonempty(x) -> bool: else: return np.concatenate(to_concat) - elif any(kind in ["m", "M"] for kind in kinds): - return _concat_datetime(to_concat, axis=axis) - elif all_empty: # we have all empties, but may need to coerce the result dtype to # object if we have non-numeric type operands (numpy would otherwise @@ -349,14 +350,5 @@ def _concat_datetime(to_concat, axis=0): # in Timestamp/Timedelta return _concatenate_2d([x.astype(object) for x in to_concat], axis=axis) - if axis == 1: - # TODO(EA2D): kludge not necessary with 2D EAs - to_concat = [x.reshape(1, -1) if x.ndim == 1 else x for x in to_concat] - result = type(to_concat[0])._concat_same_type(to_concat, axis=axis) - - if result.ndim == 2 and isinstance(result.dtype, ExtensionDtype): - # TODO(EA2D): kludge not necessary with 2D EAs - assert result.shape[0] == 1 - result = result[0] return result diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 4a7ed2bfc18df..7f970a72cb12c 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -98,6 +98,7 @@ from pandas.core.dtypes.common import ( ensure_platform_int, infer_dtype_from_object, + is_1d_only_ea_dtype, is_bool_dtype, is_dataclass, is_datetime64_any_dtype, @@ -845,7 +846,9 @@ def _can_fast_transpose(self) -> bool: if len(blocks) != 1: return False - return not self._mgr.any_extension_types + dtype = blocks[0].dtype + # TODO(EA2D) special case would be unnecessary with 2D EAs + return not is_1d_only_ea_dtype(dtype) # ---------------------------------------------------------------------- # Rendering Methods diff --git a/pandas/core/internals/api.py b/pandas/core/internals/api.py index d6b76510c68ab..2f8686fd38929 100644 --- a/pandas/core/internals/api.py +++ b/pandas/core/internals/api.py @@ -6,7 +6,7 @@ 2) Use only functions exposed here (or in core.internals) """ -from typing import Optional +from __future__ import annotations import numpy as np @@ -23,6 +23,7 @@ Block, DatetimeTZBlock, check_ndim, + ensure_block_shape, extract_pandas_array, get_block_type, maybe_coerce_values, @@ -30,7 +31,7 @@ def make_block( - values, placement, klass=None, ndim=None, dtype: Optional[Dtype] = None + values, placement, klass=None, ndim=None, dtype: Dtype | None = None ) -> Block: """ This is a pseudo-public analogue to blocks.new_block. @@ -48,6 +49,7 @@ def make_block( values, dtype = extract_pandas_array(values, dtype, ndim) + needs_reshape = False if klass is None: dtype = dtype or values.dtype klass = get_block_type(values, dtype) @@ -55,17 +57,21 @@ def make_block( elif klass is DatetimeTZBlock and not is_datetime64tz_dtype(values.dtype): # pyarrow calls get here values = DatetimeArray._simple_new(values, dtype=dtype) + needs_reshape = True if not isinstance(placement, BlockPlacement): placement = BlockPlacement(placement) ndim = maybe_infer_ndim(values, placement, ndim) + if needs_reshape: + values = ensure_block_shape(values, ndim) + check_ndim(values, placement, ndim) values = maybe_coerce_values(values) return klass(values, ndim=ndim, placement=placement) -def maybe_infer_ndim(values, placement: BlockPlacement, ndim: Optional[int]) -> int: +def maybe_infer_ndim(values, placement: BlockPlacement, ndim: int | None) -> int: """ If `ndim` is not provided, infer it from placment and values. """ diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index 603cc6a6ff1f2..4276aadd8edd6 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -42,6 +42,8 @@ soft_convert_objects, ) from pandas.core.dtypes.common import ( + is_1d_only_ea_dtype, + is_1d_only_ea_obj, is_categorical_dtype, is_dtype_equal, is_extension_array_dtype, @@ -224,7 +226,6 @@ def get_values(self, dtype: DtypeObj | None = None) -> np.ndarray: # expected "ndarray") return self.values # type: ignore[return-value] - @final def get_block_values_for_json(self) -> np.ndarray: """ This is used in the JSON C code. @@ -415,7 +416,11 @@ def _split_op_result(self, result) -> list[Block]: # if we get a 2D ExtensionArray, we need to split it into 1D pieces nbs = [] for i, loc in enumerate(self._mgr_locs): - vals = result[i] + if not is_1d_only_ea_obj(result): + vals = result[i : i + 1] + else: + vals = result[i] + block = self.make_block(values=vals, placement=loc) nbs.append(block) return nbs @@ -1670,7 +1675,7 @@ class NumericBlock(NumpyBlock): is_numeric = True -class NDArrayBackedExtensionBlock(EABackedBlock): +class NDArrayBackedExtensionBlock(libinternals.Block, EABackedBlock): """ Block backed by an NDArrayBackedExtensionArray """ @@ -1683,11 +1688,6 @@ def is_view(self) -> bool: # check the ndarray values of the DatetimeIndex values return self.values._ndarray.base is not None - def iget(self, key): - # GH#31649 we need to wrap scalars in Timestamp/Timedelta - # TODO(EA2D): this can be removed if we ever have 2D EA - return self.values.reshape(self.shape)[key] - def setitem(self, indexer, value): if not self._can_hold_element(value): # TODO: general case needs casting logic. @@ -1707,24 +1707,21 @@ def putmask(self, mask, new) -> list[Block]: if not self._can_hold_element(new): return self.astype(object).putmask(mask, new) - # TODO(EA2D): reshape unnecessary with 2D EAs - arr = self.values.reshape(self.shape) + arr = self.values arr.T.putmask(mask, new) return [self] def where(self, other, cond, errors="raise") -> list[Block]: # TODO(EA2D): reshape unnecessary with 2D EAs - arr = self.values.reshape(self.shape) + arr = self.values cond = extract_bool_array(cond) try: res_values = arr.T.where(cond, other).T except (ValueError, TypeError): - return super().where(other, cond, errors=errors) + return Block.where(self, other, cond, errors=errors) - # TODO(EA2D): reshape not needed with 2D EAs - res_values = res_values.reshape(self.values.shape) nb = self.make_block_same_class(res_values) return [nb] @@ -1748,15 +1745,13 @@ def diff(self, n: int, axis: int = 0) -> list[Block]: The arguments here are mimicking shift so they are called correctly by apply. """ - # TODO(EA2D): reshape not necessary with 2D EAs - values = self.values.reshape(self.shape) + values = self.values new_values = values - values.shift(n, axis=axis) return [self.make_block(new_values)] def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> list[Block]: - # TODO(EA2D) this is unnecessary if these blocks are backed by 2D EAs - values = self.values.reshape(self.shape) + values = self.values new_values = values.shift(periods, fill_value=fill_value, axis=axis) return [self.make_block_same_class(new_values)] @@ -1776,31 +1771,27 @@ def fillna( return [self.make_block_same_class(values=new_values)] -class DatetimeLikeBlock(libinternals.Block, NDArrayBackedExtensionBlock): +class DatetimeLikeBlock(NDArrayBackedExtensionBlock): """Block for datetime64[ns], timedelta64[ns].""" __slots__ = () is_numeric = False values: DatetimeArray | TimedeltaArray + def get_block_values_for_json(self): + # Not necessary to override, but helps perf + return self.values._ndarray -class DatetimeTZBlock(ExtensionBlock, NDArrayBackedExtensionBlock): + +class DatetimeTZBlock(DatetimeLikeBlock): """ implement a datetime64 block with a tz attribute """ values: DatetimeArray __slots__ = () is_extension = True - is_numeric = False - - diff = NDArrayBackedExtensionBlock.diff - where = NDArrayBackedExtensionBlock.where - putmask = NDArrayBackedExtensionBlock.putmask - fillna = NDArrayBackedExtensionBlock.fillna - - get_values = NDArrayBackedExtensionBlock.get_values - - is_view = NDArrayBackedExtensionBlock.is_view + _validate_ndim = True + _can_consolidate = False class ObjectBlock(NumpyBlock): @@ -1967,7 +1958,7 @@ def check_ndim(values, placement: BlockPlacement, ndim: int): f"values.ndim > ndim [{values.ndim} > {ndim}]" ) - elif isinstance(values.dtype, np.dtype): + elif not is_1d_only_ea_dtype(values.dtype): # TODO(EA2D): special case not needed with 2D EAs if values.ndim != ndim: raise ValueError( @@ -1981,7 +1972,7 @@ def check_ndim(values, placement: BlockPlacement, ndim: int): ) elif ndim == 2 and len(placement) != 1: # TODO(EA2D): special case unnecessary with 2D EAs - raise AssertionError("block.size != values.size") + raise ValueError("need to split") def extract_pandas_array( @@ -2026,8 +2017,9 @@ def ensure_block_shape(values: ArrayLike, ndim: int = 1) -> ArrayLike: """ Reshape if possible to have values.ndim == ndim. """ + if values.ndim < ndim: - if not is_extension_array_dtype(values.dtype): + if not is_1d_only_ea_dtype(values.dtype): # TODO(EA2D): https://github.com/pandas-dev/pandas/issues/23023 # block.shape is incorrect for "2D" ExtensionArrays # We can't, and don't need to, reshape. diff --git a/pandas/core/internals/concat.py b/pandas/core/internals/concat.py index 0b0013eeb7147..51a381a1b7f4f 100644 --- a/pandas/core/internals/concat.py +++ b/pandas/core/internals/concat.py @@ -5,6 +5,7 @@ from typing import ( TYPE_CHECKING, Sequence, + cast, ) import numpy as np @@ -23,6 +24,8 @@ find_common_type, ) from pandas.core.dtypes.common import ( + is_1d_only_ea_dtype, + is_1d_only_ea_obj, is_datetime64tz_dtype, is_dtype_equal, is_extension_array_dtype, @@ -210,8 +213,8 @@ def concatenate_managers( values = np.concatenate(vals, axis=blk.ndim - 1) else: # TODO(EA2D): special-casing not needed with 2D EAs - values = concat_compat(vals) - values = ensure_block_shape(values, ndim=2) + values = concat_compat(vals, axis=1) + values = ensure_block_shape(values, blk.ndim) values = ensure_wrapped_if_datetimelike(values) @@ -412,13 +415,16 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike: fill_value = None if is_datetime64tz_dtype(empty_dtype): - # TODO(EA2D): special case unneeded with 2D EAs - i8values = np.full(self.shape[1], fill_value.value) + i8values = np.full(self.shape, fill_value.value) return DatetimeArray(i8values, dtype=empty_dtype) + elif is_extension_array_dtype(blk_dtype): pass - elif isinstance(empty_dtype, ExtensionDtype): + + elif is_1d_only_ea_dtype(empty_dtype): + empty_dtype = cast(ExtensionDtype, empty_dtype) cls = empty_dtype.construct_array_type() + missing_arr = cls._from_sequence([], dtype=empty_dtype) ncols, nrows = self.shape assert ncols == 1, ncols @@ -429,6 +435,7 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike: else: # NB: we should never get here with empty_dtype integer or bool; # if we did, the missing_arr.fill would cast to gibberish + empty_dtype = cast(np.dtype, empty_dtype) missing_arr = np.empty(self.shape, dtype=empty_dtype) missing_arr.fill(fill_value) @@ -493,15 +500,17 @@ def _concatenate_join_units( concat_values = concat_values.copy() else: concat_values = concat_values.copy() - elif any(isinstance(t, ExtensionArray) and t.ndim == 1 for t in to_concat): + + elif any(is_1d_only_ea_obj(t) for t in to_concat): + # TODO(EA2D): special case not needed if all EAs used HybridBlocks + # NB: we are still assuming here that Hybrid blocks have shape (1, N) # concatting with at least one EA means we are concatting a single column # the non-EA values are 2D arrays with shape (1, n) + # error: Invalid index type "Tuple[int, slice]" for # "Union[ExtensionArray, ndarray]"; expected type "Union[int, slice, ndarray]" to_concat = [ - t - if (isinstance(t, ExtensionArray) and t.ndim == 1) - else t[0, :] # type: ignore[index] + t if is_1d_only_ea_obj(t) else t[0, :] # type: ignore[index] for t in to_concat ] concat_values = concat_compat(to_concat, axis=0, ea_compat_axis=True) diff --git a/pandas/core/internals/construction.py b/pandas/core/internals/construction.py index 2960fb292818a..83ecdbce5fa80 100644 --- a/pandas/core/internals/construction.py +++ b/pandas/core/internals/construction.py @@ -32,6 +32,7 @@ maybe_upcast, ) from pandas.core.dtypes.common import ( + is_1d_only_ea_dtype, is_datetime64tz_dtype, is_dtype_equal, is_extension_array_dtype, @@ -55,7 +56,8 @@ ) from pandas.core.arrays import ( Categorical, - DatetimeArray, + ExtensionArray, + TimedeltaArray, ) from pandas.core.construction import ( extract_array, @@ -259,7 +261,8 @@ def ndarray_to_mgr( if not len(values) and columns is not None and len(columns): values = np.empty((0, 1), dtype=object) - if is_extension_array_dtype(values) or isinstance(dtype, ExtensionDtype): + vdtype = getattr(values, "dtype", None) + if is_1d_only_ea_dtype(vdtype) or isinstance(dtype, ExtensionDtype): # GH#19157 if isinstance(values, np.ndarray) and values.ndim > 1: @@ -274,9 +277,18 @@ def ndarray_to_mgr( return arrays_to_mgr(values, columns, index, columns, dtype=dtype, typ=typ) - # by definition an array here - # the dtypes will be coerced to a single dtype - values = _prep_ndarray(values, copy=copy) + if is_extension_array_dtype(vdtype) and not is_1d_only_ea_dtype(vdtype): + # i.e. Datetime64TZ + values = extract_array(values, extract_numpy=True) + if copy: + values = values.copy() + if values.ndim == 1: + values = values.reshape(-1, 1) + + else: + # by definition an array here + # the dtypes will be coerced to a single dtype + values = _prep_ndarray(values, copy=copy) if dtype is not None and not is_dtype_equal(values.dtype, dtype): shape = values.shape @@ -320,7 +332,6 @@ def ndarray_to_mgr( dvals_list = [ensure_block_shape(dval, 2) for dval in dvals_list] # TODO: What about re-joining object columns? - dvals_list = [maybe_squeeze_dt64tz(x) for x in dvals_list] block_values = [ new_block(dvals_list[n], placement=n, ndim=2) for n in range(len(dvals_list)) @@ -328,12 +339,10 @@ def ndarray_to_mgr( else: datelike_vals = maybe_infer_to_datetimelike(values) - datelike_vals = maybe_squeeze_dt64tz(datelike_vals) nb = new_block(datelike_vals, placement=slice(len(columns)), ndim=2) block_values = [nb] else: - new_values = maybe_squeeze_dt64tz(values) - nb = new_block(new_values, placement=slice(len(columns)), ndim=2) + nb = new_block(values, placement=slice(len(columns)), ndim=2) block_values = [nb] if len(columns) == 0: @@ -360,20 +369,6 @@ def _check_values_indices_shape_match( raise ValueError(f"Shape of passed values is {passed}, indices imply {implied}") -def maybe_squeeze_dt64tz(dta: ArrayLike) -> ArrayLike: - """ - If we have a tzaware DatetimeArray with shape (1, N), squeeze to (N,) - """ - # TODO(EA2D): kludge not needed with 2D EAs - if isinstance(dta, DatetimeArray) and dta.ndim == 2 and dta.tz is not None: - assert dta.shape[0] == 1 - # error: Incompatible types in assignment (expression has type - # "Union[DatetimeLikeArrayMixin, Union[Any, NaTType]]", variable has - # type "Union[ExtensionArray, ndarray]") - dta = dta[0] # type: ignore[assignment] - return dta - - def dict_to_mgr( data: dict, index, @@ -396,7 +391,6 @@ def dict_to_mgr( arrays = Series(data, index=columns, dtype=object) data_names = arrays.index - missing = arrays.isna() if index is None: # GH10856 @@ -481,13 +475,23 @@ def treat_as_nested(data) -> bool: """ Check if we should use nested_data_to_arrays. """ - return len(data) > 0 and is_list_like(data[0]) and getattr(data[0], "ndim", 1) == 1 + return ( + len(data) > 0 + and is_list_like(data[0]) + and getattr(data[0], "ndim", 1) == 1 + and not (isinstance(data, ExtensionArray) and data.ndim == 2) + ) # --------------------------------------------------------------------- def _prep_ndarray(values, copy: bool = True) -> np.ndarray: + if isinstance(values, TimedeltaArray): + # On older numpy, np.asarray below apparently does not call __array__, + # so nanoseconds get dropped. + values = values._ndarray + if not isinstance(values, (np.ndarray, ABCSeries, Index)): if len(values) == 0: return np.empty((0, 0), dtype=object) diff --git a/pandas/core/internals/managers.py b/pandas/core/internals/managers.py index 97d605e2fa2d1..5db6592ba77f9 100644 --- a/pandas/core/internals/managers.py +++ b/pandas/core/internals/managers.py @@ -9,6 +9,7 @@ Hashable, Sequence, TypeVar, + cast, ) import warnings @@ -32,6 +33,7 @@ from pandas.core.dtypes.cast import infer_dtype_from_scalar from pandas.core.dtypes.common import ( ensure_platform_int, + is_1d_only_ea_dtype, is_dtype_equal, is_extension_array_dtype, is_list_like, @@ -47,6 +49,7 @@ ) import pandas.core.algorithms as algos +from pandas.core.arrays._mixins import NDArrayBackedExtensionArray from pandas.core.arrays.sparse import SparseDtype from pandas.core.construction import ( ensure_wrapped_if_datetimelike, @@ -1048,6 +1051,19 @@ def __init__( f"Number of Block dimensions ({block.ndim}) must equal " f"number of axes ({self.ndim})" ) + if isinstance(block, DatetimeTZBlock) and block.values.ndim == 1: + # TODO: remove once fastparquet no longer needs this + # error: Incompatible types in assignment (expression has type + # "Union[ExtensionArray, ndarray]", variable has type + # "DatetimeArray") + block.values = ensure_block_shape( # type: ignore[assignment] + block.values, self.ndim + ) + try: + block._cache.clear() + except AttributeError: + # _cache not initialized + pass self._verify_integrity() @@ -1149,7 +1165,8 @@ def iset(self, loc: int | slice | np.ndarray, value: ArrayLike): self._rebuild_blknos_and_blklocs() # Note: we exclude DTA/TDA here - value_is_extension_type = is_extension_array_dtype(value) + vdtype = getattr(value, "dtype", None) + value_is_extension_type = is_1d_only_ea_dtype(vdtype) # categorical/sparse/datetimetz if value_is_extension_type: @@ -1780,7 +1797,12 @@ def _form_blocks( if len(items_dict["DatetimeTZBlock"]): dttz_blocks = [ - new_block(array, klass=DatetimeTZBlock, placement=i, ndim=2) + new_block( + ensure_block_shape(extract_array(array), 2), + klass=DatetimeTZBlock, + placement=i, + ndim=2, + ) for i, array in items_dict["DatetimeTZBlock"] ] blocks.extend(dttz_blocks) @@ -1917,11 +1939,19 @@ def _merge_blocks( # TODO: optimization potential in case all mgrs contain slices and # combination of those slices is a slice, too. new_mgr_locs = np.concatenate([b.mgr_locs.as_array for b in blocks]) - # error: List comprehension has incompatible type List[Union[ndarray, - # ExtensionArray]]; expected List[Union[complex, generic, Sequence[Union[int, - # float, complex, str, bytes, generic]], Sequence[Sequence[Any]], - # _SupportsArray]] - new_values = np.vstack([b.values for b in blocks]) # type: ignore[misc] + + new_values: ArrayLike + + if isinstance(blocks[0].dtype, np.dtype): + # error: List comprehension has incompatible type List[Union[ndarray, + # ExtensionArray]]; expected List[Union[complex, generic, + # Sequence[Union[int, float, complex, str, bytes, generic]], + # Sequence[Sequence[Any]], SupportsArray]] + new_values = np.vstack([b.values for b in blocks]) # type: ignore[misc] + else: + bvals = [blk.values for blk in blocks] + bvals2 = cast(Sequence[NDArrayBackedExtensionArray], bvals) + new_values = bvals2[0]._concat_same_type(bvals2, axis=0) argsort = np.argsort(new_mgr_locs) new_values = new_values[argsort] diff --git a/pandas/core/reshape/reshape.py b/pandas/core/reshape/reshape.py index d889e84cb9045..1a4d8dbe2885e 100644 --- a/pandas/core/reshape/reshape.py +++ b/pandas/core/reshape/reshape.py @@ -16,6 +16,7 @@ from pandas.core.dtypes.cast import maybe_promote from pandas.core.dtypes.common import ( ensure_platform_int, + is_1d_only_ea_dtype, is_bool_dtype, is_extension_array_dtype, is_integer, @@ -438,7 +439,7 @@ def unstack(obj, level, fill_value=None): f"index must be a MultiIndex to unstack, {type(obj.index)} was passed" ) else: - if is_extension_array_dtype(obj.dtype): + if is_1d_only_ea_dtype(obj.dtype): return _unstack_extension_series(obj, level, fill_value) unstacker = _Unstacker( obj.index, level=level, constructor=obj._constructor_expanddim diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py index 8e6c330475e68..b9c1113e7f441 100644 --- a/pandas/tests/arrays/test_datetimes.py +++ b/pandas/tests/arrays/test_datetimes.py @@ -226,6 +226,16 @@ def test_fillna_2d(self): res4 = dta2.fillna(method="backfill") tm.assert_extension_array_equal(res4, expected2) + # test the DataFrame method while we're here + df = pd.DataFrame(dta) + res = df.fillna(method="pad") + expected = pd.DataFrame(expected1) + tm.assert_frame_equal(res, expected) + + res = df.fillna(method="backfill") + expected = pd.DataFrame(expected2) + tm.assert_frame_equal(res, expected) + def test_array_interface_tz(self): tz = "US/Central" data = DatetimeArray(pd.date_range("2017", periods=2, tz=tz)) diff --git a/pandas/tests/extension/base/constructors.py b/pandas/tests/extension/base/constructors.py index 56c3f8216f033..6e4ed7b77cad8 100644 --- a/pandas/tests/extension/base/constructors.py +++ b/pandas/tests/extension/base/constructors.py @@ -3,7 +3,10 @@ import pandas as pd from pandas.api.extensions import ExtensionArray -from pandas.core.internals import ExtensionBlock +from pandas.core.internals.blocks import ( + DatetimeTZBlock, + ExtensionBlock, +) from pandas.tests.extension.base.base import BaseExtensionTests @@ -26,14 +29,14 @@ def test_series_constructor(self, data): assert result.dtype == data.dtype assert len(result) == len(data) if hasattr(result._mgr, "blocks"): - assert isinstance(result._mgr.blocks[0], ExtensionBlock) + assert isinstance(result._mgr.blocks[0], (ExtensionBlock, DatetimeTZBlock)) assert result._mgr.array is data # Series[EA] is unboxed / boxed correctly result2 = pd.Series(result) assert result2.dtype == data.dtype if hasattr(result._mgr, "blocks"): - assert isinstance(result2._mgr.blocks[0], ExtensionBlock) + assert isinstance(result2._mgr.blocks[0], (ExtensionBlock, DatetimeTZBlock)) def test_series_constructor_no_data_with_index(self, dtype, na_value): result = pd.Series(index=[1, 2, 3], dtype=dtype) @@ -68,7 +71,7 @@ def test_dataframe_constructor_from_dict(self, data, from_series): assert result.dtypes["A"] == data.dtype assert result.shape == (len(data), 1) if hasattr(result._mgr, "blocks"): - assert isinstance(result._mgr.blocks[0], ExtensionBlock) + assert isinstance(result._mgr.blocks[0], (ExtensionBlock, DatetimeTZBlock)) assert isinstance(result._mgr.arrays[0], ExtensionArray) def test_dataframe_from_series(self, data): @@ -76,7 +79,7 @@ def test_dataframe_from_series(self, data): assert result.dtypes[0] == data.dtype assert result.shape == (len(data), 1) if hasattr(result._mgr, "blocks"): - assert isinstance(result._mgr.blocks[0], ExtensionBlock) + assert isinstance(result._mgr.blocks[0], (ExtensionBlock, DatetimeTZBlock)) assert isinstance(result._mgr.arrays[0], ExtensionArray) def test_series_given_mismatched_index_raises(self, data): diff --git a/pandas/tests/frame/methods/test_set_index.py b/pandas/tests/frame/methods/test_set_index.py index 430abd9700a23..62dc400f8de9f 100644 --- a/pandas/tests/frame/methods/test_set_index.py +++ b/pandas/tests/frame/methods/test_set_index.py @@ -96,15 +96,18 @@ def test_set_index_cast_datetimeindex(self): idf = df.set_index("A") assert isinstance(idf.index, DatetimeIndex) - def test_set_index_dst(self): + def test_set_index_dst(self, using_array_manager): di = date_range("2006-10-29 00:00:00", periods=3, freq="H", tz="US/Pacific") df = DataFrame(data={"a": [0, 1, 2], "b": [3, 4, 5]}, index=di).reset_index() # single level res = df.set_index("index") exp = DataFrame( - data={"a": [0, 1, 2], "b": [3, 4, 5]}, index=Index(di, name="index") + data={"a": [0, 1, 2], "b": [3, 4, 5]}, + index=Index(di, name="index"), ) + if not using_array_manager: + exp.index = exp.index._with_freq(None) tm.assert_frame_equal(res, exp) # GH#12920 diff --git a/pandas/tests/frame/test_block_internals.py b/pandas/tests/frame/test_block_internals.py index 748aa462cddae..ba0acdc4f947b 100644 --- a/pandas/tests/frame/test_block_internals.py +++ b/pandas/tests/frame/test_block_internals.py @@ -45,7 +45,7 @@ def test_setitem_invalidates_datetime_index_freq(self): ts = dti[1] df = DataFrame({"B": dti}) - assert df["B"]._values.freq == "D" + assert df["B"]._values.freq is None df.iloc[1, 0] = pd.NaT assert df["B"]._values.freq is None diff --git a/pandas/tests/internals/test_internals.py b/pandas/tests/internals/test_internals.py index a1c5810ba8bb8..3299503dbc3a4 100644 --- a/pandas/tests/internals/test_internals.py +++ b/pandas/tests/internals/test_internals.py @@ -545,7 +545,7 @@ def test_astype(self, t): mgr = create_mgr("a,b: object; c: bool; d: datetime; e: f4; f: f2; g: f8") t = np.dtype(t) - with tm.assert_produces_warning(warn): + with tm.assert_produces_warning(warn, check_stacklevel=False): tmgr = mgr.astype(t, errors="ignore") assert tmgr.iget(2).dtype.type == t assert tmgr.iget(4).dtype.type == t @@ -618,10 +618,10 @@ def _compare(old_mgr, new_mgr): assert new_mgr.iget(8).dtype == np.float16 def test_invalid_ea_block(self): - with pytest.raises(AssertionError, match="block.size != values.size"): + with pytest.raises(ValueError, match="need to split"): create_mgr("a: category; b: category") - with pytest.raises(AssertionError, match="block.size != values.size"): + with pytest.raises(ValueError, match="need to split"): create_mgr("a: category2; b: category2") def test_interleave(self): diff --git a/pandas/tests/series/test_constructors.py b/pandas/tests/series/test_constructors.py index 82961a42e4ff0..67649e6e37b35 100644 --- a/pandas/tests/series/test_constructors.py +++ b/pandas/tests/series/test_constructors.py @@ -1341,7 +1341,7 @@ def test_constructor_dtype_timedelta64(self): # td.astype('m8[%s]' % t) # valid astype - with tm.assert_produces_warning(FutureWarning): + with tm.assert_produces_warning(FutureWarning, check_stacklevel=False): # astype(int64) deprecated td.astype("int64")