Mach: introduce Pyrefly for Python type checking, starting with the wpt folder (#37953)

This is the first stage of adopting Pyrefly. It introduces the Python
folder and focuses on fixing issues around it.

Testing: *Describe how this pull request is tested or why it doesn't
require tests*
Fixes: *Link to an issue this pull requests fixes or remove this line if
there is no issue*

---------

Signed-off-by: Jerens Lensun <jerensslensun@gmail.com>
This commit is contained in:
Jerens Lensun 2025-07-11 21:07:36 +08:00 committed by GitHub
parent 2366a8bf9e
commit 55fd7b862f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 303 additions and 154 deletions

View file

@ -24,3 +24,26 @@ ignore = [
# 80 character line length; the standard tidy process will enforce line length # 80 character line length; the standard tidy process will enforce line length
"E501", "E501",
] ]
[tool.pyrefly]
search-path = [
"python",
"tests/wpt/tests",
"tests/wpt/tests/tools",
"tests/wpt/tests/tools/wptrunner",
"tests/wpt/tests/tools/wptserve",
"python/mach",
]
project-includes = [
"python/wpt/**/*.py",
]
project-excludes = [
"**/venv/**",
"**/.venv/**",
"tests/wpt/tests/**",
"**/test.py",
"**/*_tests.py",
"**/tests/**",
"python/mach/**/*.py",
"python/servo/mutation/**/*.py",
]

View file

@ -38,3 +38,6 @@ Mako == 1.2.2
# For devtools tests. # For devtools tests.
geckordp == 1.0.3 geckordp == 1.0.3
# For Python static type checking
pyrefly == 0.23.1

View file

@ -17,7 +17,8 @@ import os
import re import re
import subprocess import subprocess
import sys import sys
from typing import Any, Dict, List from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Tuple
import colorama import colorama
import toml import toml
@ -128,6 +129,10 @@ def is_iter_empty(iterator):
return False, iterator return False, iterator
def normalize_path(path: str) -> str:
return os.path.relpath(os.path.abspath(path), TOPDIR)
def normilize_paths(paths): def normilize_paths(paths):
if isinstance(paths, str): if isinstance(paths, str):
return os.path.join(*paths.split("/")) return os.path.join(*paths.split("/"))
@ -376,6 +381,38 @@ def check_ruff_lints():
) )
@dataclass
class PyreflyDiagnostic:
"""
Represents a single diagnostic error reported by Pyrefly.
"""
line: int
column: int
stop_line: int
stop_column: int
path: str
code: int
name: str
description: str
concise_description: str
def run_python_type_checker() -> Iterator[Tuple[str, int, str]]:
print("\r ➤ Checking type annotations in python files ...")
try:
result = subprocess.run(["pyrefly", "check", "--output-format", "json"], capture_output=True)
parsed_json = json.loads(result.stdout)
errors = parsed_json.get("errors", [])
except subprocess.CalledProcessError as error:
print(f"{colorama.Fore.YELLOW}{error}{colorama.Style.RESET_ALL}")
pass
else:
for error in errors:
diagnostic = PyreflyDiagnostic(**error)
yield normalize_path(diagnostic.path), diagnostic.line, diagnostic.concise_description
def run_cargo_deny_lints(): def run_cargo_deny_lints():
print("\r ➤ Running `cargo-deny` checks...") print("\r ➤ Running `cargo-deny` checks...")
result = subprocess.run( result = subprocess.run(
@ -1003,11 +1040,14 @@ def scan(only_changed_files=False, progress=False, github_annotations=False):
file_errors = collect_errors_for_files(files_to_check, checking_functions, line_checking_functions) file_errors = collect_errors_for_files(files_to_check, checking_functions, line_checking_functions)
python_errors = check_ruff_lints() python_errors = check_ruff_lints()
python_type_check = run_python_type_checker()
cargo_lock_errors = run_cargo_deny_lints() cargo_lock_errors = run_cargo_deny_lints()
wpt_errors = run_wpt_lints(only_changed_files) wpt_errors = run_wpt_lints(only_changed_files)
# chain all the iterators # chain all the iterators
errors = itertools.chain(config_errors, directory_errors, file_errors, python_errors, wpt_errors, cargo_lock_errors) errors = itertools.chain(
config_errors, directory_errors, file_errors, python_errors, python_type_check, wpt_errors, cargo_lock_errors
)
colorama.init() colorama.init()
error = None error = None

View file

@ -7,6 +7,7 @@
# option. This file may not be copied, modified, or distributed # option. This file may not be copied, modified, or distributed
# except according to those terms. # except according to those terms.
from argparse import ArgumentParser
import os import os
import sys import sys
@ -25,7 +26,7 @@ import localpaths # noqa: F401,E402
import wptrunner.wptcommandline # noqa: E402 import wptrunner.wptcommandline # noqa: E402
def create_parser(): def create_parser() -> ArgumentParser:
parser = wptrunner.wptcommandline.create_parser() parser = wptrunner.wptcommandline.create_parser()
parser.add_argument( parser.add_argument(
"--rr-chaos", default=False, action="store_true", help="Run under chaos mode in rr until a failure is captured" "--rr-chaos", default=False, action="store_true", help="Run under chaos mode in rr until a failure is captured"
@ -60,5 +61,5 @@ def create_parser():
return parser return parser
def run_tests(): def run_tests() -> bool:
return test.run_tests() return test.run_tests()

View file

@ -17,7 +17,7 @@ import logging
import os import os
import sys import sys
from exporter import WPTSync from .exporter import WPTSync
def main() -> int: def main() -> int:

View file

@ -21,7 +21,7 @@ import logging
import re import re
import shutil import shutil
import subprocess import subprocess
from dataclasses import field
from typing import Callable, Optional from typing import Callable, Optional
from .common import ( from .common import (
@ -34,7 +34,6 @@ from .common import (
UPSTREAMABLE_PATH, UPSTREAMABLE_PATH,
wpt_branch_name_from_servo_pr_number, wpt_branch_name_from_servo_pr_number,
) )
from .github import GithubRepository, PullRequest from .github import GithubRepository, PullRequest
from .step import ( from .step import (
AsyncValue, AsyncValue,
@ -49,7 +48,7 @@ from .step import (
class LocalGitRepo: class LocalGitRepo:
def __init__(self, path: str, sync: WPTSync): def __init__(self, path: str, sync: WPTSync) -> None:
self.path = path self.path = path
self.sync = sync self.sync = sync
@ -57,7 +56,9 @@ class LocalGitRepo:
# git in advance and run the subprocess by its absolute path. # git in advance and run the subprocess by its absolute path.
self.git_path = shutil.which("git") self.git_path = shutil.which("git")
def run_without_encoding(self, *args, env: dict = {}): def run_without_encoding(self, *args, env: dict = {}) -> bytes:
if self.git_path is None:
raise RuntimeError("Git executable not found in PATH")
command_line = [self.git_path] + list(args) command_line = [self.git_path] + list(args)
logging.info(" → Execution (cwd='%s'): %s", self.path, " ".join(command_line)) logging.info(" → Execution (cwd='%s'): %s", self.path, " ".join(command_line))
@ -74,7 +75,7 @@ class LocalGitRepo:
) )
raise exception raise exception
def run(self, *args, env: dict = {}): def run(self, *args, env: dict = {}) -> str:
return self.run_without_encoding(*args, env=env).decode("utf-8", errors="surrogateescape") return self.run_without_encoding(*args, env=env).decode("utf-8", errors="surrogateescape")
@ -93,11 +94,19 @@ class SyncRun:
servo_pr=self.servo_pr, servo_pr=self.servo_pr,
) )
def add_step(self, step) -> Optional[AsyncValue]: def add_step(
self,
step: ChangePRStep
| CommentStep
| CreateOrUpdateBranchForPRStep
| MergePRStep
| OpenPRStep
| RemoveBranchForPRStep,
) -> Optional[AsyncValue]:
self.steps.append(step) self.steps.append(step)
return step.provides() return step.provides()
def run(self): def run(self) -> None:
# This loop always removes the first step and runs it, because # This loop always removes the first step and runs it, because
# individual steps can modify the list of steps. For instance, if a # individual steps can modify the list of steps. For instance, if a
# step fails, it might clear the remaining steps and replace them with # step fails, it might clear the remaining steps and replace them with
@ -142,7 +151,13 @@ class WPTSync:
github_name: str github_name: str
suppress_force_push: bool = False suppress_force_push: bool = False
def __post_init__(self): servo: GithubRepository = field(init=False)
wpt: GithubRepository = field(init=False)
downstream_wpt: GithubRepository = field(init=False)
local_servo_repo: LocalGitRepo = field(init=False)
local_wpt_repo: LocalGitRepo = field(init=False)
def __post_init__(self) -> None:
self.servo = GithubRepository(self, self.servo_repo, "main") self.servo = GithubRepository(self, self.servo_repo, "main")
self.wpt = GithubRepository(self, self.wpt_repo, "master") self.wpt = GithubRepository(self, self.wpt_repo, "master")
self.downstream_wpt = GithubRepository(self, self.downstream_wpt_repo, "master") self.downstream_wpt = GithubRepository(self, self.downstream_wpt_repo, "master")
@ -194,7 +209,7 @@ class WPTSync:
logging.error(exception, exc_info=True) logging.error(exception, exc_info=True)
return False return False
def handle_new_pull_request_contents(self, run: SyncRun, pull_data: dict): def handle_new_pull_request_contents(self, run: SyncRun, pull_data: dict) -> None:
num_commits = pull_data["commits"] num_commits = pull_data["commits"]
head_sha = pull_data["head"]["sha"] head_sha = pull_data["head"]["sha"]
is_upstreamable = ( is_upstreamable = (
@ -243,13 +258,13 @@ class WPTSync:
# Leave a comment to the new pull request in the original pull request. # Leave a comment to the new pull request in the original pull request.
run.add_step(CommentStep(run.servo_pr, OPENED_NEW_UPSTREAM_PR)) run.add_step(CommentStep(run.servo_pr, OPENED_NEW_UPSTREAM_PR))
def handle_edited_pull_request(self, run: SyncRun, pull_data: dict): def handle_edited_pull_request(self, run: SyncRun, pull_data: dict) -> None:
logging.info("Changing upstream PR title") logging.info("Changing upstream PR title")
if run.upstream_pr.has_value(): if run.upstream_pr.has_value():
run.add_step(ChangePRStep(run.upstream_pr.value(), "open", pull_data["title"], pull_data["body"])) run.add_step(ChangePRStep(run.upstream_pr.value(), "open", pull_data["title"], pull_data["body"]))
run.add_step(CommentStep(run.servo_pr, UPDATED_TITLE_IN_EXISTING_UPSTREAM_PR)) run.add_step(CommentStep(run.servo_pr, UPDATED_TITLE_IN_EXISTING_UPSTREAM_PR))
def handle_closed_pull_request(self, run: SyncRun, pull_data: dict): def handle_closed_pull_request(self, run: SyncRun, pull_data: dict) -> None:
logging.info("Processing closed PR") logging.info("Processing closed PR")
if not run.upstream_pr.has_value(): if not run.upstream_pr.has_value():
# If we don't recognize this PR, it never contained upstreamable changes. # If we don't recognize this PR, it never contained upstreamable changes.

View file

@ -43,5 +43,5 @@ COULD_NOT_MERGE_CHANGES_UPSTREAM_COMMENT = (
) )
def wpt_branch_name_from_servo_pr_number(servo_pr_number): def wpt_branch_name_from_servo_pr_number(servo_pr_number) -> str:
return f"servo_export_{servo_pr_number}" return f"servo_export_{servo_pr_number}"

View file

@ -16,7 +16,7 @@ day be entirely replaced with something like PyGithub."""
from __future__ import annotations from __future__ import annotations
import logging import logging
import urllib import urllib.parse
from typing import Optional, TYPE_CHECKING from typing import Optional, TYPE_CHECKING
@ -29,7 +29,7 @@ USER_AGENT = "Servo web-platform-test sync service"
TIMEOUT = 30 # 30 seconds TIMEOUT = 30 # 30 seconds
def authenticated(sync: WPTSync, method, url, json=None) -> requests.Response: def authenticated(sync: WPTSync, method: str, url: str, json=None) -> requests.Response:
logging.info(" → Request: %s %s", method, url) logging.info(" → Request: %s %s", method, url)
if json: if json:
logging.info(" → Request JSON: %s", json) logging.info(" → Request JSON: %s", json)
@ -51,14 +51,14 @@ class GithubRepository:
This class allows interacting with a single GitHub repository. This class allows interacting with a single GitHub repository.
""" """
def __init__(self, sync: WPTSync, repo: str, default_branch: str): def __init__(self, sync: WPTSync, repo: str, default_branch: str) -> None:
self.sync = sync self.sync = sync
self.repo = repo self.repo = repo
self.default_branch = default_branch self.default_branch = default_branch
self.org = repo.split("/")[0] self.org = repo.split("/")[0]
self.pulls_url = f"repos/{self.repo}/pulls" self.pulls_url = f"repos/{self.repo}/pulls"
def __str__(self): def __str__(self) -> str:
return self.repo return self.repo
def get_pull_request(self, number: int) -> PullRequest: def get_pull_request(self, number: int) -> PullRequest:
@ -94,7 +94,7 @@ class GithubRepository:
return self.get_pull_request(json["items"][0]["number"]) return self.get_pull_request(json["items"][0]["number"])
def open_pull_request(self, branch: GithubBranch, title: str, body: str): def open_pull_request(self, branch: GithubBranch, title: str, body: str) -> PullRequest:
data = { data = {
"title": title, "title": title,
"head": branch.get_pr_head_reference_for_repo(self), "head": branch.get_pr_head_reference_for_repo(self),
@ -107,11 +107,11 @@ class GithubRepository:
class GithubBranch: class GithubBranch:
def __init__(self, repo: GithubRepository, branch_name: str): def __init__(self, repo: GithubRepository, branch_name: str) -> None:
self.repo = repo self.repo = repo
self.name = branch_name self.name = branch_name
def __str__(self): def __str__(self) -> str:
return f"{self.repo}/{self.name}" return f"{self.repo}/{self.name}"
def get_pr_head_reference_for_repo(self, other_repo: GithubRepository) -> str: def get_pr_head_reference_for_repo(self, other_repo: GithubRepository) -> str:
@ -128,20 +128,20 @@ class PullRequest:
This class allows interacting with a single pull request on GitHub. This class allows interacting with a single pull request on GitHub.
""" """
def __init__(self, repo: GithubRepository, number: int): def __init__(self, repo: GithubRepository, number: int) -> None:
self.repo = repo self.repo = repo
self.context = repo.sync self.context = repo.sync
self.number = number self.number = number
self.base_url = f"repos/{self.repo.repo}/pulls/{self.number}" self.base_url = f"repos/{self.repo.repo}/pulls/{self.number}"
self.base_issues_url = f"repos/{self.repo.repo}/issues/{self.number}" self.base_issues_url = f"repos/{self.repo.repo}/issues/{self.number}"
def __str__(self): def __str__(self) -> str:
return f"{self.repo}#{self.number}" return f"{self.repo}#{self.number}"
def api(self, *args, **kwargs) -> requests.Response: def api(self, *args, **kwargs) -> requests.Response:
return authenticated(self.context, *args, **kwargs) return authenticated(self.context, *args, **kwargs)
def leave_comment(self, comment: str): def leave_comment(self, comment: str) -> requests.Response:
return self.api("POST", f"{self.base_issues_url}/comments", json={"body": comment}) return self.api("POST", f"{self.base_issues_url}/comments", json={"body": comment})
def change( def change(
@ -149,7 +149,7 @@ class PullRequest:
state: Optional[str] = None, state: Optional[str] = None,
title: Optional[str] = None, title: Optional[str] = None,
body: Optional[str] = None, body: Optional[str] = None,
): ) -> requests.Response:
data = {} data = {}
if title: if title:
data["title"] = title data["title"] = title
@ -159,11 +159,11 @@ class PullRequest:
data["state"] = state data["state"] = state
return self.api("PATCH", self.base_url, json=data) return self.api("PATCH", self.base_url, json=data)
def remove_label(self, label: str): def remove_label(self, label: str) -> None:
self.api("DELETE", f"{self.base_issues_url}/labels/{label}") self.api("DELETE", f"{self.base_issues_url}/labels/{label}")
def add_labels(self, labels: list[str]): def add_labels(self, labels: list[str]) -> None:
self.api("POST", f"{self.base_issues_url}/labels", json=labels) self.api("POST", f"{self.base_issues_url}/labels", json=labels)
def merge(self): def merge(self) -> None:
self.api("PUT", f"{self.base_url}/merge", json={"merge_method": "rebase"}) self.api("PUT", f"{self.base_url}/merge", json={"merge_method": "rebase"})

View file

@ -36,13 +36,13 @@ PATCH_FILE_NAME = "tmp.patch"
class Step: class Step:
def __init__(self, name): def __init__(self, name) -> None:
self.name = name self.name = name
def provides(self) -> Optional[AsyncValue]: def provides(self) -> Optional[AsyncValue]:
return None return None
def run(self, _: SyncRun): def run(self, run: SyncRun) -> None:
return return
@ -50,31 +50,31 @@ T = TypeVar("T")
class AsyncValue(Generic[T]): class AsyncValue(Generic[T]):
def __init__(self, value: Optional[T] = None): def __init__(self, value: Optional[T] = None) -> None:
self._value = value self._value = value
def resolve(self, value: T): def resolve(self, value: Optional[T]) -> None:
self._value = value self._value = value
def value(self) -> T: def value(self) -> T:
assert self._value is not None assert self._value is not None
return self._value return self._value
def has_value(self): def has_value(self) -> bool:
return self._value is not None return self._value is not None
class CreateOrUpdateBranchForPRStep(Step): class CreateOrUpdateBranchForPRStep(Step):
def __init__(self, pull_data: dict, pull_request: PullRequest): def __init__(self, pull_data: dict, pull_request: PullRequest) -> None:
Step.__init__(self, "CreateOrUpdateBranchForPRStep") Step.__init__(self, "CreateOrUpdateBranchForPRStep")
self.pull_data = pull_data self.pull_data = pull_data
self.pull_request = pull_request self.pull_request = pull_request
self.branch: AsyncValue[GithubBranch] = AsyncValue() self.branch: AsyncValue[GithubBranch] = AsyncValue()
def provides(self): def provides(self) -> AsyncValue[GithubBranch]:
return self.branch return self.branch
def run(self, run: SyncRun): def run(self, run: SyncRun) -> None:
try: try:
commits = self._get_upstreamable_commits_from_local_servo_repo(run.sync) commits = self._get_upstreamable_commits_from_local_servo_repo(run.sync)
branch_name = self._create_or_update_branch_for_pr(run, commits) branch_name = self._create_or_update_branch_for_pr(run, commits)
@ -128,7 +128,7 @@ class CreateOrUpdateBranchForPRStep(Step):
] ]
return filtered_commits return filtered_commits
def _apply_filtered_servo_commit_to_wpt(self, run: SyncRun, commit: dict): def _apply_filtered_servo_commit_to_wpt(self, run: SyncRun, commit: dict) -> None:
patch_path = os.path.join(run.sync.wpt_path, PATCH_FILE_NAME) patch_path = os.path.join(run.sync.wpt_path, PATCH_FILE_NAME)
strip_count = UPSTREAMABLE_PATH.count("/") + 1 strip_count = UPSTREAMABLE_PATH.count("/") + 1
@ -143,7 +143,7 @@ class CreateOrUpdateBranchForPRStep(Step):
run.sync.local_wpt_repo.run("add", "--all") run.sync.local_wpt_repo.run("add", "--all")
run.sync.local_wpt_repo.run("commit", "--message", commit["message"], "--author", commit["author"]) run.sync.local_wpt_repo.run("commit", "--message", commit["message"], "--author", commit["author"])
def _create_or_update_branch_for_pr(self, run: SyncRun, commits: list[dict], pre_commit_callback=None): def _create_or_update_branch_for_pr(self, run: SyncRun, commits: list[dict], pre_commit_callback=None) -> str:
branch_name = wpt_branch_name_from_servo_pr_number(self.pull_data["number"]) branch_name = wpt_branch_name_from_servo_pr_number(self.pull_data["number"])
try: try:
# Create a new branch with a unique name that is consistent between # Create a new branch with a unique name that is consistent between
@ -169,7 +169,6 @@ class CreateOrUpdateBranchForPRStep(Step):
remote_url = f"https://{user}:{token}@github.com/{repo}.git" remote_url = f"https://{user}:{token}@github.com/{repo}.git"
run.sync.local_wpt_repo.run("push", "-f", remote_url, branch_name) run.sync.local_wpt_repo.run("push", "-f", remote_url, branch_name)
return branch_name
finally: finally:
try: try:
run.sync.local_wpt_repo.run("checkout", "master") run.sync.local_wpt_repo.run("checkout", "master")
@ -177,13 +176,15 @@ class CreateOrUpdateBranchForPRStep(Step):
except Exception: except Exception:
pass pass
return branch_name
class RemoveBranchForPRStep(Step): class RemoveBranchForPRStep(Step):
def __init__(self, pull_request): def __init__(self, pull_request) -> None:
Step.__init__(self, "RemoveBranchForPRStep") Step.__init__(self, "RemoveBranchForPRStep")
self.branch_name = wpt_branch_name_from_servo_pr_number(pull_request["number"]) self.branch_name = wpt_branch_name_from_servo_pr_number(pull_request["number"])
def run(self, run: SyncRun): def run(self, run: SyncRun) -> None:
self.name += f":{run.sync.downstream_wpt.get_branch(self.branch_name)}" self.name += f":{run.sync.downstream_wpt.get_branch(self.branch_name)}"
logging.info(" -> Removing branch used for upstream PR") logging.info(" -> Removing branch used for upstream PR")
if not run.sync.suppress_force_push: if not run.sync.suppress_force_push:
@ -201,7 +202,7 @@ class ChangePRStep(Step):
state: str, state: str,
title: Optional[str] = None, title: Optional[str] = None,
body: Optional[str] = None, body: Optional[str] = None,
): ) -> None:
name = f"ChangePRStep:{pull_request}:{state}" name = f"ChangePRStep:{pull_request}:{state}"
if title: if title:
name += f":{title}" name += f":{title}"
@ -212,7 +213,7 @@ class ChangePRStep(Step):
self.title = title self.title = title
self.body = body self.body = body
def run(self, run: SyncRun): def run(self, run: SyncRun) -> None:
body = self.body body = self.body
if body: if body:
body = run.prepare_body_text(body) body = run.prepare_body_text(body)
@ -222,12 +223,12 @@ class ChangePRStep(Step):
class MergePRStep(Step): class MergePRStep(Step):
def __init__(self, pull_request: PullRequest, labels_to_remove: list[str] = []): def __init__(self, pull_request: PullRequest, labels_to_remove: list[str] = []) -> None:
Step.__init__(self, f"MergePRStep:{pull_request}") Step.__init__(self, f"MergePRStep:{pull_request}")
self.pull_request = pull_request self.pull_request = pull_request
self.labels_to_remove = labels_to_remove self.labels_to_remove = labels_to_remove
def run(self, run: SyncRun): def run(self, run: SyncRun) -> None:
try: try:
for label in self.labels_to_remove: for label in self.labels_to_remove:
self.pull_request.remove_label(label) self.pull_request.remove_label(label)
@ -250,7 +251,7 @@ class OpenPRStep(Step):
title: str, title: str,
body: str, body: str,
labels: list[str], labels: list[str],
): ) -> None:
Step.__init__(self, "OpenPRStep") Step.__init__(self, "OpenPRStep")
self.title = title self.title = title
self.body = body self.body = body
@ -259,10 +260,10 @@ class OpenPRStep(Step):
self.new_pr: AsyncValue[PullRequest] = AsyncValue() self.new_pr: AsyncValue[PullRequest] = AsyncValue()
self.labels = labels self.labels = labels
def provides(self): def provides(self) -> AsyncValue[PullRequest]:
return self.new_pr return self.new_pr
def run(self, run: SyncRun): def run(self, run: SyncRun) -> None:
pull_request = self.target_repo.open_pull_request( pull_request = self.target_repo.open_pull_request(
self.source_branch.value(), self.title, run.prepare_body_text(self.body) self.source_branch.value(), self.title, run.prepare_body_text(self.body)
) )
@ -276,12 +277,12 @@ class OpenPRStep(Step):
class CommentStep(Step): class CommentStep(Step):
def __init__(self, pull_request: PullRequest, comment_template: str): def __init__(self, pull_request: PullRequest, comment_template: str) -> None:
Step.__init__(self, "CommentStep") Step.__init__(self, "CommentStep")
self.pull_request = pull_request self.pull_request = pull_request
self.comment_template = comment_template self.comment_template = comment_template
def run(self, run: SyncRun): def run(self, run: SyncRun) -> None:
comment = run.make_comment(self.comment_template) comment = run.make_comment(self.comment_template)
self.name += f":{self.pull_request}:{comment}" self.name += f":{self.pull_request}:{comment}"
self.pull_request.leave_comment(comment) self.pull_request.leave_comment(comment)

View file

@ -13,7 +13,7 @@ import mozlog.formatters.base
import mozlog.reader import mozlog.reader
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any from typing import DefaultDict, Dict, Optional, NotRequired, Union, TypedDict, Literal
from six import itervalues from six import itervalues
DEFAULT_MOVE_UP_CODE = "\x1b[A" DEFAULT_MOVE_UP_CODE = "\x1b[A"
@ -34,7 +34,7 @@ class UnexpectedSubtestResult:
@dataclass @dataclass
class UnexpectedResult: class UnexpectedResult:
path: str path: str
subsuite: str subsuite: Optional[str]
actual: str actual: str
expected: str expected: str
message: str message: str
@ -61,7 +61,7 @@ class UnexpectedResult:
# Organize the failures by stack trace so we don't print the same stack trace # Organize the failures by stack trace so we don't print the same stack trace
# more than once. They are really tall and we don't want to flood the screen # more than once. They are really tall and we don't want to flood the screen
# with duplicate information. # with duplicate information.
results_by_stack = collections.defaultdict(list) results_by_stack: DefaultDict[str | None, list[UnexpectedSubtestResult]] = collections.defaultdict(list)
for subtest_result in self.unexpected_subtest_results: for subtest_result in self.unexpected_subtest_results:
results_by_stack[subtest_result.stack].append(subtest_result) results_by_stack[subtest_result.stack].append(subtest_result)
@ -74,7 +74,7 @@ class UnexpectedResult:
return UnexpectedResult.wrap_and_indent_lines(output, " ") return UnexpectedResult.wrap_and_indent_lines(output, " ")
@staticmethod @staticmethod
def wrap_and_indent_lines(lines, indent): def wrap_and_indent_lines(lines, indent: str):
if not lines: if not lines:
return "" return ""
@ -86,7 +86,7 @@ class UnexpectedResult:
return output return output
@staticmethod @staticmethod
def to_lines(result: Any[UnexpectedSubtestResult, UnexpectedResult], print_stack=True): def to_lines(result: Union[UnexpectedSubtestResult, UnexpectedResult], print_stack=True) -> list[str]:
first_line = result.actual first_line = result.actual
if result.expected != result.actual: if result.expected != result.actual:
first_line += f" [expected {result.expected}]" first_line += f" [expected {result.expected}]"
@ -109,11 +109,66 @@ class UnexpectedResult:
return lines return lines
class GlobalTestData(TypedDict):
action: str
time: int
thread: str
pid: int
source: str
Status = Literal["PASS", "FAIL", "PRECONDITION_FAILED", "TIMEOUT", "CRASH", "ASSERT", "SKIP", "OK", "ERROR"]
class SuiteStartData(GlobalTestData):
tests: Dict
name: NotRequired[str]
run_info: NotRequired[Dict]
version_info: NotRequired[Dict]
device_info: NotRequired[Dict]
class TestStartData(GlobalTestData):
test: str
path: NotRequired[str]
known_intermittent: Status
subsuite: NotRequired[str]
group: NotRequired[str]
class TestEndData(GlobalTestData):
test: str
status: Status
expected: Status
known_intermittent: Status
message: NotRequired[str]
stack: NotRequired[str]
extra: NotRequired[str]
subsuite: NotRequired[str]
group: NotRequired[str]
class TestStatusData(TestEndData):
subtest: str
class ServoHandler(mozlog.reader.LogHandler): class ServoHandler(mozlog.reader.LogHandler):
"""LogHandler designed to collect unexpected results for use by """LogHandler designed to collect unexpected results for use by
script or by the ServoFormatter output formatter.""" script or by the ServoFormatter output formatter."""
def __init__(self, detect_flakes=False): number_of_tests: int
completed_tests: int
need_to_erase_last_line: int
running_tests: Dict[str, str]
test_output: DefaultDict[str, str]
subtest_failures: DefaultDict[str, list]
tests_with_failing_subtests: list
unexpected_results: list
expected: Dict[str, int]
unexpected_tests: Dict[str, list]
suite_start_time: int
def __init__(self, detect_flakes=False) -> None:
""" """
Flake detection assumes first suite is actual run Flake detection assumes first suite is actual run
and rest of the suites are retry-unexpected for flakes detection. and rest of the suites are retry-unexpected for flakes detection.
@ -122,18 +177,18 @@ class ServoHandler(mozlog.reader.LogHandler):
self.currently_detecting_flakes = False self.currently_detecting_flakes = False
self.reset_state() self.reset_state()
def reset_state(self): def reset_state(self) -> None:
self.number_of_tests = 0 self.number_of_tests = 0
self.completed_tests = 0 self.completed_tests = 0
self.need_to_erase_last_line = False self.need_to_erase_last_line = False
self.running_tests: Dict[str, str] = {} self.running_tests = {}
if self.currently_detecting_flakes: if self.currently_detecting_flakes:
return return
self.currently_detecting_flakes = False self.currently_detecting_flakes = False
self.test_output = collections.defaultdict(str) self.test_output = collections.defaultdict(str)
self.subtest_failures = collections.defaultdict(list) self.subtest_failures = collections.defaultdict(list)
self.tests_with_failing_subtests = [] self.tests_with_failing_subtests = []
self.unexpected_results: List[UnexpectedResult] = [] self.unexpected_results = []
self.expected = { self.expected = {
"OK": 0, "OK": 0,
@ -159,7 +214,7 @@ class ServoHandler(mozlog.reader.LogHandler):
def any_stable_unexpected(self) -> bool: def any_stable_unexpected(self) -> bool:
return any(not unexpected.flaky for unexpected in self.unexpected_results) return any(not unexpected.flaky for unexpected in self.unexpected_results)
def suite_start(self, data): def suite_start(self, data: SuiteStartData) -> Optional[str]:
# If there were any unexpected results and we are starting another suite, assume # If there were any unexpected results and we are starting another suite, assume
# that this suite has been launched to detect intermittent tests. # that this suite has been launched to detect intermittent tests.
# TODO: Support running more than a single suite at once. # TODO: Support running more than a single suite at once.
@ -170,10 +225,10 @@ class ServoHandler(mozlog.reader.LogHandler):
self.number_of_tests = sum(len(tests) for tests in itervalues(data["tests"])) self.number_of_tests = sum(len(tests) for tests in itervalues(data["tests"]))
self.suite_start_time = data["time"] self.suite_start_time = data["time"]
def suite_end(self, _): def suite_end(self, data) -> Optional[str]:
pass pass
def test_start(self, data): def test_start(self, data: TestStartData) -> Optional[str]:
self.running_tests[data["thread"]] = data["test"] self.running_tests[data["thread"]] = data["test"]
@staticmethod @staticmethod
@ -182,7 +237,7 @@ class ServoHandler(mozlog.reader.LogHandler):
return True return True
return "known_intermittent" in data and data["status"] in data["known_intermittent"] return "known_intermittent" in data and data["status"] in data["known_intermittent"]
def test_end(self, data: dict) -> Optional[UnexpectedResult]: def test_end(self, data: TestEndData) -> Union[UnexpectedResult, str, None]:
self.completed_tests += 1 self.completed_tests += 1
test_status = data["status"] test_status = data["status"]
test_path = data["test"] test_path = data["test"]
@ -249,7 +304,7 @@ class ServoHandler(mozlog.reader.LogHandler):
self.unexpected_results.append(result) self.unexpected_results.append(result)
return result return result
def test_status(self, data: dict): def test_status(self, data: TestStatusData) -> None:
if self.data_was_for_expected_result(data): if self.data_was_for_expected_result(data):
return return
self.subtest_failures[data["test"]].append( self.subtest_failures[data["test"]].append(
@ -264,11 +319,11 @@ class ServoHandler(mozlog.reader.LogHandler):
) )
) )
def process_output(self, data): def process_output(self, data) -> None:
if "test" in data: if "test" in data:
self.test_output[data["test"]] += data["data"] + "\n" self.test_output[data["test"]] += data["data"] + "\n"
def log(self, _): def log(self, data) -> str | None:
pass pass
@ -276,7 +331,13 @@ class ServoFormatter(mozlog.formatters.base.BaseFormatter, ServoHandler):
"""Formatter designed to produce unexpected test results grouped """Formatter designed to produce unexpected test results grouped
together in a readable format.""" together in a readable format."""
def __init__(self): current_display: str
interactive: bool
number_skipped: int
move_up: str
clear_eol: str
def __init__(self) -> None:
ServoHandler.__init__(self) ServoHandler.__init__(self)
self.current_display = "" self.current_display = ""
self.interactive = os.isatty(sys.stdout.fileno()) self.interactive = os.isatty(sys.stdout.fileno())
@ -296,12 +357,12 @@ class ServoFormatter(mozlog.formatters.base.BaseFormatter, ServoHandler):
except Exception as exception: except Exception as exception:
sys.stderr.write("GroupingFormatter: Could not get terminal control characters: %s\n" % exception) sys.stderr.write("GroupingFormatter: Could not get terminal control characters: %s\n" % exception)
def text_to_erase_display(self): def text_to_erase_display(self) -> str:
if not self.interactive or not self.current_display: if not self.interactive or not self.current_display:
return "" return ""
return (self.move_up + self.clear_eol) * self.current_display.count("\n") return (self.move_up + self.clear_eol) * self.current_display.count("\n")
def generate_output(self, text=None, new_display=None): def generate_output(self, text=None, new_display=None) -> str | None:
if not self.interactive: if not self.interactive:
return text return text
@ -312,13 +373,13 @@ class ServoFormatter(mozlog.formatters.base.BaseFormatter, ServoHandler):
self.current_display = new_display self.current_display = new_display
return output + self.current_display return output + self.current_display
def test_counter(self): def test_counter(self) -> str:
if self.number_of_tests == 0: if self.number_of_tests == 0:
return " [%i] " % self.completed_tests return " [%i] " % self.completed_tests
else: else:
return " [%i/%i] " % (self.completed_tests, self.number_of_tests) return " [%i/%i] " % (self.completed_tests, self.number_of_tests)
def build_status_line(self): def build_status_line(self) -> str:
new_display = self.test_counter() new_display = self.test_counter()
if self.running_tests: if self.running_tests:
@ -331,7 +392,7 @@ class ServoFormatter(mozlog.formatters.base.BaseFormatter, ServoHandler):
else: else:
return new_display + "No tests running.\n" return new_display + "No tests running.\n"
def suite_start(self, data): def suite_start(self, data) -> str:
ServoHandler.suite_start(self, data) ServoHandler.suite_start(self, data)
maybe_flakes_msg = " to detect flaky tests" if self.currently_detecting_flakes else "" maybe_flakes_msg = " to detect flaky tests" if self.currently_detecting_flakes else ""
if self.number_of_tests == 0: if self.number_of_tests == 0:
@ -339,12 +400,12 @@ class ServoFormatter(mozlog.formatters.base.BaseFormatter, ServoHandler):
else: else:
return f"Running {self.number_of_tests} tests in {data['source']}{maybe_flakes_msg}\n\n" return f"Running {self.number_of_tests} tests in {data['source']}{maybe_flakes_msg}\n\n"
def test_start(self, data): def test_start(self, data) -> str | None:
ServoHandler.test_start(self, data) ServoHandler.test_start(self, data)
if self.interactive: if self.interactive:
return self.generate_output(new_display=self.build_status_line()) return self.generate_output(new_display=self.build_status_line())
def test_end(self, data): def test_end(self, data) -> str | None:
unexpected_result = ServoHandler.test_end(self, data) unexpected_result = ServoHandler.test_end(self, data)
if unexpected_result: if unexpected_result:
# Surround test output by newlines so that it is easier to read. # Surround test output by newlines so that it is easier to read.
@ -363,10 +424,10 @@ class ServoFormatter(mozlog.formatters.base.BaseFormatter, ServoHandler):
else: else:
return self.generate_output(text="%s%s\n" % (self.test_counter(), data["test"])) return self.generate_output(text="%s%s\n" % (self.test_counter(), data["test"]))
def test_status(self, data): def test_status(self, data) -> None:
ServoHandler.test_status(self, data) ServoHandler.test_status(self, data)
def suite_end(self, data): def suite_end(self, data) -> str | None:
ServoHandler.suite_end(self, data) ServoHandler.suite_end(self, data)
if not self.interactive: if not self.interactive:
output = "\n" output = "\n"
@ -384,7 +445,7 @@ class ServoFormatter(mozlog.formatters.base.BaseFormatter, ServoHandler):
if self.number_skipped: if self.number_skipped:
output += f" \u2022 {self.number_skipped} skipped.\n" output += f" \u2022 {self.number_skipped} skipped.\n"
def text_for_unexpected_list(text, section): def text_for_unexpected_list(text: str, section: str) -> str:
tests = self.unexpected_tests[section] tests = self.unexpected_tests[section]
if not tests: if not tests:
return "" return ""
@ -411,10 +472,10 @@ class ServoFormatter(mozlog.formatters.base.BaseFormatter, ServoHandler):
return self.generate_output(text=output, new_display="") return self.generate_output(text=output, new_display="")
def process_output(self, data): def process_output(self, data) -> None:
ServoHandler.process_output(self, data) ServoHandler.process_output(self, data)
def log(self, data): def log(self, data) -> str | None:
ServoHandler.log(self, data) ServoHandler.log(self, data)
# We are logging messages that begin with STDERR, because that is how exceptions # We are logging messages that begin with STDERR, because that is how exceptions

View file

@ -2,7 +2,10 @@
# License, v. 2.0. If a copy of the MPL was not distributed with this # License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at https://mozilla.org/MPL/2.0/. # file, You can obtain one at https://mozilla.org/MPL/2.0/.
from wptrunner.wptcommandline import TestRoot
from typing import Mapping
import argparse import argparse
from argparse import ArgumentParser
import os import os
import sys import sys
import tempfile import tempfile
@ -10,7 +13,7 @@ from collections import defaultdict
from six import iterkeys, iteritems from six import iterkeys, iteritems
from . import SERVO_ROOT, WPT_PATH from . import SERVO_ROOT, WPT_PATH
from mozlog.structured import commandline from mozlog import commandline
# This must happen after importing from "." since it adds WPT # This must happen after importing from "." since it adds WPT
# tools to the Python system path. # tools to the Python system path.
@ -20,7 +23,7 @@ from wptrunner.wptcommandline import get_test_paths, set_from_config
from wptrunner import wptlogging from wptrunner import wptlogging
def create_parser(): def create_parser() -> ArgumentParser:
p = argparse.ArgumentParser() p = argparse.ArgumentParser()
p.add_argument( p.add_argument(
"--check-clean", action="store_true", help="Check that updating the manifest doesn't lead to any changes" "--check-clean", action="store_true", help="Check that updating the manifest doesn't lead to any changes"
@ -31,7 +34,7 @@ def create_parser():
return p return p
def update(check_clean=True, rebuild=False, logger=None, **kwargs): def update(check_clean=True, rebuild=False, logger=None, **kwargs) -> int:
if not logger: if not logger:
logger = wptlogging.setup(kwargs, {"mach": sys.stdout}) logger = wptlogging.setup(kwargs, {"mach": sys.stdout})
kwargs = { kwargs = {
@ -52,7 +55,7 @@ def update(check_clean=True, rebuild=False, logger=None, **kwargs):
return _update(logger, test_paths, rebuild) return _update(logger, test_paths, rebuild)
def _update(logger, test_paths, rebuild): def _update(logger, test_paths: Mapping[str, TestRoot], rebuild) -> int:
for url_base, paths in iteritems(test_paths): for url_base, paths in iteritems(test_paths):
manifest_path = os.path.join(paths.metadata_path, "MANIFEST.json") manifest_path = os.path.join(paths.metadata_path, "MANIFEST.json")
cache_subdir = os.path.relpath(os.path.dirname(manifest_path), os.path.dirname(__file__)) cache_subdir = os.path.relpath(os.path.dirname(manifest_path), os.path.dirname(__file__))
@ -67,7 +70,7 @@ def _update(logger, test_paths, rebuild):
return 0 return 0
def _check_clean(logger, test_paths): def _check_clean(logger, test_paths: Mapping[str, TestRoot]) -> int:
manifests_by_path = {} manifests_by_path = {}
rv = 0 rv = 0
for url_base, paths in iteritems(test_paths): for url_base, paths in iteritems(test_paths):
@ -104,7 +107,7 @@ def _check_clean(logger, test_paths):
return rv return rv
def diff_manifests(logger, manifest_path, old_manifest, new_manifest): def diff_manifests(logger, manifest_path, old_manifest, new_manifest) -> bool:
"""Lint the differences between old and new versions of a """Lint the differences between old and new versions of a
manifest. Differences are considered significant (and so produce manifest. Differences are considered significant (and so produce
lint errors) if they produce a meaningful difference in the actual lint errors) if they produce a meaningful difference in the actual
@ -167,5 +170,5 @@ def diff_manifests(logger, manifest_path, old_manifest, new_manifest):
return clean return clean
def log_error(logger, manifest_path, msg): def log_error(logger, manifest_path, msg: str) -> None:
logger.lint_error(path=manifest_path, message=msg, lineno=0, source="", linter="wpt-manifest") logger.lint_error(path=manifest_path, message=msg, lineno=0, source="", linter="wpt-manifest")

View file

@ -13,7 +13,7 @@ import urllib.error
import urllib.parse import urllib.parse
import urllib.request import urllib.request
from typing import List, NamedTuple, Optional, Union from typing import List, NamedTuple, Optional, Union, cast, Callable
import mozlog import mozlog
import mozlog.formatters import mozlog.formatters
@ -31,12 +31,12 @@ TRACKER_DASHBOARD_SECRET_ENV_VAR = "INTERMITTENT_TRACKER_DASHBOARD_SECRET_PROD"
TRACKER_DASHBOARD_MAXIMUM_OUTPUT_LENGTH = 1024 TRACKER_DASHBOARD_MAXIMUM_OUTPUT_LENGTH = 1024
def set_if_none(args: dict, key: str, value): def set_if_none(args: dict, key: str, value: bool | int | str) -> None:
if key not in args or args[key] is None: if key not in args or args[key] is None:
args[key] = value args[key] = value
def run_tests(default_binary_path: str, **kwargs): def run_tests(default_binary_path: str, **kwargs) -> int:
print(f"Running WPT tests with {default_binary_path}") print(f"Running WPT tests with {default_binary_path}")
# By default, Rayon selects the number of worker threads based on the # By default, Rayon selects the number of worker threads based on the
@ -99,7 +99,7 @@ def run_tests(default_binary_path: str, **kwargs):
wptcommandline.check_args(kwargs) wptcommandline.check_args(kwargs)
mozlog.commandline.log_formatters["servo"] = ( mozlog.commandline.log_formatters["servo"] = (
ServoFormatter, cast(Callable, ServoFormatter),
"Servo's grouping output formatter", "Servo's grouping output formatter",
) )
@ -147,7 +147,7 @@ class GithubContextInformation(NamedTuple):
class TrackerDashboardFilter: class TrackerDashboardFilter:
def __init__(self): def __init__(self) -> None:
base_url = os.environ.get(TRACKER_API_ENV_VAR, TRACKER_API) base_url = os.environ.get(TRACKER_API_ENV_VAR, TRACKER_API)
self.headers = {"Content-Type": "application/json"} self.headers = {"Content-Type": "application/json"}
if TRACKER_DASHBOARD_SECRET_ENV_VAR in os.environ and os.environ[TRACKER_DASHBOARD_SECRET_ENV_VAR]: if TRACKER_DASHBOARD_SECRET_ENV_VAR in os.environ and os.environ[TRACKER_DASHBOARD_SECRET_ENV_VAR]:
@ -202,7 +202,7 @@ class TrackerDashboardFilter:
data["subtest"] = result.subtest data["subtest"] = result.subtest
return data return data
def report_failures(self, unexpected_results: List[UnexpectedResult]): def report_failures(self, unexpected_results: List[UnexpectedResult]) -> None:
attempts = [] attempts = []
for result in unexpected_results: for result in unexpected_results:
attempts.append(self.make_data_from_result(result)) attempts.append(self.make_data_from_result(result))
@ -244,12 +244,12 @@ def filter_intermittents(unexpected_results: List[UnexpectedResult], output_path
print(f"Filtering {len(unexpected_results)} unexpected results for known intermittents via <{dashboard.url}>") print(f"Filtering {len(unexpected_results)} unexpected results for known intermittents via <{dashboard.url}>")
dashboard.report_failures(unexpected_results) dashboard.report_failures(unexpected_results)
def add_result(output, text, results: List[UnexpectedResult], filter_func) -> None: def add_result(output: list[str], text: str, results: List[UnexpectedResult], filter_func) -> None:
filtered = [str(result) for result in filter(filter_func, results)] filtered = [str(result) for result in filter(filter_func, results)]
if filtered: if filtered:
output += [f"{text} ({len(filtered)}): ", *filtered] output += [f"{text} ({len(filtered)}): ", *filtered]
def is_stable_and_unexpected(result): def is_stable_and_unexpected(result: UnexpectedResult) -> bool:
return not result.flaky and not result.issues return not result.flaky and not result.issues
output: List[str] = [] output: List[str] = []
@ -271,7 +271,7 @@ def filter_intermittents(unexpected_results: List[UnexpectedResult], output_path
def write_unexpected_only_raw_log( def write_unexpected_only_raw_log(
unexpected_results: List[UnexpectedResult], raw_log_file: str, filtered_raw_log_file: str unexpected_results: List[UnexpectedResult], raw_log_file: str, filtered_raw_log_file: str
): ) -> None:
tests = [result.path for result in unexpected_results] tests = [result.path for result in unexpected_results]
print(f"Writing unexpected-only raw log to {filtered_raw_log_file}") print(f"Writing unexpected-only raw log to {filtered_raw_log_file}")

View file

@ -39,7 +39,7 @@ import flask
import flask.cli import flask.cli
import requests import requests
from .exporter import SyncRun, WPTSync from .exporter import SyncRun, WPTSync, LocalGitRepo
from .exporter.step import CreateOrUpdateBranchForPRStep from .exporter.step import CreateOrUpdateBranchForPRStep
TESTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tests") TESTS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tests")
@ -56,14 +56,14 @@ class MockPullRequest:
class MockGitHubAPIServer: class MockGitHubAPIServer:
def __init__(self, port: int): def __init__(self, port: int) -> None:
self.port = port self.port = port
self.disable_logging() self.disable_logging()
self.app = flask.Flask(__name__) self.app = flask.Flask(__name__)
self.pulls: list[MockPullRequest] = [] self.pulls: list[MockPullRequest] = []
class NoLoggingHandler(WSGIRequestHandler): class NoLoggingHandler(WSGIRequestHandler):
def log_message(self, *args): def log_message(self, *args) -> None:
pass pass
if logging.getLogger().level == logging.DEBUG: if logging.getLogger().level == logging.DEBUG:
@ -74,12 +74,12 @@ class MockGitHubAPIServer:
self.server = make_server("localhost", self.port, self.app, handler_class=handler) self.server = make_server("localhost", self.port, self.app, handler_class=handler)
self.start_server_thread() self.start_server_thread()
def disable_logging(self): def disable_logging(self) -> None:
flask.cli.show_server_banner = lambda *args: None flask.cli.show_server_banner = lambda *args: None
logging.getLogger("werkzeug").disabled = True logging.getLogger("werkzeug").disabled = True
logging.getLogger("werkzeug").setLevel(logging.CRITICAL) logging.getLogger("werkzeug").setLevel(logging.CRITICAL)
def start(self): def start(self) -> None:
self.thread.start() self.thread.start()
# Wait for the server to be started. # Wait for the server to be started.
@ -92,7 +92,7 @@ class MockGitHubAPIServer:
except Exception: except Exception:
time.sleep(0.1) time.sleep(0.1)
def reset_server_state_with_pull_requests(self, pulls: list[MockPullRequest]): def reset_server_state_with_pull_requests(self, pulls: list[MockPullRequest]) -> None:
response = requests.get( response = requests.get(
f"http://localhost:{self.port}/reset-mock-github", f"http://localhost:{self.port}/reset-mock-github",
json=[dataclasses.asdict(pull_request) for pull_request in pulls], json=[dataclasses.asdict(pull_request) for pull_request in pulls],
@ -101,21 +101,21 @@ class MockGitHubAPIServer:
assert response.status_code == 200 assert response.status_code == 200
assert response.text == "👍" assert response.text == "👍"
def shutdown(self): def shutdown(self) -> None:
self.server.shutdown() self.server.shutdown()
self.thread.join() self.thread.join()
def start_server_thread(self): def start_server_thread(self) -> None:
# pylint: disable=unused-argument # pylint: disable=unused-argument
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.thread.start() self.thread.start()
@self.app.route("/ping") @self.app.route("/ping")
def ping(): def ping() -> tuple[str, int]:
return ("pong", 200) return ("pong", 200)
@self.app.route("/reset-mock-github") @self.app.route("/reset-mock-github")
def reset_server(): def reset_server() -> tuple[str, int]:
self.pulls = [ self.pulls = [
MockPullRequest(pull_request["head"], pull_request["number"], pull_request["state"]) MockPullRequest(pull_request["head"], pull_request["number"], pull_request["state"])
for pull_request in flask.request.json for pull_request in flask.request.json
@ -123,7 +123,7 @@ class MockGitHubAPIServer:
return ("👍", 200) return ("👍", 200)
@self.app.route("/repos/<org>/<repo>/pulls/<int:number>/merge", methods=["PUT"]) @self.app.route("/repos/<org>/<repo>/pulls/<int:number>/merge", methods=["PUT"])
def merge_pull_request(org, repo, number): def merge_pull_request(org, repo, number) -> tuple[str, int]:
for pull_request in self.pulls: for pull_request in self.pulls:
if pull_request.number == number: if pull_request.number == number:
pull_request.state = "closed" pull_request.state = "closed"
@ -131,7 +131,7 @@ class MockGitHubAPIServer:
return ("", 404) return ("", 404)
@self.app.route("/search/issues", methods=["GET"]) @self.app.route("/search/issues", methods=["GET"])
def search(): def search() -> str:
params = {} params = {}
param_strings = flask.request.args.get("q", "").split(" ") param_strings = flask.request.args.get("q", "").split(" ")
for string in param_strings: for string in param_strings:
@ -149,13 +149,13 @@ class MockGitHubAPIServer:
return json.dumps({"total_count": 0, "items": []}) return json.dumps({"total_count": 0, "items": []})
@self.app.route("/repos/<org>/<repo>/pulls", methods=["POST"]) @self.app.route("/repos/<org>/<repo>/pulls", methods=["POST"])
def create_pull_request(org, repo): def create_pull_request(org, repo) -> dict[str, int]:
new_pr_number = len(self.pulls) + 1 new_pr_number = len(self.pulls) + 1
self.pulls.append(MockPullRequest(flask.request.json["head"], new_pr_number, "open")) self.pulls.append(MockPullRequest(flask.request.json["head"], new_pr_number, "open"))
return {"number": new_pr_number} return {"number": new_pr_number}
@self.app.route("/repos/<org>/<repo>/pulls/<int:number>", methods=["PATCH"]) @self.app.route("/repos/<org>/<repo>/pulls/<int:number>", methods=["PATCH"])
def update_pull_request(org, repo, number): def update_pull_request(org, repo, number) -> tuple[str, int]:
for pull_request in self.pulls: for pull_request in self.pulls:
if pull_request.number == number: if pull_request.number == number:
if "state" in flask.request.json: if "state" in flask.request.json:
@ -166,7 +166,7 @@ class MockGitHubAPIServer:
@self.app.route("/repos/<org>/<repo>/issues/<number>/labels", methods=["GET", "POST"]) @self.app.route("/repos/<org>/<repo>/issues/<number>/labels", methods=["GET", "POST"])
@self.app.route("/repos/<org>/<repo>/issues/<number>/labels/<label>", methods=["DELETE"]) @self.app.route("/repos/<org>/<repo>/issues/<number>/labels/<label>", methods=["DELETE"])
@self.app.route("/repos/<org>/<repo>/issues/<issue>/comments", methods=["GET", "POST"]) @self.app.route("/repos/<org>/<repo>/issues/<issue>/comments", methods=["GET", "POST"])
def other_requests(*args, **kwargs): def other_requests(*args, **kwargs) -> tuple[str, int]:
return ("", 204) return ("", 204)
@ -174,7 +174,7 @@ class TestCleanUpBodyText(unittest.TestCase):
"""Tests that SyncRun.clean_up_body_text properly prepares the """Tests that SyncRun.clean_up_body_text properly prepares the
body text for an upstream pull request.""" body text for an upstream pull request."""
def test_prepare_body(self): def test_prepare_body(self) -> None:
text = "Simple body text" text = "Simple body text"
self.assertEqual(text, SyncRun.clean_up_body_text(text)) self.assertEqual(text, SyncRun.clean_up_body_text(text))
self.assertEqual( self.assertEqual(
@ -210,7 +210,7 @@ class TestApplyCommitsToWPT(unittest.TestCase):
"""Tests that commits are properly applied to WPT by """Tests that commits are properly applied to WPT by
CreateOrUpdateBranchForPRStep._create_or_update_branch_for_pr.""" CreateOrUpdateBranchForPRStep._create_or_update_branch_for_pr."""
def run_test(self, pr_number: int, commit_data: dict): def run_test(self, pr_number: int, commit_data: dict) -> None:
def make_commit(data): def make_commit(data):
with open(os.path.join(TESTS_DIR, data[2]), "rb") as file: with open(os.path.join(TESTS_DIR, data[2]), "rb") as file:
return {"author": data[0], "message": data[1], "diff": file.read()} return {"author": data[0], "message": data[1], "diff": file.read()}
@ -221,7 +221,7 @@ class TestApplyCommitsToWPT(unittest.TestCase):
pull_request = SYNC.servo.get_pull_request(pr_number) pull_request = SYNC.servo.get_pull_request(pr_number)
step = CreateOrUpdateBranchForPRStep({"number": pr_number}, pull_request) step = CreateOrUpdateBranchForPRStep({"number": pr_number}, pull_request)
def get_applied_commits(num_commits: int, applied_commits: list[Tuple[str, str]]): def get_applied_commits(num_commits: int, applied_commits: list[Tuple[str, str]]) -> None:
assert SYNC is not None assert SYNC is not None
repo = SYNC.local_wpt_repo repo = SYNC.local_wpt_repo
log = ["log", "--oneline", f"-{num_commits}"] log = ["log", "--oneline", f"-{num_commits}"]
@ -240,10 +240,10 @@ class TestApplyCommitsToWPT(unittest.TestCase):
expected_commits = [(commit["author"], commit["message"]) for commit in commits] expected_commits = [(commit["author"], commit["message"]) for commit in commits]
self.assertListEqual(applied_commits, expected_commits) self.assertListEqual(applied_commits, expected_commits)
def test_simple_commit(self): def test_simple_commit(self) -> None:
self.run_test(45, [["test author <test@author>", "test commit message", "18746.diff"]]) self.run_test(45, [["test author <test@author>", "test commit message", "18746.diff"]])
def test_two_commits(self): def test_two_commits(self) -> None:
self.run_test( self.run_test(
100, 100,
[ [
@ -253,7 +253,7 @@ class TestApplyCommitsToWPT(unittest.TestCase):
], ],
) )
def test_non_utf8_commit(self): def test_non_utf8_commit(self) -> None:
self.run_test( self.run_test(
100, 100,
[ [
@ -266,15 +266,15 @@ class TestFullSyncRun(unittest.TestCase):
server: Optional[MockGitHubAPIServer] = None server: Optional[MockGitHubAPIServer] = None
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls) -> None:
cls.server = MockGitHubAPIServer(PORT) cls.server = MockGitHubAPIServer(PORT)
@classmethod @classmethod
def tearDownClass(cls): def tearDownClass(cls) -> None:
assert cls.server is not None assert cls.server is not None
cls.server.shutdown() cls.server.shutdown()
def tearDown(self): def tearDown(self) -> None:
assert SYNC is not None assert SYNC is not None
# Clean up any old files. # Clean up any old files.
@ -282,7 +282,7 @@ class TestFullSyncRun(unittest.TestCase):
SYNC.local_servo_repo.run("reset", "--hard", first_commit_hash) SYNC.local_servo_repo.run("reset", "--hard", first_commit_hash)
SYNC.local_servo_repo.run("clean", "-fxd") SYNC.local_servo_repo.run("clean", "-fxd")
def mock_servo_repository_state(self, diffs: list): def mock_servo_repository_state(self, diffs: list) -> str:
assert SYNC is not None assert SYNC is not None
def make_commit_data(diff): def make_commit_data(diff):
@ -333,7 +333,7 @@ class TestFullSyncRun(unittest.TestCase):
SYNC.run(payload, step_callback=lambda step: actual_steps.append(step.name)) SYNC.run(payload, step_callback=lambda step: actual_steps.append(step.name))
return actual_steps return actual_steps
def test_opened_upstreamable_pr(self): def test_opened_upstreamable_pr(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test("opened.json", ["18746.diff"]), self.run_test("opened.json", ["18746.diff"]),
[ [
@ -344,7 +344,7 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_opened_upstreamable_pr_with_move_into_wpt(self): def test_opened_upstreamable_pr_with_move_into_wpt(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test("opened.json", ["move-into-wpt.diff"]), self.run_test("opened.json", ["move-into-wpt.diff"]),
[ [
@ -355,7 +355,7 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_opened_upstreamble_pr_with_move_into_wpt_and_non_ascii_author(self): def test_opened_upstreamble_pr_with_move_into_wpt_and_non_ascii_author(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test( self.run_test(
"opened.json", "opened.json",
@ -376,7 +376,7 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_opened_upstreamable_pr_with_move_out_of_wpt(self): def test_opened_upstreamable_pr_with_move_out_of_wpt(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test("opened.json", ["move-out-of-wpt.diff"]), self.run_test("opened.json", ["move-out-of-wpt.diff"]),
[ [
@ -387,11 +387,11 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_opened_new_mr_with_no_sync_signal(self): def test_opened_new_mr_with_no_sync_signal(self) -> None:
self.assertListEqual(self.run_test("opened-with-no-sync-signal.json", ["18746.diff"]), []) self.assertListEqual(self.run_test("opened-with-no-sync-signal.json", ["18746.diff"]), [])
self.assertListEqual(self.run_test("opened-with-no-sync-signal.json", ["non-wpt.diff"]), []) self.assertListEqual(self.run_test("opened-with-no-sync-signal.json", ["non-wpt.diff"]), [])
def test_opened_upstreamable_pr_not_applying_cleanly_to_upstream(self): def test_opened_upstreamable_pr_not_applying_cleanly_to_upstream(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test("opened.json", ["does-not-apply-cleanly.diff"]), self.run_test("opened.json", ["does-not-apply-cleanly.diff"]),
[ [
@ -401,7 +401,7 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_open_new_upstreamable_pr_with_preexisting_upstream_pr(self): def test_open_new_upstreamable_pr_with_preexisting_upstream_pr(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test( self.run_test(
"opened.json", "opened.json",
@ -416,7 +416,7 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_open_new_non_upstreamable_pr_with_preexisting_upstream_pr(self): def test_open_new_non_upstreamable_pr_with_preexisting_upstream_pr(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test( self.run_test(
"opened.json", "opened.json",
@ -433,7 +433,7 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_opened_upstreamable_pr_with_non_utf8_file_contents(self): def test_opened_upstreamable_pr_with_non_utf8_file_contents(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test("opened.json", ["add-non-utf8-file.diff"]), self.run_test("opened.json", ["add-non-utf8-file.diff"]),
[ [
@ -446,7 +446,7 @@ class TestFullSyncRun(unittest.TestCase):
def test_open_new_upstreamable_pr_with_preexisting_upstream_pr_not_apply_cleanly_to_upstream( def test_open_new_upstreamable_pr_with_preexisting_upstream_pr_not_apply_cleanly_to_upstream(
self, self,
): ) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test( self.run_test(
"opened.json", "opened.json",
@ -463,10 +463,10 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_closed_pr_no_upstream_pr(self): def test_closed_pr_no_upstream_pr(self) -> None:
self.assertListEqual(self.run_test("closed.json", ["18746.diff"]), []) self.assertListEqual(self.run_test("closed.json", ["18746.diff"]), [])
def test_closed_pr_with_preexisting_upstream_pr(self): def test_closed_pr_with_preexisting_upstream_pr(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test( self.run_test(
"closed.json", "closed.json",
@ -476,7 +476,7 @@ class TestFullSyncRun(unittest.TestCase):
["ChangePRStep:wpt/wpt#10:closed", "RemoveBranchForPRStep:servo/wpt/servo_export_18746"], ["ChangePRStep:wpt/wpt#10:closed", "RemoveBranchForPRStep:servo/wpt/servo_export_18746"],
) )
def test_synchronize_move_new_changes_to_preexisting_upstream_pr(self): def test_synchronize_move_new_changes_to_preexisting_upstream_pr(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test( self.run_test(
"synchronize.json", "synchronize.json",
@ -491,7 +491,7 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_synchronize_close_upstream_pr_after_new_changes_do_not_include_wpt(self): def test_synchronize_close_upstream_pr_after_new_changes_do_not_include_wpt(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test( self.run_test(
"synchronize.json", "synchronize.json",
@ -508,7 +508,7 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_synchronize_open_upstream_pr_after_new_changes_include_wpt(self): def test_synchronize_open_upstream_pr_after_new_changes_include_wpt(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test("synchronize.json", ["18746.diff"]), self.run_test("synchronize.json", ["18746.diff"]),
[ [
@ -521,7 +521,7 @@ class TestFullSyncRun(unittest.TestCase):
def test_synchronize_fail_to_update_preexisting_pr_after_new_changes_do_not_apply( def test_synchronize_fail_to_update_preexisting_pr_after_new_changes_do_not_apply(
self, self,
): ) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test( self.run_test(
"synchronize.json", "synchronize.json",
@ -538,7 +538,7 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_edited_with_upstream_pr(self): def test_edited_with_upstream_pr(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test("edited.json", ["wpt.diff"], [MockPullRequest("servo:servo_export_19620", 10)]), self.run_test("edited.json", ["wpt.diff"], [MockPullRequest("servo:servo_export_19620", 10)]),
[ [
@ -548,12 +548,12 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_edited_with_no_upstream_pr(self): def test_edited_with_no_upstream_pr(self) -> None:
self.assertListEqual(self.run_test("edited.json", ["wpt.diff"], []), []) self.assertListEqual(self.run_test("edited.json", ["wpt.diff"], []), [])
def test_synchronize_move_new_changes_to_preexisting_upstream_pr_with_multiple_commits( def test_synchronize_move_new_changes_to_preexisting_upstream_pr_with_multiple_commits(
self, self,
): ) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test("synchronize-multiple.json", ["18746.diff", "non-wpt.diff", "wpt.diff"]), self.run_test("synchronize-multiple.json", ["18746.diff", "non-wpt.diff", "wpt.diff"]),
[ [
@ -564,23 +564,23 @@ class TestFullSyncRun(unittest.TestCase):
], ],
) )
def test_synchronize_with_non_upstreamable_changes(self): def test_synchronize_with_non_upstreamable_changes(self) -> None:
self.assertListEqual(self.run_test("synchronize.json", ["non-wpt.diff"]), []) self.assertListEqual(self.run_test("synchronize.json", ["non-wpt.diff"]), [])
def test_merge_upstream_pr_after_merge(self): def test_merge_upstream_pr_after_merge(self) -> None:
self.assertListEqual( self.assertListEqual(
self.run_test("merged.json", ["18746.diff"], [MockPullRequest("servo:servo_export_19620", 100)]), self.run_test("merged.json", ["18746.diff"], [MockPullRequest("servo:servo_export_19620", 100)]),
["MergePRStep:wpt/wpt#100", "RemoveBranchForPRStep:servo/wpt/servo_export_19620"], ["MergePRStep:wpt/wpt#100", "RemoveBranchForPRStep:servo/wpt/servo_export_19620"],
) )
def test_pr_merged_no_upstream_pr(self): def test_pr_merged_no_upstream_pr(self) -> None:
self.assertListEqual(self.run_test("merged.json", ["18746.diff"]), []) self.assertListEqual(self.run_test("merged.json", ["18746.diff"]), [])
def test_merge_of_non_upstreamble_pr(self): def test_merge_of_non_upstreamble_pr(self) -> None:
self.assertListEqual(self.run_test("merged.json", ["non-wpt.diff"]), []) self.assertListEqual(self.run_test("merged.json", ["non-wpt.diff"]), [])
def setUpModule(): def setUpModule() -> None:
# pylint: disable=invalid-name # pylint: disable=invalid-name
global TMP_DIR, SYNC global TMP_DIR, SYNC
@ -599,7 +599,7 @@ def setUpModule():
suppress_force_push=True, suppress_force_push=True,
) )
def setup_mock_repo(repo_name, local_repo, default_branch: str): def setup_mock_repo(repo_name: str, local_repo: LocalGitRepo, default_branch: str) -> None:
subprocess.check_output(["cp", "-R", "-p", os.path.join(TESTS_DIR, repo_name), local_repo.path]) subprocess.check_output(["cp", "-R", "-p", os.path.join(TESTS_DIR, repo_name), local_repo.path])
local_repo.run("init", "-b", default_branch) local_repo.run("init", "-b", default_branch)
local_repo.run("add", ".") local_repo.run("add", ".")
@ -612,15 +612,15 @@ def setUpModule():
logging.info("=" * 80) logging.info("=" * 80)
def tearDownModule(): def tearDownModule() -> None:
# pylint: disable=invalid-name # pylint: disable=invalid-name
shutil.rmtree(TMP_DIR) shutil.rmtree(TMP_DIR)
def run_tests(): def run_tests() -> bool:
verbosity = 1 if logging.getLogger().level >= logging.WARN else 2 verbosity = 1 if logging.getLogger().level >= logging.WARN else 2
def run_suite(test_case: Type[unittest.TestCase]): def run_suite(test_case: Type[unittest.TestCase]) -> bool:
return ( return (
unittest.TextTestRunner(verbosity=verbosity) unittest.TextTestRunner(verbosity=verbosity)
.run(unittest.TestLoader().loadTestsFromTestCase(test_case)) .run(unittest.TestLoader().loadTestsFromTestCase(test_case))

View file

@ -12,6 +12,8 @@ from wptrunner.update import setup_logging, WPTUpdate # noqa: F401
from wptrunner.update.base import exit_unclean # noqa: F401 from wptrunner.update.base import exit_unclean # noqa: F401
from wptrunner import wptcommandline # noqa: F401 from wptrunner import wptcommandline # noqa: F401
from argparse import ArgumentParser
from . import WPT_PATH from . import WPT_PATH
from . import manifestupdate from . import manifestupdate
@ -49,7 +51,7 @@ def do_sync(**kwargs) -> int:
return 0 return 0
def remove_unused_metadata(): def remove_unused_metadata() -> None:
print("Removing unused results...") print("Removing unused results...")
unused_files = [] unused_files = []
unused_dirs = [] unused_dirs = []
@ -93,7 +95,7 @@ def remove_unused_metadata():
def update_tests(**kwargs) -> int: def update_tests(**kwargs) -> int:
def set_if_none(args: dict, key: str, value): def set_if_none(args: dict, key: str, value: str) -> None:
if key not in args or args[key] is None: if key not in args or args[key] is None:
args[key] = value args[key] = value
@ -117,5 +119,5 @@ def run_update(**kwargs) -> bool:
return WPTUpdate(logger, **kwargs).run() != exit_unclean return WPTUpdate(logger, **kwargs).run() != exit_unclean
def create_parser(**_kwargs): def create_parser(**_kwargs) -> ArgumentParser:
return wptcommandline.create_parser_update() return wptcommandline.create_parser_update()