diff --git a/python/tidy.py b/python/tidy.py index 238dec69124..a474abaa0c8 100644 --- a/python/tidy.py +++ b/python/tidy.py @@ -7,10 +7,12 @@ # option. This file may not be copied, modified, or distributed # except according to those terms. +import contextlib import os import fnmatch import itertools import re +import StringIO import sys from licenseck import licenses @@ -77,7 +79,9 @@ EMACS_HEADER = "/* -*- Mode:" VIM_HEADER = "/* vim:" -def check_license(contents): +def check_license(file_name, contents): + if file_name.endswith(".toml"): + raise StopIteration while contents.startswith(EMACS_HEADER) or contents.startswith(VIM_HEADER): _, _, contents = contents.partition("\n") valid_license = any(contents.startswith(license) for license in licenses) @@ -114,7 +118,7 @@ def check_whitespace(idx, line): yield (idx + 1, "CR on line") -def check_by_line(contents): +def check_by_line(file_name, contents): lines = contents.splitlines(True) for idx, line in enumerate(lines): errors = itertools.chain( @@ -126,33 +130,43 @@ def check_by_line(contents): yield error -def check_flake8(file_paths): - from flake8.main import check_file +def check_flake8(file_name, contents): + from flake8.main import check_code + + if not file_name.endswith(".py"): + raise StopIteration + + @contextlib.contextmanager + def stdout_redirect(where): + sys.stdout = where + try: + yield where + finally: + sys.stdout = sys.__stdout__ ignore = { "W291", # trailing whitespace; the standard tidy process will enforce no trailing whitespace "E501", # 80 character line length; the standard tidy process will enforce line length } - num_errors = 0 - - for file_path in file_paths: - if os.path.splitext(file_path)[-1].lower() != ".py": - continue - - num_errors += check_file(file_path, ignore=ignore) - - return num_errors + output = StringIO.StringIO() + with stdout_redirect(output): + check_code(contents, ignore=ignore) + for error in output.getvalue().splitlines(): + _, line_num, _, message = error.split(":") + yield line_num, message.strip() -def check_toml(contents): +def check_toml(file_name, contents): + if not file_name.endswith(".toml"): + raise StopIteration contents = contents.splitlines(True) for idx, line in enumerate(contents): if line.find("*") != -1: yield (idx + 1, "found asterisk instead of minimum version number") -def check_webidl_spec(contents): +def check_webidl_spec(file_name, contents): # Sorted by this function (in pseudo-Rust). The idea is to group the same # organization together. # fn sort_standards(a: &Url, b: &Url) -> Ordering { @@ -171,6 +185,8 @@ def check_webidl_spec(contents): # } # a_domain.path().cmp(b_domain.path()) # } + if not file_name.endswith(".webidl"): + raise StopIteration standards = [ "//www.khronos.org/registry/webgl/specs", "//developer.mozilla.org/en-US/docs/Web/API", @@ -192,25 +208,18 @@ def check_webidl_spec(contents): ] for i in standards: if contents.find(i) != -1: - return True - return False + raise StopIteration + yield 0, "No specification link found." def collect_errors_for_files(files_to_check, checking_functions): for file_name in files_to_check: with open(file_name, "r") as fp: contents = fp.read() - if file_name.endswith(".toml"): - for error in check_toml(contents): + for check in checking_functions: + for error in check(file_name, contents): + # filename, line, message yield (file_name, error[0], error[1]) - elif file_name.endswith(".webidl"): - if not check_webidl_spec(contents): - yield (file_name, 0, "No specification link found.") - else: - for check in checking_functions: - for error in check(contents): - # filename, line, message - yield (file_name, error[0], error[1]) def check_reftest_order(files_to_check): @@ -241,9 +250,7 @@ def scan(): all_files = collect_file_names() files_to_check = filter(should_check, all_files) - num_flake8_errors = check_flake8(files_to_check) - - checking_functions = [check_license, check_by_line] + checking_functions = [check_license, check_by_line, check_flake8, check_toml, check_webidl_spec] errors = collect_errors_for_files(files_to_check, checking_functions) reftest_files = collect_file_names(reftest_directories) @@ -252,7 +259,7 @@ def scan(): errors = list(itertools.chain(errors, r_errors)) - if errors or num_flake8_errors: + if errors: for error in errors: print("{}:{}: {}".format(*error)) return 1