Skip to content

Commit

Permalink
Refactor to extract label_utils from export_pytorch_labels (pytorch#9…
Browse files Browse the repository at this point in the history
…4179)

Part of fixing pytorch#88098

## Context

This is 1/3 PRs to address issue 88098 (move label check failure logic from `check_labels.py` workflow to `trymerge.py` mergebot. Due to the messy cross-script imports and potential circular dependencies, it requires some refactoring to the scripts before, the functional PR can be cleanly implemented.

## What Changed
1. Extract extracts label utils fcns to a `label_utils.py` module from the `export_pytorch_labels.py` script.
2. Small improvements to naming, interface and test coverage

## Note to Reviewers
This series of PRs is to replace the original PR pytorch#92682 to make the changes more modular and easier to review.

* 1st PR: this one
* 2nd PR: Goldspear#2
* 3rd PR: Goldspear#3

Pull Request resolved: pytorch#94179
Approved by: https://github.com/ZainRizvi
  • Loading branch information
Goldspear authored and pytorchmergebot committed Feb 9, 2023
1 parent 4f691d2 commit 527b646
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 53 deletions.
3 changes: 2 additions & 1 deletion .github/scripts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,5 +61,6 @@ New runner types can be added by committing changes to `.github/scale-config.yml

In order to test changes to the builder scripts:

1. Specify your builder PR's branch and repo as `builder_repo` and `builder_branch` in [`.github/templates/common.yml.j2`](https://github.com/pytorch/pytorch/blob/32356aaee6a77e0ae424435a7e9da3d99e7a4ca5/.github/templates/common.yml.j2#LL10C26-L10C32). 2. Regenerate workflow files with `.github/regenerate.sh` (see above).
1. Specify your builder PR's branch and repo as `builder_repo` and `builder_branch` in [`.github/templates/common.yml.j2`](https://github.com/pytorch/pytorch/blob/32356aaee6a77e0ae424435a7e9da3d99e7a4ca5/.github/templates/common.yml.j2#LL10C26-L10C32).
2. Regenerate workflow files with `.github/regenerate.sh` (see above).
3. Submit fake PR to PyTorch. If changing binaries build, add an appropriate label like `ciflow/binaries` to trigger the builds.
11 changes: 7 additions & 4 deletions .github/scripts/check_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from typing import Any, List

from export_pytorch_labels import get_pytorch_labels
from label_utils import gh_get_labels
from gitutils import (
get_git_remote_name,
get_git_repo_dir,
Expand All @@ -27,8 +27,8 @@
)


def get_release_notes_labels() -> List[str]:
return [label for label in get_pytorch_labels() if label.lstrip().startswith("release notes:")]
def get_release_notes_labels(org: str, repo: str) -> List[str]:
return [label for label in gh_get_labels(org, repo) if label.lstrip().startswith("release notes:")]


def delete_comment(comment_id: int) -> None:
Expand All @@ -40,7 +40,10 @@ def has_required_labels(pr: GitHubPR) -> bool:
pr_labels = pr.get_labels()
# Check if PR is not user facing
is_not_user_facing_pr = any(label.strip() == "topic: not user facing" for label in pr_labels)
return is_not_user_facing_pr or any(label.strip() in get_release_notes_labels() for label in pr_labels)
return (
is_not_user_facing_pr or
any(label.strip() in get_release_notes_labels(pr.org, pr.project) for label in pr_labels)
)


def delete_comments(pr: GitHubPR) -> None:
Expand Down
51 changes: 3 additions & 48 deletions .github/scripts/export_pytorch_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,14 @@

import boto3 # type: ignore[import]
import json
from functools import lru_cache
from typing import List, Any
from urllib.request import urlopen, Request

# Modified from https://github.com/pytorch/pytorch/blob/b00206d4737d1f1e7a442c9f8a1cadccd272a386/torch/hub.py#L129
def _read_url(url: Any) -> Any:
with urlopen(url) as r:
return r.headers, r.read().decode(r.headers.get_content_charset('utf-8'))
from label_utils import gh_get_labels


def request_for_labels(url: str) -> Any:
headers = {'Accept': 'application/vnd.github.v3+json'}
return _read_url(Request(url, headers=headers))


def get_last_page(header: Any) -> int:
# Link info looks like: <https://api.github.com/repositories/65600975/labels?per_page=100&page=2>;
# rel="next", <https://api.github.com/repositories/65600975/labels?per_page=100&page=3>; rel="last"
link_info = header['link']
prefix = "&page="
suffix = ">;"
return int(link_info[link_info.rindex(prefix) + len(prefix):link_info.rindex(suffix)])


def update_labels(labels: List[str], info: str) -> None:
labels_json = json.loads(info)
labels.extend([x["name"] for x in labels_json])


@lru_cache()
def get_pytorch_labels() -> List[str]:
prefix = "https://api.github.com/repos/pytorch/pytorch/labels?per_page=100"
header, info = request_for_labels(prefix + "&page=1")
labels: List[str] = []
update_labels(labels, info)

last_page = get_last_page(header)
assert last_page > 0, "Error reading header info to determine total number of pages of labels"
for page_number in range(2, last_page + 1): # skip page 1
_, info = request_for_labels(prefix + f"&page={page_number}")
update_labels(labels, info)

return labels


def send_labels_to_S3(labels: List[str]) -> None:
def main() -> None:
labels_file_name = "pytorch_labels.json"
obj = boto3.resource('s3').Object('ossci-metrics', labels_file_name)
obj.put(Body=json.dumps(labels).encode())


def main() -> None:
send_labels_to_S3(get_pytorch_labels())
obj.put(Body=json.dumps(gh_get_labels()).encode())


if __name__ == '__main__':
Expand Down
47 changes: 47 additions & 0 deletions .github/scripts/label_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""GitHub Label Utilities."""

import json

from functools import lru_cache
from typing import List, Any, Tuple
from urllib.request import urlopen, Request

# Modified from https://github.com/pytorch/pytorch/blob/b00206d4737d1f1e7a442c9f8a1cadccd272a386/torch/hub.py#L129
def _read_url(url: Request) -> Tuple[Any, Any]:
with urlopen(url) as r:
return r.headers, r.read().decode(r.headers.get_content_charset('utf-8'))


def request_for_labels(url: str) -> Tuple[Any, Any]:
headers = {'Accept': 'application/vnd.github.v3+json'}
return _read_url(Request(url, headers=headers))


def update_labels(labels: List[str], info: str) -> None:
labels_json = json.loads(info)
labels.extend([x["name"] for x in labels_json])


def get_last_page_num_from_header(header: Any) -> int:
# Link info looks like: <https://api.github.com/repositories/65600975/labels?per_page=100&page=2>;
# rel="next", <https://api.github.com/repositories/65600975/labels?per_page=100&page=3>; rel="last"
link_info = header['link']
prefix = "&page="
suffix = ">;"
return int(link_info[link_info.rindex(prefix) + len(prefix):link_info.rindex(suffix)])


@lru_cache()
def gh_get_labels(org: str, repo: str) -> List[str]:
prefix = f"https://api.github.com/repos/{org}/{repo}/labels?per_page=100"
header, info = request_for_labels(prefix + "&page=1")
labels: List[str] = []
update_labels(labels, info)

last_page = get_last_page_num_from_header(header)
assert last_page > 0, "Error reading header info to determine total number of pages of labels"
for page_number in range(2, last_page + 1): # skip page 1
_, info = request_for_labels(prefix + f"&page={page_number}")
update_labels(labels, info)

return labels
47 changes: 47 additions & 0 deletions .github/scripts/test_label_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Any

from unittest import TestCase, mock, main
from label_utils import (
get_last_page_num_from_header,
gh_get_labels,
)


class TestLabelUtils(TestCase):
MOCK_HEADER_LINKS_TO_PAGE_NUMS = {
1: {"link": "<https://api.github.com/dummy/labels?per_page=10&page=1>; rel='last'"},
2: {"link": "<https://api.github.com/dummy/labels?per_page=1&page=2>;"},
3: {"link": "<https://api.github.com/dummy/labels?per_page=1&page=2&page=3>;"},
}

def test_get_last_page_num_from_header(self) -> None:
for expected_page_num, mock_header in self.MOCK_HEADER_LINKS_TO_PAGE_NUMS.items():
self.assertEqual(get_last_page_num_from_header(mock_header), expected_page_num)

MOCK_LABEL_INFO = '[{"name": "foo"}]'

@mock.patch("label_utils.get_last_page_num_from_header", return_value=3)
@mock.patch("label_utils.request_for_labels", return_value=(None, MOCK_LABEL_INFO))
def test_gh_get_labels(
self,
mock_request_for_labels: Any,
mock_get_last_page_num_from_header: Any,
) -> None:
res = gh_get_labels("mock_org", "mock_repo")
mock_get_last_page_num_from_header.assert_called_once()
self.assertEqual(res, ["foo"] * 3)

@mock.patch("label_utils.get_last_page_num_from_header", return_value=0)
@mock.patch("label_utils.request_for_labels", return_value=(None, MOCK_LABEL_INFO))
def test_gh_get_labels_raises_with_no_pages(
self,
mock_request_for_labels: Any,
get_last_page_num_from_header: Any,
) -> None:
with self.assertRaises(AssertionError) as err:
gh_get_labels("foo", "bar")
self.assertIn("number of pages of labels", str(err.exception))


if __name__ == "__main__":
main()

0 comments on commit 527b646

Please sign in to comment.