Skip to content

Commit

Permalink
[alt] typing: accept buffers in IO.write
Browse files Browse the repository at this point in the history
Co-authored-by: JelleZijlstra <jelle.zijlstra@gmail.com>
  • Loading branch information
AlexWaygood committed Mar 10, 2023
1 parent 390058e commit d555ade
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 12 deletions.
3 changes: 2 additions & 1 deletion stdlib/codecs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,9 @@ class StreamRecoder(BinaryIO):
def readlines(self, sizehint: int | None = None) -> list[bytes]: ...
def __next__(self) -> bytes: ...
def __iter__(self) -> Self: ...
# Base class accepts more types than just bytes
def write(self, data: bytes) -> None: ... # type: ignore[override]
def writelines(self, list: Iterable[bytes]) -> None: ...
def writelines(self, list: Iterable[bytes]) -> None: ... # type: ignore[override]
def reset(self) -> None: ...
def __getattr__(self, name: str) -> Any: ...
def __enter__(self) -> Self: ...
Expand Down
2 changes: 1 addition & 1 deletion stdlib/http/client.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class HTTPMessage(email.message.Message):

def parse_headers(fp: io.BufferedIOBase, _class: Callable[[], email.message.Message] = ...) -> HTTPMessage: ...

class HTTPResponse(io.BufferedIOBase, BinaryIO):
class HTTPResponse(io.BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible method definitions in the base classes
msg: HTTPMessage
headers: HTTPMessage
version: int
Expand Down
12 changes: 6 additions & 6 deletions stdlib/io.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class BufferedIOBase(IOBase):
def read(self, __size: int | None = ...) -> bytes: ...
def read1(self, __size: int = ...) -> bytes: ...

class FileIO(RawIOBase, BinaryIO):
class FileIO(RawIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of writelines in the base classes
mode: str
name: FileDescriptorOrPath # type: ignore[assignment]
def __init__(
Expand All @@ -102,7 +102,7 @@ class FileIO(RawIOBase, BinaryIO):
def read(self, __size: int = -1) -> bytes: ...
def __enter__(self) -> Self: ...

class BytesIO(BufferedIOBase, BinaryIO):
class BytesIO(BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of methods in the base classes
def __init__(self, initial_bytes: ReadableBuffer = ...) -> None: ...
# BytesIO does not contain a "name" field. This workaround is necessary
# to allow BytesIO sub-classes to add this field, as it is defined
Expand All @@ -113,17 +113,17 @@ class BytesIO(BufferedIOBase, BinaryIO):
def getbuffer(self) -> memoryview: ...
def read1(self, __size: int | None = -1) -> bytes: ...

class BufferedReader(BufferedIOBase, BinaryIO):
class BufferedReader(BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of methods in the base classes
def __enter__(self) -> Self: ...
def __init__(self, raw: RawIOBase, buffer_size: int = ...) -> None: ...
def peek(self, __size: int = 0) -> bytes: ...

class BufferedWriter(BufferedIOBase, BinaryIO):
class BufferedWriter(BufferedIOBase, BinaryIO): # type: ignore[misc] # incompatible definitions of writelines in the base classes
def __enter__(self) -> Self: ...
def __init__(self, raw: RawIOBase, buffer_size: int = ...) -> None: ...
def write(self, __buffer: ReadableBuffer) -> int: ...

class BufferedRandom(BufferedReader, BufferedWriter):
class BufferedRandom(BufferedReader, BufferedWriter): # type: ignore[misc] # incompatible definitions of methods in the base classes
def __enter__(self) -> Self: ...
def seek(self, __target: int, __whence: int = 0) -> int: ... # stubtest needs this

Expand All @@ -144,7 +144,7 @@ class TextIOBase(IOBase):
def readlines(self, __hint: int = -1) -> list[str]: ... # type: ignore[override]
def read(self, __size: int | None = ...) -> str: ...

class TextIOWrapper(TextIOBase, TextIO):
class TextIOWrapper(TextIOBase, TextIO): # type: ignore[misc] # incompatible definitions of write in the base classes
def __init__(
self,
buffer: IO[bytes],
Expand Down
2 changes: 1 addition & 1 deletion stdlib/lzma.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class LZMACompressor:

class LZMAError(Exception): ...

class LZMAFile(io.BufferedIOBase, IO[bytes]):
class LZMAFile(io.BufferedIOBase, IO[bytes]): # type: ignore[misc] # incompatible definitions of writelines in the base classes
def __init__(
self,
filename: _PathOrFile | None = None,
Expand Down
24 changes: 22 additions & 2 deletions stdlib/tempfile.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import io
import sys
from _typeshed import BytesPath, GenericPath, StrPath, WriteableBuffer
from _typeshed import BytesPath, GenericPath, ReadableBuffer, StrPath, WriteableBuffer
from collections.abc import Iterable, Iterator
from types import TracebackType
from typing import IO, Any, AnyStr, Generic, overload
Expand Down Expand Up @@ -215,7 +215,17 @@ class _TemporaryFileWrapper(Generic[AnyStr], IO[AnyStr]):
def tell(self) -> int: ...
def truncate(self, size: int | None = ...) -> int: ...
def writable(self) -> bool: ...
@overload
def write(self: _TemporaryFileWrapper[str], s: str) -> int: ...
@overload
def write(self: _TemporaryFileWrapper[bytes], s: ReadableBuffer) -> int: ...
@overload
def write(self, s: AnyStr) -> int: ...
@overload
def writelines(self: _TemporaryFileWrapper[str], lines: Iterable[str]) -> None: ...
@overload
def writelines(self: _TemporaryFileWrapper[bytes], lines: Iterable[ReadableBuffer]) -> None: ...
@overload
def writelines(self, lines: Iterable[AnyStr]) -> None: ...

if sys.version_info >= (3, 11):
Expand Down Expand Up @@ -392,8 +402,18 @@ class SpooledTemporaryFile(IO[AnyStr], _SpooledTemporaryFileBase):
def seek(self, offset: int, whence: int = ...) -> int: ...
def tell(self) -> int: ...
def truncate(self, size: int | None = None) -> None: ... # type: ignore[override]
@overload
def write(self: SpooledTemporaryFile[str], s: str) -> int: ...
@overload
def write(self: SpooledTemporaryFile[bytes], s: ReadableBuffer) -> int: ...
@overload
def write(self, s: AnyStr) -> int: ...
def writelines(self, iterable: Iterable[AnyStr]) -> None: ... # type: ignore[override]
@overload
def writelines(self: SpooledTemporaryFile[str], lines: Iterable[str]) -> None: ...
@overload
def writelines(self: SpooledTemporaryFile[bytes], lines: Iterable[ReadableBuffer]) -> None: ...
@overload
def writelines(self, lines: Iterable[AnyStr]) -> None: ...
def __iter__(self) -> Iterator[AnyStr]: ... # type: ignore[override]
# These exist at runtime only on 3.11+.
def readable(self) -> bool: ...
Expand Down
16 changes: 15 additions & 1 deletion stdlib/typing.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import collections # Needed by aliases like DefaultDict, see mypy issue 2986
import sys
import typing_extensions
from _collections_abc import dict_items, dict_keys, dict_values
from _typeshed import IdentityFunction, Incomplete, SupportsKeysAndGetItem
from _typeshed import IdentityFunction, Incomplete, ReadableBuffer, SupportsKeysAndGetItem
from abc import ABCMeta, abstractmethod
from contextlib import AbstractAsyncContextManager, AbstractContextManager
from re import Match as Match, Pattern as Pattern
Expand Down Expand Up @@ -687,8 +687,22 @@ class IO(Iterator[AnyStr], Generic[AnyStr]):
@abstractmethod
def writable(self) -> bool: ...
@abstractmethod
@overload
def write(self: IO[str], __s: str) -> int: ...
@abstractmethod
@overload
def write(self: IO[bytes], __s: ReadableBuffer) -> int: ...
@abstractmethod
@overload
def write(self, __s: AnyStr) -> int: ...
@abstractmethod
@overload
def writelines(self: IO[str], __lines: Iterable[str]) -> None: ...
@abstractmethod
@overload
def writelines(self: IO[bytes], __lines: Iterable[ReadableBuffer]) -> None: ...
@abstractmethod
@overload
def writelines(self, __lines: Iterable[AnyStr]) -> None: ...
@abstractmethod
def __next__(self) -> AnyStr: ...
Expand Down
21 changes: 21 additions & 0 deletions test_cases/stdlib/typing/check_io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

import mmap
from typing import IO, AnyStr


def check_write(io_bytes: IO[bytes], io_str: IO[str], io_anystr: IO[AnyStr], any_str: AnyStr, buf: mmap.mmap) -> None:
io_bytes.write(b"")
io_bytes.write(buf)
io_bytes.write("") # type: ignore
io_bytes.write(any_str) # type: ignore

io_str.write(b"") # type: ignore
io_str.write(buf) # type: ignore
io_str.write("")
io_str.write(any_str) # type: ignore

io_anystr.write(b"") # type: ignore
io_anystr.write(buf) # type: ignore
io_anystr.write("") # type: ignore
io_anystr.write(any_str)

0 comments on commit d555ade

Please sign in to comment.