Skip to content

Commit

Permalink
Improve --update-data handler (#15283)
Browse files Browse the repository at this point in the history
  • Loading branch information
ikonst authored May 31, 2023
1 parent 85e6719 commit 3e03484
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 37 deletions.
42 changes: 40 additions & 2 deletions mypy/test/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import tempfile
from abc import abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterator, NamedTuple, Pattern, Union
from typing_extensions import Final, TypeAlias as _TypeAlias

Expand Down Expand Up @@ -426,8 +427,16 @@ class TestItem:

id: str
arg: str | None
# Processed, collapsed text data
data: list[str]
# Start line: 1-based, inclusive, relative to testcase
line: int
# End line: 1-based, exclusive, relative to testcase; not same as `line + len(test_item.data)` due to collapsing
end_line: int

@property
def trimmed_newlines(self) -> int: # compensates for strip_list
return self.end_line - self.line - len(self.data)


def parse_test_data(raw_data: str, name: str) -> list[TestItem]:
Expand All @@ -449,7 +458,7 @@ def parse_test_data(raw_data: str, name: str) -> list[TestItem]:
if id:
data = collapse_line_continuation(data)
data = strip_list(data)
ret.append(TestItem(id, arg, strip_list(data), i0 + 1))
ret.append(TestItem(id, arg, data, i0 + 1, i))

i0 = i
id = s[1:-1]
Expand All @@ -470,7 +479,7 @@ def parse_test_data(raw_data: str, name: str) -> list[TestItem]:
if id:
data = collapse_line_continuation(data)
data = strip_list(data)
ret.append(TestItem(id, arg, data, i0 + 1))
ret.append(TestItem(id, arg, data, i0 + 1, i - 1))

return ret

Expand Down Expand Up @@ -693,6 +702,12 @@ def collect(self) -> Iterator[DataFileCollector]:
yield DataFileCollector.from_parent(parent=self, name=data_file)


class DataFileFix(NamedTuple):
lineno: int # 1-offset, inclusive
end_lineno: int # 1-offset, exclusive
lines: list[str]


class DataFileCollector(pytest.Collector):
"""Represents a single `.test` data driven test file.
Expand All @@ -701,6 +716,8 @@ class DataFileCollector(pytest.Collector):

parent: DataSuiteCollector

_fixes: list[DataFileFix]

@classmethod # We have to fight with pytest here:
def from_parent(
cls, parent: DataSuiteCollector, *, name: str # type: ignore[override]
Expand All @@ -716,6 +733,27 @@ def collect(self) -> Iterator[DataDrivenTestCase]:
file=os.path.join(self.parent.obj.data_prefix, self.name),
)

def setup(self) -> None:
super().setup()
self._fixes = []

def teardown(self) -> None:
super().teardown()
self._apply_fixes()

def enqueue_fix(self, fix: DataFileFix) -> None:
self._fixes.append(fix)

def _apply_fixes(self) -> None:
if not self._fixes:
return
data_path = Path(self.parent.obj.data_prefix) / self.name
lines = data_path.read_text().split("\n")
# start from end to prevent line offsets from shifting as we update
for fix in sorted(self._fixes, reverse=True):
lines[fix.lineno - 1 : fix.end_lineno - 1] = fix.lines
data_path.write_text("\n".join(lines))


def add_test_name_suffix(name: str, suffix: str) -> str:
# Find magic suffix of form "-foobar" (used for things like "-skip").
Expand Down
33 changes: 0 additions & 33 deletions mypy/test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,39 +141,6 @@ def assert_target_equivalence(name: str, expected: list[str], actual: list[str])
)


def update_testcase_output(testcase: DataDrivenTestCase, output: list[str]) -> None:
assert testcase.old_cwd is not None, "test was not properly set up"
testcase_path = os.path.join(testcase.old_cwd, testcase.file)
with open(testcase_path, encoding="utf8") as f:
data_lines = f.read().splitlines()
test = "\n".join(data_lines[testcase.line : testcase.last_line])

mapping: dict[str, list[str]] = {}
for old, new in zip(testcase.output, output):
PREFIX = "error:"
ind = old.find(PREFIX)
if ind != -1 and old[:ind] == new[:ind]:
old, new = old[ind + len(PREFIX) :], new[ind + len(PREFIX) :]
mapping.setdefault(old, []).append(new)

for old in mapping:
if test.count(old) == len(mapping[old]):
betweens = test.split(old)

# Interleave betweens and mapping[old]
from itertools import chain

interleaved = [betweens[0]] + list(
chain.from_iterable(zip(mapping[old], betweens[1:]))
)
test = "".join(interleaved)

data_lines[testcase.line : testcase.last_line] = [test]
data = "\n".join(data_lines)
with open(testcase_path, "w", encoding="utf8") as f:
print(data, file=f)


def show_align_message(s1: str, s2: str) -> None:
"""Align s1 and s2 so that the their first difference is highlighted.
Expand Down
5 changes: 3 additions & 2 deletions mypy/test/testcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
normalize_error_messages,
parse_options,
perform_file_operations,
update_testcase_output,
)
from mypy.test.update_data import update_testcase_output

try:
import lxml # type: ignore[import]
Expand Down Expand Up @@ -192,7 +192,8 @@ def run_case_once(
output = testcase.output2.get(incremental_step, [])

if output != a and testcase.config.getoption("--update-data", False):
update_testcase_output(testcase, a)
update_testcase_output(testcase, a, incremental_step=incremental_step)

assert_string_arrays_equal(output, a, msg.format(testcase.file, testcase.line))

if res:
Expand Down
146 changes: 146 additions & 0 deletions mypy/test/testupdatedata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import shlex
import subprocess
import sys
import textwrap
from pathlib import Path

from mypy.test.config import test_data_prefix
from mypy.test.helpers import Suite


class UpdateDataSuite(Suite):
def _run_pytest_update_data(self, data_suite: str, *, max_attempts: int) -> str:
"""
Runs a suite of data test cases through 'pytest --update-data' until either tests pass
or until a maximum number of attempts (needed for incremental tests).
"""
p = Path(test_data_prefix) / "check-update-data.test"
assert not p.exists()
try:
p.write_text(textwrap.dedent(data_suite).lstrip())

test_nodeid = f"mypy/test/testcheck.py::TypeCheckSuite::{p.name}"
args = [sys.executable, "-m", "pytest", "-n", "0", "-s", "--update-data", test_nodeid]
if sys.version_info >= (3, 8):
cmd = shlex.join(args)
else:
cmd = " ".join(args)
for i in range(max_attempts - 1, -1, -1):
res = subprocess.run(args)
if res.returncode == 0:
break
print(f"`{cmd}` returned {res.returncode}: {i} attempts remaining")

return p.read_text()
finally:
p.unlink()

def test_update_data(self) -> None:
# Note: We test multiple testcases rather than 'test case per test case'
# so we could also exercise rewriting multiple testcases at once.
actual = self._run_pytest_update_data(
"""
[case testCorrect]
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[case testWrong]
s: str = 42 # E: wrong error
[case testWrongMultiline]
s: str = 42 # E: foo \
# N: bar
[case testMissingMultiline]
s: str = 42; i: int = 'foo'
[case testExtraneous]
s: str = 'foo' # E: wrong error
[case testExtraneousMultiline]
s: str = 'foo' # E: foo \
# E: bar
[case testExtraneousMultilineNonError]
s: str = 'foo' # W: foo \
# N: bar
[case testOutCorrect]
s: str = 42
[out]
main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str")
[case testOutWrong]
s: str = 42
[out]
main:1: error: foobar
[case testOutWrongIncremental]
s: str = 42
[out]
main:1: error: foobar
[out2]
main:1: error: foobar
[case testWrongMultipleFiles]
import a, b
s: str = 42 # E: foo
[file a.py]
s1: str = 42 # E: bar
[file b.py]
s2: str = 43 # E: baz
[builtins fixtures/list.pyi]
""",
max_attempts=3,
)

# Assert
expected = """
[case testCorrect]
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[case testWrong]
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[case testWrongMultiline]
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[case testMissingMultiline]
s: str = 42; i: int = 'foo' # E: Incompatible types in assignment (expression has type "int", variable has type "str") \\
# E: Incompatible types in assignment (expression has type "str", variable has type "int")
[case testExtraneous]
s: str = 'foo'
[case testExtraneousMultiline]
s: str = 'foo'
[case testExtraneousMultilineNonError]
s: str = 'foo'
[case testOutCorrect]
s: str = 42
[out]
main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str")
[case testOutWrong]
s: str = 42
[out]
main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str")
[case testOutWrongIncremental]
s: str = 42
[out]
main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str")
[out2]
main:1: error: Incompatible types in assignment (expression has type "int", variable has type "str")
[case testWrongMultipleFiles]
import a, b
s: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[file a.py]
s1: str = 42 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[file b.py]
s2: str = 43 # E: Incompatible types in assignment (expression has type "int", variable has type "str")
[builtins fixtures/list.pyi]
"""
assert actual == textwrap.dedent(expected).lstrip()
Empty file removed mypy/test/update.py
Empty file.
85 changes: 85 additions & 0 deletions mypy/test/update_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from __future__ import annotations

import re
from collections import defaultdict
from typing import Iterator

from mypy.test.data import DataDrivenTestCase, DataFileCollector, DataFileFix, parse_test_data


def update_testcase_output(
testcase: DataDrivenTestCase, actual: list[str], *, incremental_step: int
) -> None:
collector = testcase.parent
assert isinstance(collector, DataFileCollector)
for fix in _iter_fixes(testcase, actual, incremental_step=incremental_step):
collector.enqueue_fix(fix)


def _iter_fixes(
testcase: DataDrivenTestCase, actual: list[str], *, incremental_step: int
) -> Iterator[DataFileFix]:
reports_by_line: dict[tuple[str, int], list[tuple[str, str]]] = defaultdict(list)
for error_line in actual:
comment_match = re.match(
r"^(?P<filename>[^:]+):(?P<lineno>\d+): (?P<severity>error|note|warning): (?P<msg>.+)$",
error_line,
)
if comment_match:
filename = comment_match.group("filename")
lineno = int(comment_match.group("lineno"))
severity = comment_match.group("severity")
msg = comment_match.group("msg")
reports_by_line[filename, lineno].append((severity, msg))

test_items = parse_test_data(testcase.data, testcase.name)

# If we have [out] and/or [outN], we update just those sections.
if any(re.match(r"^out\d*$", test_item.id) for test_item in test_items):
for test_item in test_items:
if (incremental_step < 2 and test_item.id == "out") or (
incremental_step >= 2 and test_item.id == f"out{incremental_step}"
):
yield DataFileFix(
lineno=testcase.line + test_item.line - 1,
end_lineno=testcase.line + test_item.end_line - 1,
lines=actual + [""] * test_item.trimmed_newlines,
)

return

# Update assertion comments within the sections
for test_item in test_items:
if test_item.id == "case":
source_lines = test_item.data
file_path = "main"
elif test_item.id == "file":
source_lines = test_item.data
file_path = f"tmp/{test_item.arg}"
else:
continue # other sections we don't touch

fix_lines = []
for lineno, source_line in enumerate(source_lines, start=1):
reports = reports_by_line.get((file_path, lineno))
comment_match = re.search(r"(?P<indent>\s+)(?P<comment># [EWN]: .+)$", source_line)
if comment_match:
source_line = source_line[: comment_match.start("indent")] # strip old comment
if reports:
indent = comment_match.group("indent") if comment_match else " "
# multiline comments are on the first line and then on subsequent lines emtpy lines
# with a continuation backslash
for j, (severity, msg) in enumerate(reports):
out_l = source_line if j == 0 else " " * len(source_line)
is_last = j == len(reports) - 1
severity_char = severity[0].upper()
continuation = "" if is_last else " \\"
fix_lines.append(f"{out_l}{indent}# {severity_char}: {msg}{continuation}")
else:
fix_lines.append(source_line)

yield DataFileFix(
lineno=testcase.line + test_item.line - 1,
end_lineno=testcase.line + test_item.end_line - 1,
lines=fix_lines + [""] * test_item.trimmed_newlines,
)

0 comments on commit 3e03484

Please sign in to comment.