forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor to extract label_utils from export_pytorch_labels (pytorch#9…
…4179) Part of fixing pytorch#88098 ## Context This is 1/3 PRs to address issue 88098 (move label check failure logic from `check_labels.py` workflow to `trymerge.py` mergebot. Due to the messy cross-script imports and potential circular dependencies, it requires some refactoring to the scripts before, the functional PR can be cleanly implemented. ## What Changed 1. Extract extracts label utils fcns to a `label_utils.py` module from the `export_pytorch_labels.py` script. 2. Small improvements to naming, interface and test coverage ## Note to Reviewers This series of PRs is to replace the original PR pytorch#92682 to make the changes more modular and easier to review. * 1st PR: this one * 2nd PR: Goldspear#2 * 3rd PR: Goldspear#3 Pull Request resolved: pytorch#94179 Approved by: https://github.com/ZainRizvi
- Loading branch information
1 parent
4f691d2
commit 527b646
Showing
5 changed files
with
106 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
"""GitHub Label Utilities.""" | ||
|
||
import json | ||
|
||
from functools import lru_cache | ||
from typing import List, Any, Tuple | ||
from urllib.request import urlopen, Request | ||
|
||
# Modified from https://github.com/pytorch/pytorch/blob/b00206d4737d1f1e7a442c9f8a1cadccd272a386/torch/hub.py#L129 | ||
def _read_url(url: Request) -> Tuple[Any, Any]: | ||
with urlopen(url) as r: | ||
return r.headers, r.read().decode(r.headers.get_content_charset('utf-8')) | ||
|
||
|
||
def request_for_labels(url: str) -> Tuple[Any, Any]: | ||
headers = {'Accept': 'application/vnd.github.v3+json'} | ||
return _read_url(Request(url, headers=headers)) | ||
|
||
|
||
def update_labels(labels: List[str], info: str) -> None: | ||
labels_json = json.loads(info) | ||
labels.extend([x["name"] for x in labels_json]) | ||
|
||
|
||
def get_last_page_num_from_header(header: Any) -> int: | ||
# Link info looks like: <https://api.github.com/repositories/65600975/labels?per_page=100&page=2>; | ||
# rel="next", <https://api.github.com/repositories/65600975/labels?per_page=100&page=3>; rel="last" | ||
link_info = header['link'] | ||
prefix = "&page=" | ||
suffix = ">;" | ||
return int(link_info[link_info.rindex(prefix) + len(prefix):link_info.rindex(suffix)]) | ||
|
||
|
||
@lru_cache() | ||
def gh_get_labels(org: str, repo: str) -> List[str]: | ||
prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100" | ||
header, info = request_for_labels(prefix + "&page=1") | ||
labels: List[str] = [] | ||
update_labels(labels, info) | ||
|
||
last_page = get_last_page_num_from_header(header) | ||
assert last_page > 0, "Error reading header info to determine total number of pages of labels" | ||
for page_number in range(2, last_page + 1): # skip page 1 | ||
_, info = request_for_labels(prefix + f"&page={page_number}") | ||
update_labels(labels, info) | ||
|
||
return labels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
from typing import Any | ||
|
||
from unittest import TestCase, mock, main | ||
from label_utils import ( | ||
get_last_page_num_from_header, | ||
gh_get_labels, | ||
) | ||
|
||
|
||
class TestLabelUtils(TestCase): | ||
MOCK_HEADER_LINKS_TO_PAGE_NUMS = { | ||
1: {"link": "<https://api.github.com/dummy/labels?per_page=10&page=1>; rel='last'"}, | ||
2: {"link": "<https://api.github.com/dummy/labels?per_page=1&page=2>;"}, | ||
3: {"link": "<https://api.github.com/dummy/labels?per_page=1&page=2&page=3>;"}, | ||
} | ||
|
||
def test_get_last_page_num_from_header(self) -> None: | ||
for expected_page_num, mock_header in self.MOCK_HEADER_LINKS_TO_PAGE_NUMS.items(): | ||
self.assertEqual(get_last_page_num_from_header(mock_header), expected_page_num) | ||
|
||
MOCK_LABEL_INFO = '[{"name": "foo"}]' | ||
|
||
@mock.patch("label_utils.get_last_page_num_from_header", return_value=3) | ||
@mock.patch("label_utils.request_for_labels", return_value=(None, MOCK_LABEL_INFO)) | ||
def test_gh_get_labels( | ||
self, | ||
mock_request_for_labels: Any, | ||
mock_get_last_page_num_from_header: Any, | ||
) -> None: | ||
res = gh_get_labels("mock_org", "mock_repo") | ||
mock_get_last_page_num_from_header.assert_called_once() | ||
self.assertEqual(res, ["foo"] * 3) | ||
|
||
@mock.patch("label_utils.get_last_page_num_from_header", return_value=0) | ||
@mock.patch("label_utils.request_for_labels", return_value=(None, MOCK_LABEL_INFO)) | ||
def test_gh_get_labels_raises_with_no_pages( | ||
self, | ||
mock_request_for_labels: Any, | ||
get_last_page_num_from_header: Any, | ||
) -> None: | ||
with self.assertRaises(AssertionError) as err: | ||
gh_get_labels("foo", "bar") | ||
self.assertIn("number of pages of labels", str(err.exception)) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |