diff --git a/pandas/_typing.py b/pandas/_typing.py index 68ec331c2781f..85e29681285f4 100644 --- a/pandas/_typing.py +++ b/pandas/_typing.py @@ -7,7 +7,6 @@ BufferedIOBase, RawIOBase, TextIOBase, - TextIOWrapper, ) from mmap import mmap from os import PathLike @@ -170,7 +169,7 @@ PythonFuncType = Callable[[Any], Any] # filenames and file-like-objects -Buffer = Union[IO[AnyStr], RawIOBase, BufferedIOBase, TextIOBase, TextIOWrapper, mmap] +Buffer = Union[IO[AnyStr], RawIOBase, BufferedIOBase, TextIOBase, mmap] FileOrBuffer = Union[str, Buffer[AnyStr]] FilePathOrBuffer = Union["PathLike[str]", FileOrBuffer[AnyStr]] diff --git a/pandas/io/common.py b/pandas/io/common.py index 12c7afc8ee2e4..1aacbfa2bfb64 100644 --- a/pandas/io/common.py +++ b/pandas/io/common.py @@ -23,8 +23,11 @@ IO, Any, AnyStr, + Generic, + Literal, Mapping, cast, + overload, ) from urllib.parse import ( urljoin, @@ -78,7 +81,7 @@ class IOArgs: @dataclasses.dataclass -class IOHandles: +class IOHandles(Generic[AnyStr]): """ Return value of io/common.py:get_handle @@ -92,7 +95,7 @@ class IOHandles: is_wrapped: Whether a TextIOWrapper needs to be detached. """ - handle: Buffer + handle: Buffer[AnyStr] compression: CompressionDict created_handles: list[Buffer] = dataclasses.field(default_factory=list) is_wrapped: bool = False @@ -118,7 +121,7 @@ def close(self) -> None: self.created_handles = [] self.is_wrapped = False - def __enter__(self) -> IOHandles: + def __enter__(self) -> IOHandles[AnyStr]: return self def __exit__(self, *args: Any) -> None: @@ -533,16 +536,47 @@ def check_parent_directory(path: Path | str) -> None: raise OSError(fr"Cannot save file into a non-existent directory: '{parent}'") +@overload def get_handle( path_or_buf: FilePathOrBuffer, mode: str, + *, + encoding: str | None = ..., + compression: CompressionOptions = ..., + memory_map: bool = ..., + is_text: Literal[False], + errors: str | None = ..., + storage_options: StorageOptions = ..., +) -> IOHandles[bytes]: + ... + + +@overload +def get_handle( + path_or_buf: FilePathOrBuffer, + mode: str, + *, + encoding: str | None = ..., + compression: CompressionOptions = ..., + memory_map: bool = ..., + is_text: Literal[True] = True, + errors: str | None = ..., + storage_options: StorageOptions = ..., +) -> IOHandles[str]: + ... + + +def get_handle( + path_or_buf: FilePathOrBuffer, + mode: str, + *, encoding: str | None = None, compression: CompressionOptions = None, memory_map: bool = False, is_text: bool = True, errors: str | None = None, storage_options: StorageOptions = None, -) -> IOHandles: +) -> IOHandles[str] | IOHandles[bytes]: """ Get file handle for given path/buffer and mode. diff --git a/pandas/io/excel/_base.py b/pandas/io/excel/_base.py index ed79a5ad98ab9..22fbaaaa8b2f8 100644 --- a/pandas/io/excel/_base.py +++ b/pandas/io/excel/_base.py @@ -941,7 +941,9 @@ def __init__( mode = mode.replace("a", "r+") # cast ExcelWriter to avoid adding 'if self.handles is not None' - self.handles = IOHandles(cast(Buffer, path), compression={"copression": None}) + self.handles = IOHandles( + cast(Buffer[bytes], path), compression={"copression": None} + ) if not isinstance(path, ExcelWriter): self.handles = get_handle( path, mode, storage_options=storage_options, is_text=False diff --git a/pandas/io/json/_json.py b/pandas/io/json/_json.py index b9bdfb91ca154..8c44b54e75a3f 100644 --- a/pandas/io/json/_json.py +++ b/pandas/io/json/_json.py @@ -661,7 +661,7 @@ def __init__( self.nrows_seen = 0 self.nrows = nrows self.encoding_errors = encoding_errors - self.handles: IOHandles | None = None + self.handles: IOHandles[str] | None = None if self.chunksize is not None: self.chunksize = validate_integer("chunksize", self.chunksize, 1) diff --git a/pandas/io/parquet.py b/pandas/io/parquet.py index e92afd4e35ca1..2eb1dd2d44d65 100644 --- a/pandas/io/parquet.py +++ b/pandas/io/parquet.py @@ -74,7 +74,7 @@ def _get_path_or_handle( storage_options: StorageOptions = None, mode: str = "rb", is_dir: bool = False, -) -> tuple[FilePathOrBuffer, IOHandles | None, Any]: +) -> tuple[FilePathOrBuffer, IOHandles[bytes] | None, Any]: """File handling for PyArrow.""" path_or_handle = stringify_path(path) if is_fsspec_url(path_or_handle) and fs is None: diff --git a/pandas/io/parsers/base_parser.py b/pandas/io/parsers/base_parser.py index 043eb34e18798..42b9c8c9f10fe 100644 --- a/pandas/io/parsers/base_parser.py +++ b/pandas/io/parsers/base_parser.py @@ -212,7 +212,7 @@ def __init__(self, kwds): self.usecols, self.usecols_dtype = self._validate_usecols_arg(kwds["usecols"]) - self.handles: IOHandles | None = None + self.handles: IOHandles[str] | None = None # Fallback to error to pass a sketchy test(test_override_set_noconvert_columns) # Normally, this arg would get pre-processed earlier on diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 9803a2e4e3309..013f17580600d 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -2294,7 +2294,7 @@ def __init__( self._value_labels: list[StataValueLabel] = [] self._has_value_labels = np.array([], dtype=bool) self._compression = compression - self._output_file: Buffer | None = None + self._output_file: Buffer[bytes] | None = None self._converted_names: dict[Hashable, str] = {} # attach nobs, nvars, data, varlist, typlist self._prepare_pandas(data) diff --git a/pandas/tests/io/xml/test_to_xml.py b/pandas/tests/io/xml/test_to_xml.py index b8d146c597d2c..c257b61db296e 100644 --- a/pandas/tests/io/xml/test_to_xml.py +++ b/pandas/tests/io/xml/test_to_xml.py @@ -1287,7 +1287,8 @@ def test_compression_output(parser, comp): output = equalize_decl(output) - assert geom_xml == output.strip() + # error: Item "None" of "Union[str, bytes, None]" has no attribute "strip" + assert geom_xml == output.strip() # type: ignore[union-attr] @pytest.mark.parametrize("comp", ["bz2", "gzip", "xz", "zip"]) @@ -1305,7 +1306,8 @@ def test_filename_and_suffix_comp(parser, comp, compfile): output = equalize_decl(output) - assert geom_xml == output.strip() + # error: Item "None" of "Union[str, bytes, None]" has no attribute "strip" + assert geom_xml == output.strip() # type: ignore[union-attr] def test_unsuported_compression(datapath, parser):