Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[alt] typing: accept buffers in IO.write #9861

Merged
merged 1 commit into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion stdlib/codecs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,9 @@ class StreamRecoder(BinaryIO):
def readlines(self, sizehint: int | None = None) -> list[bytes]: ...
def __next__(self) -> bytes: ...
def __iter__(self) -> Self: ...
# Base class accepts more types than just bytes
def write(self, data: bytes) -> None: ... # type: ignore[override]
def writelines(self, list: Iterable[bytes]) -> None: ...
def writelines(self, list: Iterable[bytes]) -> None: ... # type: ignore[override]
def reset(self) -> None: ...
def __getattr__(self, name: str) -> Any: ...
def __enter__(self) -> Self: ...
Expand Down
2 changes: 1 addition & 1 deletion stdlib/http/client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class HTTPMessage(email.message.Message):

def parse_headers(fp: io.BufferedIOBase, _class: Callable[[], email.message.Message] = ...) -> HTTPMessage: ...

class HTTPResponse(io.BufferedIOBase, BinaryIO):
class HTTPResponse(io.BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible method definitions in the base classes
msg: HTTPMessage
headers: HTTPMessage
version: int
Expand Down
12 changes: 6 additions & 6 deletions stdlib/io.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class BufferedIOBase(IOBase):
def read(self, __size: int | None = ...) -> bytes: ...
def read1(self, __size: int = ...) -> bytes: ...

class FileIO(RawIOBase, BinaryIO):
class FileIO(RawIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of writelines in the base classes
mode: str
name: FileDescriptorOrPath # type: ignore[assignment]
def __init__(
Expand All @@ -102,7 +102,7 @@ class FileIO(RawIOBase, BinaryIO):
def read(self, __size: int = -1) -> bytes: ...
def __enter__(self) -> Self: ...

class BytesIO(BufferedIOBase, BinaryIO):
class BytesIO(BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of methods in the base classes
def __init__(self, initial_bytes: ReadableBuffer = ...) -> None: ...
# BytesIO does not contain a "name" field. This workaround is necessary
# to allow BytesIO sub-classes to add this field, as it is defined
Expand All @@ -113,17 +113,17 @@ class BytesIO(BufferedIOBase, BinaryIO):
def getbuffer(self) -> memoryview: ...
def read1(self, __size: int | None = -1) -> bytes: ...

class BufferedReader(BufferedIOBase, BinaryIO):
class BufferedReader(BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of methods in the base classes
def __enter__(self) -> Self: ...
def __init__(self, raw: RawIOBase, buffer_size: int = ...) -> None: ...
def peek(self, __size: int = 0) -> bytes: ...

class BufferedWriter(BufferedIOBase, BinaryIO):
class BufferedWriter(BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of writelines in the base classes
def __enter__(self) -> Self: ...
def __init__(self, raw: RawIOBase, buffer_size: int = ...) -> None: ...
def write(self, __buffer: ReadableBuffer) -> int: ...

class BufferedRandom(BufferedReader, BufferedWriter):
class BufferedRandom(BufferedReader, BufferedWriter): # type: ignore[misc] # incompatible definitions of methods in the base classes
def __enter__(self) -> Self: ...
def seek(self, __target: int, __whence: int = 0) -> int: ... # stubtest needs this

Expand All @@ -144,7 +144,7 @@ class TextIOBase(IOBase):
def readlines(self, __hint: int = -1) -> list[str]: ... # type: ignore[override]
def read(self, __size: int | None = ...) -> str: ...

class TextIOWrapper(TextIOBase, TextIO):
class TextIOWrapper(TextIOBase, TextIO): # type: ignore[misc] # incompatible definitions of write in the base classes
def __init__(
self,
buffer: IO[bytes],
Expand Down
2 changes: 1 addition & 1 deletion stdlib/lzma.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class LZMACompressor:

class LZMAError(Exception): ...

class LZMAFile(io.BufferedIOBase, IO[bytes]):
class LZMAFile(io.BufferedIOBase, IO[bytes]): # type: ignore[misc] # incompatible definitions of writelines in the base classes
def __init__(
self,
filename: _PathOrFile | None = None,
Expand Down
24 changes: 22 additions & 2 deletions stdlib/tempfile.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import sys
from _typeshed import BytesPath, GenericPath, StrPath, WriteableBuffer
from _typeshed import BytesPath, GenericPath, ReadableBuffer, StrPath, WriteableBuffer
from collections.abc import Iterable, Iterator
from types import TracebackType
from typing import IO, Any, AnyStr, Generic, overload
Expand Down Expand Up @@ -215,7 +215,17 @@ class _TemporaryFileWrapper(Generic[AnyStr], IO[AnyStr]):
def tell(self) -> int: ...
def truncate(self, size: int | None = ...) -> int: ...
def writable(self) -> bool: ...
@overload
def write(self: _TemporaryFileWrapper[str], s: str) -> int: ...
@overload
def write(self: _TemporaryFileWrapper[bytes], s: ReadableBuffer) -> int: ...
@overload
def write(self, s: AnyStr) -> int: ...
Comment on lines +222 to 223
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the third overload needed? This should be covered by the preceding two overloads. (Same for further overloads below.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately the test cases I've added don't pass for pyright unless we include this third overload. This is due to a design decision by pyright in the way it handles contrained TypeVars that leads to different behaviour from mypy: https://github.com/microsoft/pyright/blob/main/docs/mypy-comparison.md#constrained-type-variables

It's the same issue that led to microsoft/pyright#4534 being filed, which we then fixed over at typeshed in #9592.

@overload
def writelines(self: _TemporaryFileWrapper[str], lines: Iterable[str]) -> None: ...
@overload
def writelines(self: _TemporaryFileWrapper[bytes], lines: Iterable[ReadableBuffer]) -> None: ...
@overload
def writelines(self, lines: Iterable[AnyStr]) -> None: ...

if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -392,8 +402,18 @@ class SpooledTemporaryFile(IO[AnyStr], _SpooledTemporaryFileBase):
def seek(self, offset: int, whence: int = ...) -> int: ...
def tell(self) -> int: ...
def truncate(self, size: int | None = None) -> None: ... # type: ignore[override]
@overload
def write(self: SpooledTemporaryFile[str], s: str) -> int: ...
@overload
def write(self: SpooledTemporaryFile[bytes], s: ReadableBuffer) -> int: ...
@overload
def write(self, s: AnyStr) -> int: ...
def writelines(self, iterable: Iterable[AnyStr]) -> None: ... # type: ignore[override]
@overload
def writelines(self: SpooledTemporaryFile[str], iterable: Iterable[str]) -> None: ...
@overload
def writelines(self: SpooledTemporaryFile[bytes], iterable: Iterable[ReadableBuffer]) -> None: ...
@overload
def writelines(self, iterable: Iterable[AnyStr]) -> None: ...
def __iter__(self) -> Iterator[AnyStr]: ... # type: ignore[override]
# These exist at runtime only on 3.11+.
def readable(self) -> bool: ...
Expand Down
16 changes: 15 additions & 1 deletion stdlib/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import collections # Needed by aliases like DefaultDict, see mypy issue 2986
import sys
import typing_extensions
from _collections_abc import dict_items, dict_keys, dict_values
from _typeshed import IdentityFunction, Incomplete, SupportsKeysAndGetItem
from _typeshed import IdentityFunction, Incomplete, ReadableBuffer, SupportsKeysAndGetItem
from abc import ABCMeta, abstractmethod
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from re import Match as Match, Pattern as Pattern
Expand Down Expand Up @@ -687,8 +687,22 @@ class IO(Iterator[AnyStr], Generic[AnyStr]):
@abstractmethod
def writable(self) -> bool: ...
@abstractmethod
@overload
def write(self: IO[str], __s: str) -> int: ...
@abstractmethod
@overload
def write(self: IO[bytes], __s: ReadableBuffer) -> int: ...
@abstractmethod
@overload
def write(self, __s: AnyStr) -> int: ...
@abstractmethod
@overload
def writelines(self: IO[str], __lines: Iterable[str]) -> None: ...
@abstractmethod
@overload
def writelines(self: IO[bytes], __lines: Iterable[ReadableBuffer]) -> None: ...
@abstractmethod
@overload
def writelines(self, __lines: Iterable[AnyStr]) -> None: ...
@abstractmethod
def __next__(self) -> AnyStr: ...
Expand Down
21 changes: 21 additions & 0 deletions test_cases/stdlib/typing/check_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

import mmap
from typing import IO, AnyStr


def check_write(io_bytes: IO[bytes], io_str: IO[str], io_anystr: IO[AnyStr], any_str: AnyStr, buf: mmap.mmap) -> None:
io_bytes.write(b"")
io_bytes.write(buf)
io_bytes.write("") # type: ignore
io_bytes.write(any_str) # type: ignore

io_str.write(b"") # type: ignore
io_str.write(buf) # type: ignore
io_str.write("")
io_str.write(any_str) # type: ignore

io_anystr.write(b"") # type: ignore
io_anystr.write(buf) # type: ignore
io_anystr.write("") # type: ignore
io_anystr.write(any_str)