Skip to content

Commit

Permalink
[BE] Import Literal, Protocol, and Final from standard library …
Browse files Browse the repository at this point in the history
…`typing` as of Python 3.8+ (pytorch#94490)

Changes:

1. `typing_extensions -> typing-extentions` in dependency. Use dash rather than underline to fit the [PEP 503: Normalized Names](https://peps.python.org/pep-0503/#normalized-names) convention.

```python
import re

def normalize(name):
    return re.sub(r"[-_.]+", "-", name).lower()
```

2. Import `Literal`, `Protocal`, and `Final` from standard library as of Python 3.8+
3. Replace `Union[Literal[XXX], Literal[YYY]]` to `Literal[XXX, YYY]`.

Pull Request resolved: pytorch#94490
Approved by: https://github.com/ezyang, https://github.com/albanD
  • Loading branch information
XuehaiPan authored and pytorchmergebot committed Feb 9, 2023
1 parent 527b646 commit 69e0bda
Show file tree
Hide file tree
Showing 47 changed files with 76 additions and 102 deletions.
4 changes: 2 additions & 2 deletions .ci/docker/common/install_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
# Install llvm-8 as it is required to compile llvmlite-0.30.0 from source
conda_install numpy=1.18.5 ${CONDA_COMMON_DEPS} llvmdev=8.0.0
else
# Install `typing_extensions` for 3.7
conda_install numpy=1.18.5 ${CONDA_COMMON_DEPS} typing_extensions
# Install `typing-extensions` for 3.7
conda_install numpy=1.18.5 ${CONDA_COMMON_DEPS} typing-extensions
fi

# Use conda cmake in some cases. Conda cmake will be newer than our supported
Expand Down
2 changes: 1 addition & 1 deletion .circleci/config.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion .circleci/scripts/binary_ios_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ export PATH="~/anaconda/bin:${PATH}"
source ~/anaconda/bin/activate

# Install dependencies
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake requests typing_extensions --yes
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake requests typing-extensions --yes
conda install -c conda-forge valgrind --yes
export CMAKE_PREFIX_PATH=${CONDA_PREFIX:-"$(dirname $(which conda))/../"}

Expand Down
2 changes: 1 addition & 1 deletion .circleci/verbatim-sources/job-specs/job-specs-custom.yml
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@
$* || (sleep 1 && $*) || (sleep 2 && $*) || (sleep 4 && $*) || (sleep 8 && $*)
}
retry conda install numpy ninja pyyaml mkl mkl-include setuptools cmake requests typing_extensions --yes
retry conda install numpy ninja pyyaml mkl mkl-include setuptools cmake requests typing-extensions --yes
# sync submodules
cd ${PROJ_ROOT}
Expand Down
2 changes: 1 addition & 1 deletion .github/requirements/conda-env-Linux-X64
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ numpy=1.23.3
pyyaml=6.0
requests=2.28.1
setuptools=65.5.0
typing_extensions=4.3.0
typing-extensions=4.3.0
2 changes: 1 addition & 1 deletion .github/requirements/conda-env-iOS
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ numpy=1.23.3
pyyaml=6.0
requests=2.28.1
setuptools=63.4.1
typing_extensions=4.3.0
typing-extensions=4.3.0
2 changes: 1 addition & 1 deletion .github/requirements/conda-env-macOS-ARM64
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ numpy=1.22.3
pyyaml=6.0
setuptools=61.2.0
cmake=3.22.*
typing_extensions=4.3.0
typing-extensions=4.3.0
dataclasses=0.8
pip=22.2.2
six=1.16.0
Expand Down
2 changes: 1 addition & 1 deletion .github/requirements/conda-env-macOS-X64
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ numpy=1.18.5
pyyaml=5.3
setuptools=46.0.0
cmake=3.22.*
typing_extensions=4.3.0
typing-extensions=4.3.0
dataclasses=0.8
pip=22.2.2
six=1.16.0
Expand Down
2 changes: 1 addition & 1 deletion .github/requirements/regenerate-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
typing_extensions
typing-extensions
jinja2
4 changes: 2 additions & 2 deletions .github/scripts/generate_ci_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Dict, Set, List, Iterable
from typing import Dict, Set, List, Literal, Iterable

import jinja2

import os
import sys
from typing_extensions import Literal, TypedDict
from typing_extensions import TypedDict # Python 3.11+

import generate_binary_build_matrix # type: ignore[import]

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run_torchbench.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
. "${SETUP_SCRIPT}"
conda activate pr-ci
conda install -y numpy="${NUMPY_VERSION}" requests ninja pyyaml mkl mkl-include \
setuptools cmake=3.22.* typing_extensions boto3 \
setuptools cmake=3.22.* typing-extensions boto3 \
six pillow pytest tabulate gitpython git-lfs tqdm psutil
pip install --pre torch torchvision torchtext -f https://download.pytorch.org/whl/nightly/cu116/torch_nightly.html
- name: Setup TorchBench branch
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/dynamo/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pull-deps: clone-deps
(cd ../../../torchvision && git pull && git submodule update --init --recursive)
(cd ../../../torchdata && git pull && git submodule update --init --recursive)
(cd ../../../torchtext && git pull && git submodule update --init --recursive)
(cd ../../../torchaudio && git pull && git submodule update --init --recursive)
(cd ../../../torchaudio && git pull && git submodule update --init --recursive)
(cd ../../../detectron2 && git pull && git submodule update --init --recursive)
(cd ../../../torchbenchmark && git pull && git submodule update --init --recursive)
(cd ../../../triton && git fetch && git checkout $(TRITON_VERSION) && git submodule update --init --recursive)
Expand All @@ -28,7 +28,7 @@ build-deps: clone-deps
# conda create --name torchdynamo -y python=3.8
# conda activate torchdynamo
conda install -y astunparse numpy scipy ninja pyyaml mkl mkl-include setuptools cmake \
typing_extensions six requests protobuf numba cython scikit-learn
typing-extensions six requests protobuf numba cython scikit-learn
conda install -y -c pytorch magma-cuda116
conda install -y -c conda-forge librosa
(cd ../../../torchvision && python setup.py clean && python setup.py develop)
Expand Down
7 changes: 1 addition & 6 deletions docs/source/jit.rst
Original file line number Diff line number Diff line change
Expand Up @@ -831,12 +831,7 @@ New API:

::

try:
from typing_extensions import Final
except:
# If you don't have `typing_extensions` installed, you can use a
# polyfill from `torch.jit`.
from torch.jit import Final
from typing import Final

class MyModule(torch.nn.Module):

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ requires = [
"pyyaml",
"setuptools",
"cmake",
"typing_extensions",
"typing-extensions",
"six",
"requests",
]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pyyaml
requests
setuptools
types-dataclasses
typing_extensions
typing-extensions
sympy
filelock
networkx
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ def main():
# the list of runtime dependencies required by this built package
install_requires = [
'filelock',
'typing_extensions',
'typing-extensions',
'sympy',
'networkx',
]
Expand Down
2 changes: 1 addition & 1 deletion tools/extract_scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Dict, Optional

import yaml
from typing_extensions import TypedDict
from typing_extensions import TypedDict # Python 3.11+

Step = Dict[str, Any]

Expand Down
2 changes: 1 addition & 1 deletion tools/fast_nvcc/fast_nvcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import time
from typing import Awaitable, cast, DefaultDict, Dict, List, Match, Optional, Set

from typing_extensions import TypedDict
from typing_extensions import TypedDict # Python 3.11+

help_msg = """fast_nvcc [OPTION]... -- [NVCC_ARG]...
Expand Down
5 changes: 2 additions & 3 deletions tools/jit/gen_unboxing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pathlib
import sys
from dataclasses import dataclass
from typing import List, Sequence, Union
from typing import List, Literal, Sequence, Union

import yaml

Expand All @@ -17,13 +17,12 @@
from torchgen.model import Argument, NativeFunction, NativeFunctionsGroup, Variant
from torchgen.selective_build.selector import SelectiveBuilder
from torchgen.utils import FileManager, make_file_manager, mapMaybe, Target
from typing_extensions import Literal


# Generates UnboxingFunctions.h & UnboxingFunctions.cpp.
@dataclass(frozen=True)
class ComputeUnboxingFunctions:
target: Union[Literal[Target.DECLARATION], Literal[Target.DEFINITION]]
target: Literal[Target.DECLARATION, Target.DEFINITION]
selector: SelectiveBuilder

@method_with_native_function
Expand Down
2 changes: 1 addition & 1 deletion tools/onnx/sarif/gen_sarif.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ python -m jschema_to_python \
-vv

# Generate SARIF version file
echo "from typing_extensions import Final" > "${ROOT}/${SARIF_DIR}/version.py"
echo "from typing import Final" > "${ROOT}/${SARIF_DIR}/version.py"
echo "SARIF_VERSION: Final = \"${SARIF_VERSION}\"" >> "${ROOT}/${SARIF_DIR}/version.py"
echo "SARIF_SCHEMA_LINK: Final = \"${SARIF_SCHEMA_LINK}\"" >> "${ROOT}/${SARIF_DIR}/version.py"

Expand Down
3 changes: 1 addition & 2 deletions torch/_C/_VariableFunctions.pyi.in
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# ${generated_comment}

from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar
from typing_extensions import Literal
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, Literal, TypeVar
from torch._six import inf

from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout, SymInt, Device
Expand Down
3 changes: 1 addition & 2 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ from pathlib import Path
from typing import (
Any, BinaryIO, Callable, ContextManager, Dict, Iterable, Iterator, List,
NamedTuple, Optional, overload, Sequence, Tuple, TypeVar, Type, Union,
Generic, Set, AnyStr)
from typing_extensions import Literal
Literal, Generic, Set, AnyStr)
from torch._six import inf

from torch.types import (
Expand Down
4 changes: 1 addition & 3 deletions torch/_C/_profiler.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from enum import Enum
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union

from torch._C import device, dtype, layout

from typing_extensions import Literal

# defined in torch/csrc/profiler/python/init.cpp

class RecordScope(Enum):
Expand Down
3 changes: 1 addition & 2 deletions torch/_C/return_types.pyi.in
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# ${generated_comment}

from torch import Tensor, Generator, strided, memory_format, contiguous_format, strided
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, TypeVar
from typing_extensions import Literal
from typing import List, Tuple, Optional, Union, Any, ContextManager, Callable, overload, Iterator, NamedTuple, Sequence, Literal, TypeVar
from torch._six import inf

from torch.types import _int, _float, _bool, Number, _dtype, _device, _qscheme, _size, _layout
Expand Down
4 changes: 1 addition & 3 deletions torch/_dynamo/backends/registry.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import functools
from typing import Callable, Dict, List, Optional, Sequence, Tuple

from typing_extensions import Protocol
from typing import Callable, Dict, List, Optional, Protocol, Sequence, Tuple

import torch
from torch import fx
Expand Down
12 changes: 10 additions & 2 deletions torch/_dynamo/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import dataclasses
import sys
import types
from typing import Callable, Dict, List, NamedTuple, Optional, OrderedDict, Union
from typing import (
Callable,
Dict,
List,
NamedTuple,
Optional,
OrderedDict,
Protocol,
Union,
)

from typing_extensions import Protocol

if sys.version_info >= (3, 11):
from torch._C._dynamo import eval_frame
Expand Down
4 changes: 1 addition & 3 deletions torch/_refs/fft.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import math

from typing import Iterable, List, NamedTuple, Optional, Sequence, Tuple, Union

from typing_extensions import Literal
from typing import Iterable, List, Literal, NamedTuple, Optional, Sequence, Tuple, Union

import torch
import torch._prims as prims
Expand Down
8 changes: 1 addition & 7 deletions torch/distributed/pipeline/sync/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
from contextlib import contextmanager
import threading
from typing import (
TYPE_CHECKING,
Any,
Deque,
Generator,
List,
Optional,
Protocol,
Union,
Sequence,
Tuple
Expand All @@ -60,12 +60,6 @@
RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state)


if TYPE_CHECKING:
from typing_extensions import Protocol
else:
Protocol = object


# Protocol with __call__ instead of Callable can be used as an attribute type.
# See: https://github.com/python/mypy/issues/708#issuecomment-561735949
class Function(Protocol):
Expand Down
2 changes: 1 addition & 1 deletion torch/nn/modules/lazy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
from typing_extensions import Protocol
import warnings
from typing import Protocol

import torch
from ..parameter import is_lazy
Expand Down
4 changes: 1 addition & 3 deletions torch/onnx/_internal/diagnostics/infra/sarif/_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from __future__ import annotations

import dataclasses
from typing import Any, List, Optional

from typing_extensions import Literal
from typing import Any, List, Literal, Optional

from torch.onnx._internal.diagnostics.infra.sarif import (
_artifact_content,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from __future__ import annotations

import dataclasses
from typing import List, Optional

from typing_extensions import Literal
from typing import List, Literal, Optional

from torch.onnx._internal.diagnostics.infra.sarif import (
_address,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from __future__ import annotations

import dataclasses
from typing import List, Optional

from typing_extensions import Literal
from typing import List, Literal, Optional

from torch.onnx._internal.diagnostics.infra.sarif import (
_exception,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from __future__ import annotations

import dataclasses
from typing import Optional

from typing_extensions import Literal
from typing import Literal, Optional

from torch.onnx._internal.diagnostics.infra.sarif import _property_bag

Expand Down
4 changes: 1 addition & 3 deletions torch/onnx/_internal/diagnostics/infra/sarif/_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from __future__ import annotations

import dataclasses
from typing import Any, List, Optional

from typing_extensions import Literal
from typing import Any, List, Literal, Optional

from torch.onnx._internal.diagnostics.infra.sarif import (
_artifact_location,
Expand Down
4 changes: 1 addition & 3 deletions torch/onnx/_internal/diagnostics/infra/sarif/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from __future__ import annotations

import dataclasses
from typing import Any, List, Optional

from typing_extensions import Literal
from typing import Any, List, Literal, Optional

from torch.onnx._internal.diagnostics.infra.sarif import (
_address,
Expand Down
4 changes: 1 addition & 3 deletions torch/onnx/_internal/diagnostics/infra/sarif/_sarif_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from __future__ import annotations

import dataclasses
from typing import List, Optional

from typing_extensions import Literal
from typing import List, Literal, Optional

from torch.onnx._internal.diagnostics.infra.sarif import (
_external_properties,
Expand Down
Loading

0 comments on commit 69e0bda

Please sign in to comment.