Mach: add type check on python tidy folder (#38043)

Introduce `python/tidy` folder in pyrefly type checker

Testing: Manual testing via `./mach test-tidy` command

---------

Signed-off-by: Jerens Lensun <jerensslensun@gmail.com>
Signed-off-by: Mukilan Thiyagarajan <mukilan@igalia.com>
Co-authored-by: Mukilan Thiyagarajan <mukilan@igalia.com>
This commit is contained in:
Jerens Lensun 2025-07-17 15:35:11 +08:00 committed by GitHub
parent 2ad250de26
commit 1c4797809a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 137 additions and 90 deletions

View file

@ -18,7 +18,9 @@ import re
import subprocess
import sys
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Tuple
from typing import Any, Dict, List, TypedDict, LiteralString
from collections.abc import Iterator, Callable
import types
import colorama
import toml
@ -42,8 +44,32 @@ ERROR_RAW_URL_IN_RUSTDOC = "Found raw link in rustdoc. Please escape it with ang
sys.path.append(os.path.join(WPT_PATH, "tests"))
sys.path.append(os.path.join(WPT_PATH, "tests", "tools", "wptrunner"))
# Default configs
config = {
CheckingFunction = Callable[[str, bytes], Iterator[tuple[int, str]]]
LineCheckingFunction = Callable[[str, list[bytes]], Iterator[tuple[int, str]]]
IgnoreConfig = TypedDict(
"IgnoreConfig",
{
"files": list[str],
"directories": list[str],
"packages": list[str],
},
)
Config = TypedDict(
"Config",
{
"skip-check-length": bool,
"skip-check-licenses": bool,
"check-alphabetical-order": bool,
"lint-scripts": list,
"blocked-packages": dict[str, Any],
"ignore": IgnoreConfig,
"check_ext": dict[str, Any],
},
)
config: Config = {
"skip-check-length": False,
"skip-check-licenses": False,
"check-alphabetical-order": True,
@ -121,7 +147,7 @@ WEBIDL_STANDARDS = [
]
def is_iter_empty(iterator):
def is_iter_empty(iterator: Iterator[str]) -> tuple[bool, Iterator[str]]:
try:
obj = next(iterator)
return True, itertools.chain((obj,), iterator)
@ -129,11 +155,11 @@ def is_iter_empty(iterator):
return False, iterator
def normalize_path(path: str) -> str:
def relative_path(path: str) -> str:
return os.path.relpath(os.path.abspath(path), TOPDIR)
def normilize_paths(paths):
def normalize_paths(paths: list[str] | str) -> list[str] | str:
if isinstance(paths, str):
return os.path.join(*paths.split("/"))
else:
@ -142,7 +168,7 @@ def normilize_paths(paths):
# A simple wrapper for iterators to show progress
# (Note that it's inefficient for giant iterators, since it iterates once to get the upper bound)
def progress_wrapper(iterator):
def progress_wrapper(iterator: Iterator[str]) -> Iterator[str]:
list_of_stuff = list(iterator)
total_files, progress = len(list_of_stuff), 0
for idx, thing in enumerate(list_of_stuff):
@ -156,16 +182,20 @@ def git_changes_since_last_merge(path):
args = ["git", "log", "-n1", "--committer", "noreply@github.com", "--format=%H"]
last_merge = subprocess.check_output(args, universal_newlines=True).strip()
if not last_merge:
return
return []
args = ["git", "diff", "--name-only", last_merge, path]
file_list = normilize_paths(subprocess.check_output(args, universal_newlines=True).splitlines())
file_list = normalize_paths(subprocess.check_output(args, universal_newlines=True).splitlines())
return file_list
class FileList(object):
def __init__(self, directory, only_changed_files=False, exclude_dirs=[], progress=True):
directory: str
excluded: list[str]
generator: Iterator[str]
def __init__(self, directory, only_changed_files=False, exclude_dirs=[], progress=True) -> None:
self.directory = directory
self.excluded = exclude_dirs
self.generator = self._filter_excluded() if exclude_dirs else self._default_walk()
@ -174,12 +204,12 @@ class FileList(object):
if progress:
self.generator = progress_wrapper(self.generator)
def _default_walk(self):
def _default_walk(self) -> Iterator[str]:
for root, _, files in os.walk(self.directory):
for f in files:
yield os.path.join(root, f)
def _git_changed_files(self):
def _git_changed_files(self) -> Iterator[str]:
file_list = git_changes_since_last_merge(self.directory)
if not file_list:
return
@ -187,21 +217,21 @@ class FileList(object):
if not any(os.path.join(".", os.path.dirname(f)).startswith(path) for path in self.excluded):
yield os.path.join(".", f)
def _filter_excluded(self):
def _filter_excluded(self) -> Iterator[str]:
for root, dirs, files in os.walk(self.directory, topdown=True):
# modify 'dirs' in-place so that we don't do unnecessary traversals in excluded directories
dirs[:] = [d for d in dirs if not any(os.path.join(root, d).startswith(name) for name in self.excluded)]
for rel_path in files:
yield os.path.join(root, rel_path)
def __iter__(self):
def __iter__(self) -> Iterator[str]:
return self.generator
def next(self):
def next(self) -> str:
return next(self.generator)
def filter_file(file_name):
def filter_file(file_name: str) -> bool:
if any(file_name.startswith(ignored_file) for ignored_file in config["ignore"]["files"]):
return False
base_name = os.path.basename(file_name)
@ -210,7 +240,7 @@ def filter_file(file_name):
return True
def filter_files(start_dir, only_changed_files, progress):
def filter_files(start_dir: str, only_changed_files: bool, progress: bool) -> Iterator[str]:
file_iter = FileList(
start_dir,
only_changed_files=only_changed_files,
@ -227,23 +257,26 @@ def filter_files(start_dir, only_changed_files, progress):
yield file_name
def uncomment(line):
def uncomment(line: bytes) -> bytes:
for c in COMMENTS:
if line.startswith(c):
if line.endswith(b"*/"):
return line[len(c) : (len(line) - 3)].strip()
return line[len(c) :].strip()
return line
def is_apache_licensed(header):
def is_apache_licensed(header: str) -> bool:
if "SPDX-License-Identifier: Apache-2.0 OR MIT" in header:
return True
if APACHE in header:
return any(c in header for c in COPYRIGHT)
return False
def check_license(file_name, lines):
def check_license(file_name: str, lines: list[bytes]) -> Iterator[tuple[int, str]]:
if any(file_name.endswith(ext) for ext in (".toml", ".lock", ".json", ".html")) or config["skip-check-licenses"]:
return
@ -272,7 +305,7 @@ def check_license(file_name, lines):
yield (1, "incorrect license")
def check_modeline(file_name, lines):
def check_modeline(file_name: str, lines: list[bytes]) -> Iterator[tuple[int, str]]:
for idx, line in enumerate(lines[:5]):
if re.search(b"^.*[ \t](vi:|vim:|ex:)[ \t]", line):
yield (idx + 1, "vi modeline present")
@ -280,7 +313,7 @@ def check_modeline(file_name, lines):
yield (idx + 1, "emacs file variables present")
def check_length(file_name, idx, line):
def check_length(file_name: str, idx: int, line: bytes) -> Iterator[tuple[int, str]]:
if any(file_name.endswith(ext) for ext in (".lock", ".json", ".html", ".toml")) or config["skip-check-length"]:
return
@ -290,29 +323,29 @@ def check_length(file_name, idx, line):
yield (idx + 1, "Line is longer than %d characters" % max_length)
def contains_url(line):
def contains_url(line: bytes) -> bool:
return bool(URL_REGEX.search(line))
def is_unsplittable(file_name, line):
def is_unsplittable(file_name: str, line: bytes):
return contains_url(line) or file_name.endswith(".rs") and line.startswith(b"use ") and b"{" not in line
def check_whatwg_specific_url(idx, line):
def check_whatwg_specific_url(idx: int, line: bytes) -> Iterator[tuple[int, str]]:
match = re.search(rb"https://html\.spec\.whatwg\.org/multipage/[\w-]+\.html#([\w\'\:-]+)", line)
if match is not None:
preferred_link = "https://html.spec.whatwg.org/multipage/#{}".format(match.group(1).decode("utf-8"))
yield (idx + 1, "link to WHATWG may break in the future, use this format instead: {}".format(preferred_link))
def check_whatwg_single_page_url(idx, line):
def check_whatwg_single_page_url(idx: int, line: bytes) -> Iterator[tuple[int, str]]:
match = re.search(rb"https://html\.spec\.whatwg\.org/#([\w\'\:-]+)", line)
if match is not None:
preferred_link = "https://html.spec.whatwg.org/multipage/#{}".format(match.group(1).decode("utf-8"))
yield (idx + 1, "links to WHATWG single-page url, change to multi page: {}".format(preferred_link))
def check_whitespace(idx, line):
def check_whitespace(idx: int, line: bytes) -> Iterator[tuple[int, str]]:
if line.endswith(b"\n"):
line = line[:-1]
else:
@ -328,7 +361,7 @@ def check_whitespace(idx, line):
yield (idx + 1, "CR on line")
def check_for_raw_urls_in_rustdoc(file_name: str, idx: int, line: bytes):
def check_for_raw_urls_in_rustdoc(file_name: str, idx: int, line: bytes) -> Iterator[tuple[int, str]]:
"""Check that rustdoc comments in Rust source code do not have raw URLs. These appear
as warnings when rustdoc is run. rustdoc warnings could be made fatal, but adding this
check as part of tidy catches this common problem without having to run rustdoc for all
@ -354,7 +387,7 @@ def check_for_raw_urls_in_rustdoc(file_name: str, idx: int, line: bytes):
yield (idx + 1, ERROR_RAW_URL_IN_RUSTDOC)
def check_by_line(file_name: str, lines: list[bytes]):
def check_by_line(file_name: str, lines: list[bytes]) -> Iterator[tuple[int, str]]:
for idx, line in enumerate(lines):
errors = itertools.chain(
check_length(file_name, idx, line),
@ -368,7 +401,7 @@ def check_by_line(file_name: str, lines: list[bytes]):
yield error
def check_ruff_lints():
def check_ruff_lints() -> Iterator[tuple[str, int, str]]:
try:
args = ["ruff", "check", "--output-format", "json"]
subprocess.check_output(args, universal_newlines=True)
@ -398,7 +431,7 @@ class PyreflyDiagnostic:
concise_description: str
def run_python_type_checker() -> Iterator[Tuple[str, int, str]]:
def run_python_type_checker() -> Iterator[tuple[str, int, str]]:
print("\r ➤ Checking type annotations in python files ...")
try:
result = subprocess.run(["pyrefly", "check", "--output-format", "json"], capture_output=True)
@ -410,7 +443,7 @@ def run_python_type_checker() -> Iterator[Tuple[str, int, str]]:
else:
for error in errors:
diagnostic = PyreflyDiagnostic(**error)
yield normalize_path(diagnostic.path), diagnostic.line, diagnostic.concise_description
yield relative_path(diagnostic.path), diagnostic.line, diagnostic.concise_description
def run_cargo_deny_lints():
@ -422,7 +455,7 @@ def run_cargo_deny_lints():
errors = []
for line in result.stderr.splitlines():
error_fields = json.loads(line)["fields"]
error_fields = json.loads(str(line))["fields"]
error_code = error_fields.get("code", "unknown")
error_severity = error_fields.get("severity", "unknown")
message = error_fields.get("message", "")
@ -459,7 +492,7 @@ def run_cargo_deny_lints():
yield error
def check_toml(file_name, lines):
def check_toml(file_name: str, lines: list[bytes]) -> Iterator[tuple[int, str]]:
if not file_name.endswith("Cargo.toml"):
return
ok_licensed = False
@ -477,7 +510,7 @@ def check_toml(file_name, lines):
yield (0, ".toml file should contain a valid license.")
def check_shell(file_name, lines):
def check_shell(file_name: str, lines: list[bytes]) -> Iterator[tuple[int, str]]:
if not file_name.endswith(".sh"):
return
@ -524,7 +557,7 @@ def check_shell(file_name, lines):
yield (idx + 1, 'variable substitutions should use the full "${VAR}" form')
def check_rust(file_name, lines):
def check_rust(file_name: str, lines: list[bytes]) -> Iterator[tuple[int, str]]:
if (
not file_name.endswith(".rs")
or file_name.endswith(".mako.rs")
@ -548,11 +581,11 @@ def check_rust(file_name, lines):
os.path.join("*", "ports", "servoshell", "embedder.rs"),
os.path.join("*", "rust_tidy.rs"), # This is for the tests.
]
is_panic_not_allowed_rs_file = any([glob.fnmatch.fnmatch(file_name, path) for path in PANIC_NOT_ALLOWED_PATHS])
is_panic_not_allowed_rs_file = any([fnmatch.fnmatch(file_name, path) for path in PANIC_NOT_ALLOWED_PATHS])
prev_open_brace = False
multi_line_string = False
prev_mod = {}
prev_mod: dict[int, str] = {}
prev_feature_name = ""
indent = 0
@ -620,7 +653,7 @@ def check_rust(file_name, lines):
# flag this line if it matches one of the following regular expressions
# tuple format: (pattern, format_message, filter_function(match, line))
def no_filter(match, line):
def no_filter(match, line) -> bool:
return True
regex_rules = [
@ -735,7 +768,7 @@ def is_associated_type(match, line):
return generic_open and generic_close
def check_webidl_spec(file_name, contents):
def check_webidl_spec(file_name: str, contents: bytes) -> Iterator[tuple[int, str]]:
# Sorted by this function (in pseudo-Rust). The idea is to group the same
# organization together.
# fn sort_standards(a: &Url, b: &Url) -> Ordering {
@ -764,7 +797,7 @@ def check_webidl_spec(file_name, contents):
yield (0, "No specification link found.")
def check_that_manifests_exist():
def check_that_manifests_exist() -> Iterator[tuple[str, int, str]]:
# Determine the metadata and test directories from the configuration file.
metadata_dirs = []
config = configparser.ConfigParser()
@ -776,10 +809,10 @@ def check_that_manifests_exist():
for directory in metadata_dirs:
manifest_path = os.path.join(TOPDIR, directory, "MANIFEST.json")
if not os.path.isfile(manifest_path):
yield (WPT_CONFIG_INI_PATH, "", f"Path in config was not found: {manifest_path}")
yield (WPT_CONFIG_INI_PATH, 0, f"Path in config was not found: {manifest_path}")
def check_that_manifests_are_clean():
def check_that_manifests_are_clean() -> Iterator[tuple[str, int, str]]:
from wptrunner import wptlogging
print("\r ➤ Checking WPT manifests for cleanliness...")
@ -789,16 +822,21 @@ def check_that_manifests_are_clean():
for line in output_stream.getvalue().splitlines():
if "ERROR" in line:
yield (WPT_CONFIG_INI_PATH, 0, line)
yield (WPT_CONFIG_INI_PATH, "", "WPT manifest is dirty. Run `./mach update-manifest`.")
yield (WPT_CONFIG_INI_PATH, 0, "WPT manifest is dirty. Run `./mach update-manifest`.")
def lint_wpt_test_files():
def lint_wpt_test_files() -> Iterator[tuple[str, int, str]]:
from tools.lint import lint
# Override the logging function so that we can collect errors from
# the lint script, which doesn't allow configuration of the output.
messages: List[str] = []
lint.logger.error = lambda message: messages.append(message)
assert lint.logger is not None
def collect_messages(_, message):
messages.append(message)
lint.logger.error = types.MethodType(collect_messages, lint.logger)
# We do not lint all WPT-like tests because they do not all currently have
# lint.ignore files.
@ -816,10 +854,10 @@ def lint_wpt_test_files():
if lint.lint(suite_directory, tests_changed, output_format="normal"):
for message in messages:
(filename, message) = message.split(":", maxsplit=1)
yield (filename, "", message)
yield (filename, 0, message)
def run_wpt_lints(only_changed_files: bool):
def run_wpt_lints(only_changed_files: bool) -> Iterator[tuple[str, int, str]]:
if not os.path.exists(WPT_CONFIG_INI_PATH):
yield (WPT_CONFIG_INI_PATH, 0, f"{WPT_CONFIG_INI_PATH} is required but was not found")
return
@ -837,7 +875,7 @@ def run_wpt_lints(only_changed_files: bool):
yield from lint_wpt_test_files()
def check_spec(file_name, lines):
def check_spec(file_name: str, lines: list[bytes]) -> Iterator[tuple[int, str]]:
if SPEC_BASE_PATH not in file_name:
return
file_name = os.path.relpath(os.path.splitext(file_name)[0], SPEC_BASE_PATH)
@ -879,7 +917,7 @@ def check_spec(file_name, lines):
break
def check_config_file(config_file, print_text=True):
def check_config_file(config_file: LiteralString, print_text: bool = True) -> Iterator[tuple[str, int, str]]:
# Check if config file exists
if not os.path.exists(config_file):
print("%s config file is required but was not found" % config_file)
@ -955,13 +993,13 @@ def check_config_file(config_file, print_text=True):
parse_config(config_content)
def parse_config(config_file):
def parse_config(config_file: dict[str, Any]) -> None:
exclude = config_file.get("ignore", {})
# Add list of ignored directories to config
ignored_directories = [d for p in exclude.get("directories", []) for d in (glob.glob(p) or [p])]
config["ignore"]["directories"] += normilize_paths(ignored_directories)
config["ignore"]["directories"] += normalize_paths(ignored_directories)
# Add list of ignored files to config
config["ignore"]["files"] += normilize_paths(exclude.get("files", []))
config["ignore"]["files"] += normalize_paths(exclude.get("files", []))
# Add list of ignored packages to config
config["ignore"]["packages"] = exclude.get("packages", [])
@ -969,19 +1007,26 @@ def parse_config(config_file):
dirs_to_check = config_file.get("check_ext", {})
# Fix the paths (OS-dependent)
for path, exts in dirs_to_check.items():
config["check_ext"][normilize_paths(path)] = exts
# FIXME: Temporarily ignoring this since the type signature for
# `normalize_paths` must use a constrained type variable for this to
# typecheck but Pyrefly doesn't handle that correctly (but mypy does).
# pyrefly: ignore[bad-argument-type]
config["check_ext"][normalize_paths(path)] = exts
# Add list of blocked packages
config["blocked-packages"] = config_file.get("blocked-packages", {})
# Override default configs
user_configs = config_file.get("configs", [])
for pref in user_configs:
if pref in config:
# FIXME: Temporarily ignoring this since only Pyrefly raises an issue about the dynamic key
# pyrefly: ignore[missing-attribute]
config[pref] = user_configs[pref]
def check_directory_files(directories, print_text=True):
def check_directory_files(directories: dict[str, Any], print_text: bool = True) -> Iterator[tuple[str, int, str]]:
if print_text:
print("\r ➤ Checking directories for correct file extensions...")
for directory, file_extensions in directories.items():
@ -994,7 +1039,12 @@ We only expect files with {ext} extensions in {dir_name}""".format(**details)
yield (filename, 1, message)
def collect_errors_for_files(files_to_check, checking_functions, line_checking_functions, print_text=True):
def collect_errors_for_files(
files_to_check: Iterator[str],
checking_functions: tuple[CheckingFunction, ...],
line_checking_functions: tuple[LineCheckingFunction, ...],
print_text: bool = True,
) -> Iterator[tuple[str, int, str]]:
(has_element, files_to_check) = is_iter_empty(files_to_check)
if not has_element:
return
@ -1005,7 +1055,7 @@ def collect_errors_for_files(files_to_check, checking_functions, line_checking_f
if not os.path.exists(filename):
continue
with open(filename, "rb") as f:
contents = f.read()
contents: bytes = f.read()
if not contents.strip():
yield filename, 0, "file is empty"
continue
@ -1013,13 +1063,13 @@ def collect_errors_for_files(files_to_check, checking_functions, line_checking_f
for error in check(filename, contents):
# the result will be: `(filename, line, message)`
yield (filename,) + error
lines = contents.splitlines(True)
lines: list[bytes] = contents.splitlines(True)
for check in line_checking_functions:
for error in check(filename, lines):
yield (filename,) + error
def scan(only_changed_files=False, progress=False, github_annotations=False):
def scan(only_changed_files=False, progress=False, github_annotations=False) -> int:
github_annotation_manager = GitHubAnnotationManager("test-tidy")
# check config file for errors
config_errors = check_config_file(CONFIG_FILE_PATH)
@ -1027,8 +1077,8 @@ def scan(only_changed_files=False, progress=False, github_annotations=False):
directory_errors = check_directory_files(config["check_ext"])
# standard checks
files_to_check = filter_files(".", only_changed_files, progress)
checking_functions = (check_webidl_spec,)
line_checking_functions = (
checking_functions: tuple[CheckingFunction, ...] = (check_webidl_spec,)
line_checking_functions: tuple[LineCheckingFunction, ...] = (
check_license,
check_by_line,
check_toml,
@ -1066,11 +1116,11 @@ def scan(only_changed_files=False, progress=False, github_annotations=False):
class CargoDenyKrate:
def __init__(self, data: Dict[Any, Any]):
def __init__(self, data: Dict[Any, Any]) -> None:
crate = data["Krate"]
self.name = crate["name"]
self.version = crate["version"]
self.parents = [CargoDenyKrate(parent) for parent in data.get("parents", [])]
def __str__(self):
def __str__(self) -> str:
return f"{self.name}@{self.version}"