script_binding: Add type check on servo script bindings (#38161)

Introduce type checking with Pyrefly in `components/script_bindings`

This commit adds Pyrefly-based type checking to the
`components/script_bindings` directory. The primary goal is to catch
type inconsistencies early and reduce the likelihood of unexpected
runtime errors.

This change affects the `webidl` component, as these script bindings are
responsible for connecting WebIDL specifications to the Rust codebase.

Testing: `./mach test-wpt webidl`
Fixes: *Link to an issue this pull requests fixes or remove this line if
there is no issue*

---------

Signed-off-by: Jerens Lensun <jerensslensun@gmail.com>
This commit is contained in:
Jerens Lensun 2025-08-01 12:34:24 +08:00 committed by GitHub
parent 4ce5b17605
commit b05d265de5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 176 additions and 88 deletions

View file

@ -4,9 +4,12 @@
# Common codegen classes.
from WebIDL import IDLUnionType
from WebIDL import IDLSequenceType
from collections import defaultdict
from itertools import groupby
from typing import Generator, Tuple, Optional, List
from typing import Generator, Optional, cast
from abc import abstractmethod
import operator
import os
@ -122,47 +125,53 @@ RUST_KEYWORDS = {
}
def genericsForType(t):
def genericsForType(t: IDLObject) -> tuple[str, str]:
if containsDomInterface(t):
return ("<D: DomTypes>", "<D>")
return ("", "")
def isDomInterface(t, logging=False):
def isDomInterface(t: IDLObject, logging: bool = False) -> bool:
while isinstance(t, IDLNullableType) or isinstance(t, IDLWrapperType):
t = t.inner
if isinstance(t, IDLInterface):
return True
# pyrefly: ignore # missing-attribute
if t.isCallback() or t.isPromise():
return True
return t.isInterface() and (t.isGeckoInterface() or (t.isSpiderMonkeyInterface() and not t.isBufferSource()))
# pyrefly: ignore # missing-attribute
return t.isInterface() and (t.isSpiderMonkeyInterface() and not t.isBufferSource())
def containsDomInterface(t, logging=False):
def containsDomInterface(t: IDLObject, logging: bool = False) -> bool:
if isinstance(t, IDLArgument):
t = t.type
if isinstance(t, IDLTypedefType):
t = t.innerType
t = t.inner
while isinstance(t, IDLNullableType) or isinstance(t, IDLWrapperType):
t = t.inner
if t.isEnum():
return False
if t.isUnion():
# pyrefly: ignore # missing-attribute
return any(map(lambda x: containsDomInterface(x), t.flatMemberTypes))
if t.isDictionary():
# pyrefly: ignore # missing-attribute, bad-argument-type
return any(map(lambda x: containsDomInterface(x), t.members)) or (t.parent and containsDomInterface(t.parent))
if isDomInterface(t):
return True
# pyrefly: ignore # missing-attribute
if t.isSequence():
# pyrefly: ignore # missing-attribute
return containsDomInterface(t.inner)
return False
def toStringBool(arg):
def toStringBool(arg) -> str:
return str(not not arg).lower()
def toBindingNamespace(arg):
def toBindingNamespace(arg) -> str:
"""
Namespaces are *_Bindings
@ -171,7 +180,7 @@ def toBindingNamespace(arg):
return re.sub("((_workers)?$)", "_Binding\\1", MakeNativeName(arg))
def toBindingModuleFile(arg):
def toBindingModuleFile(arg) -> str:
"""
Module files are *Bindings
@ -180,14 +189,14 @@ def toBindingModuleFile(arg):
return re.sub("((_workers)?$)", "Binding\\1", MakeNativeName(arg))
def toBindingModuleFileFromDescriptor(desc):
def toBindingModuleFileFromDescriptor(desc: Descriptor) -> str:
if desc.maybeGetSuperModule() is not None:
return toBindingModuleFile(desc.maybeGetSuperModule())
else:
return toBindingModuleFile(desc.name)
def stripTrailingWhitespace(text):
def stripTrailingWhitespace(text: str) -> str:
tail = '\n' if text.endswith('\n') else ''
lines = text.splitlines()
for i in range(len(lines)):
@ -254,7 +263,7 @@ numericTags = [
lineStartDetector = re.compile("^(?=[^\n#])", re.MULTILINE)
def indent(s, indentLevel=2):
def indent(s: str, indentLevel: int = 2) -> str:
"""
Indent C++ code.
@ -266,7 +275,7 @@ def indent(s, indentLevel=2):
return re.sub(lineStartDetector, indentLevel * " ", s)
@functools.cache
def dedent(s):
def dedent(s: str) -> str:
"""
Remove all leading whitespace from s, and remove a blank line
at the beginning.
@ -366,12 +375,13 @@ class CGThing():
"""
Abstract base class for things that spit out code.
"""
def __init__(self):
def __init__(self) -> None:
pass # Nothing for now
def define(self):
@abstractmethod
def define(self) -> str:
"""Produce code for a Rust file."""
raise NotImplementedError # Override me!
raise NotImplementedError
class CGMethodCall(CGThing):
@ -379,6 +389,7 @@ class CGMethodCall(CGThing):
A class to generate selection of a method signature from a set of
signatures and generation of a call to that signature.
"""
cgRoot: CGThing
def __init__(self, argsPre, nativeMethodName, static, descriptor, method):
CGThing.__init__(self)
@ -452,7 +463,7 @@ class CGMethodCall(CGThing):
# Doesn't matter which of the possible signatures we use, since
# they all have the same types up to that point; just use
# possibleSignatures[0]
caseBody = [
caseBody: list[CGThing] = [
CGArgumentConverter(possibleSignatures[0][1][i],
i, "args", "argc", descriptor)
for i in range(0, distinguishingIndex)]
@ -574,7 +585,7 @@ class CGMethodCall(CGThing):
argCountCases.append(CGCase(str(argCount),
CGList(caseBody, "\n")))
overloadCGThings = []
overloadCGThings: list[CGThing] = []
overloadCGThings.append(
CGGeneric(f"let argcount = cmp::min(argc, {maxArgCount});"))
overloadCGThings.append(
@ -613,7 +624,7 @@ def typeIsSequenceOrHasSequenceMember(type):
return False
def union_native_type(t):
def union_native_type(t: IDLType) -> str:
name = t.unroll().name
generic = "<D>" if containsDomInterface(t) else ""
return f'GenericUnionTypes::{name}{generic}'
@ -621,7 +632,7 @@ def union_native_type(t):
# Unfortunately, .capitalize() on a string will lowercase things inside the
# string, which we do not want.
def firstCap(string):
def firstCap(string: str) -> str:
return f"{string[0].upper()}{string[1:]}"
@ -1239,7 +1250,7 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None,
def instantiateJSToNativeConversionTemplate(templateBody, replacements,
declType, declName,
needsAutoRoot=False):
needsAutoRoot=False) -> CGThing:
"""
Take the templateBody and declType as returned by
getJSToNativeConversionInfo, a set of replacements as required by the
@ -1298,6 +1309,7 @@ class CGArgumentConverter(CGThing):
argument list, and the argv and argc strings and generates code to
unwrap the argument to the right native type.
"""
converter: CGThing
def __init__(self, argument, index, args, argc, descriptorProvider,
invalidEnumValueFatal=True):
CGThing.__init__(self)
@ -1374,7 +1386,7 @@ if {argc} > {index} {{
return self.converter.define()
def wrapForType(jsvalRef, result='result', successCode='true', pre=''):
def wrapForType(jsvalRef: str, result: str = 'result', successCode: str = 'true', pre: str = '') -> str:
"""
Reflect a Rust value into JS.
@ -1413,7 +1425,7 @@ def returnTypeNeedsOutparam(type):
def outparamTypeFromReturnType(type):
if type.isAny():
return "MutableHandleValue"
raise f"Don't know how to handle {type} as an outparam"
raise TypeError(f"Don't know how to handle {type} as an outparam")
# Returns a conversion behavior suitable for a type
@ -1569,7 +1581,9 @@ class PropertyDefiner:
things we're defining. They should also set self.regular to the list of
things exposed to web pages.
"""
def __init__(self, descriptor, name):
name: str
regular: list
def __init__(self, descriptor, name: str):
self.descriptor = descriptor
self.name = name
@ -1579,6 +1593,10 @@ class PropertyDefiner:
def length(self):
return len(self.regular)
@abstractmethod
def generateArray(self, array, name) -> str:
raise NotImplementedError
def __str__(self):
# We only need to generate id arrays for things that will end
# up used via ResolveProperty or EnumerateProperties.
@ -2126,7 +2144,12 @@ class CGWrapper(CGThing):
"""
Generic CGThing that wraps other CGThings with pre and post text.
"""
def __init__(self, child, pre="", post="", reindent=False):
child: CGThing
pre: str
post: str
reindent: bool
def __init__(self, child, pre: str = "", post: str= "", reindent: bool = False):
CGThing.__init__(self)
self.child = child
self.pre = pre
@ -2568,7 +2591,7 @@ class CGList(CGThing):
Generate code for a list of GCThings. Just concatenates them together, with
an optional joiner string. "\n" is a common joiner.
"""
def __init__(self, children, joiner=""):
def __init__(self, children, joiner: str = "") -> None :
CGThing.__init__(self)
# Make a copy of the kids into a list, because if someone passes in a
# generator we won't be able to both declare and define ourselves, or
@ -2585,7 +2608,7 @@ class CGList(CGThing):
def join(self, iterable):
return self.joiner.join(s for s in iterable if len(s) > 0)
def define(self):
def define(self) -> str:
return self.join(child.define() for child in self.children if child is not None)
def __len__(self):
@ -2607,10 +2630,11 @@ class CGGeneric(CGThing):
A class that spits out a fixed string into the codegen. Can spit out a
separate string for the declaration too.
"""
def __init__(self, text):
text: str
def __init__(self, text: str) -> None:
self.text = text
def define(self):
def define(self) -> str:
return self.text
@ -2620,11 +2644,11 @@ class CGCallbackTempRoot(CGGeneric):
def getAllTypes(
descriptors: List[Descriptor],
dictionaries: List[IDLDictionary],
callbacks: List[IDLCallback],
typedefs: List[IDLTypedef]
) -> Generator[Tuple[IDLType, Optional[Descriptor]], None, None]:
descriptors: list[Descriptor],
dictionaries: list[IDLDictionary],
callbacks: list[IDLCallback],
typedefs: list[IDLTypedef]
) -> Generator[tuple[IDLType, Optional[Descriptor]], None, None]:
"""
Generate all the types we're dealing with. For each type, a tuple
containing type, descriptor, dictionary is yielded. The
@ -2652,10 +2676,10 @@ def getAllTypes(
def UnionTypes(
descriptors: List[Descriptor],
dictionaries: List[IDLDictionary],
callbacks: List[IDLCallback],
typedefs: List[IDLTypedef],
descriptors: list[Descriptor],
dictionaries: list[IDLDictionary],
callbacks: list[IDLCallback],
typedefs: list[IDLTypedef],
config: Configuration
):
"""
@ -2677,6 +2701,7 @@ def UnionTypes(
t = t.unroll()
if not t.isUnion():
continue
# pyrefly: ignore # missing-attribute
for memberType in t.flatMemberTypes:
if memberType.isDictionary() or memberType.isEnum() or memberType.isCallback():
memberModule = getModuleFromObject(memberType)
@ -2969,8 +2994,8 @@ class CGAbstractMethod(CGThing):
def definition_epilogue(self):
return "\n}\n"
def definition_body(self):
raise NotImplementedError # Override me!
def definition_body(self) -> CGThing:
raise NotImplementedError
class CGConstructorEnabled(CGAbstractMethod):
@ -3533,7 +3558,7 @@ assert!((*cache)[PrototypeList::Constructor::{name} as usize].is_null());
f"{toBindingNamespace(parentName)}::GetProtoObject::<D>(cx, global, prototype_proto.handle_mut())"
)
code = [CGGeneric(f"""
code: list = [CGGeneric(f"""
rooted!(in(*cx) let mut prototype_proto = ptr::null_mut::<JSObject>());
{getPrototypeProto};
assert!(!prototype_proto.is_null());""")]
@ -3741,7 +3766,7 @@ class CGGetPerInterfaceObject(CGAbstractMethod):
self.id = f"{idPrefix}::{MakeNativeName(self.descriptor.name)}"
self.variant = self.id.split('::')[-2]
def definition_body(self):
def definition_body(self) -> CGThing:
return CGGeneric(
"get_per_interface_object_handle"
f"(cx, global, ProtoOrIfaceIndex::{self.variant}({self.id}), CreateInterfaceObjects::<D>, rval)"
@ -4038,7 +4063,7 @@ class CGPerSignatureCall(CGThing):
self.argsPre = argsPre
self.arguments = arguments
self.argCount = len(arguments)
cgThings = []
cgThings: list[CGThing] = []
cgThings.extend([CGArgumentConverter(arguments[i], i, self.getArgs(),
self.getArgc(), self.descriptor,
invalidEnumValueFatal=not setter) for
@ -4071,7 +4096,7 @@ class CGPerSignatureCall(CGThing):
def getArgs(self):
return "args" if self.argCount > 0 else ""
def getArgc(self):
def getArgc(self) -> str:
return "argc"
def getArguments(self):
@ -4208,7 +4233,7 @@ class CGAbstractStaticBindingMethod(CGAbstractMethod):
CGAbstractMethod.__init__(self, descriptor, name, "bool", args, extern=True, templateArgs=templateArgs)
self.exposureSet = descriptor.interface.exposureSet
def definition_body(self):
def definition_body(self) -> CGThing:
preamble = """\
let args = CallArgs::from_vp(vp, argc);
let global = D::GlobalScope::from_object(args.callee());
@ -4217,8 +4242,8 @@ let global = D::GlobalScope::from_object(args.callee());
preamble += f"let global = DomRoot::downcast::<D::{list(self.exposureSet)[0]}>(global).unwrap();\n"
return CGList([CGGeneric(preamble), self.generate_code()])
def generate_code(self):
raise NotImplementedError # Override me!
def generate_code(self) -> CGThing:
raise NotImplementedError
def GetConstructorNameForReporting(descriptor, ctor):
@ -4246,7 +4271,7 @@ class CGSpecializedMethod(CGAbstractExternMethod):
Argument('*const JSJitMethodCallArgs', 'args')]
CGAbstractExternMethod.__init__(self, descriptor, name, 'bool', args, templateArgs=["D: DomTypes"])
def definition_body(self):
def definition_body(self) -> CGThing:
nativeName = CGSpecializedMethod.makeNativeName(self.descriptor,
self.method)
return CGWrapper(CGMethodCall([], nativeName, self.method.isStatic(),
@ -4468,7 +4493,7 @@ class CGSpecializedSetter(CGAbstractExternMethod):
Argument('JSJitSetterCallArgs', 'args')]
CGAbstractExternMethod.__init__(self, descriptor, name, "bool", args, templateArgs=["D: DomTypes"])
def definition_body(self):
def definition_body(self) -> CGThing:
nativeName = CGSpecializedSetter.makeNativeName(self.descriptor,
self.attr)
return CGWrapper(
@ -4668,6 +4693,7 @@ pub(crate) fn init_{infoName}<D: DomTypes>() {{
if self.member.slotIndices is not None:
assert isAlwaysInSlot or self.member.getExtendedAttribute("Cached")
isLazilyCachedInSlot = not isAlwaysInSlot
# pyrefly: ignore # unknown-name
slotIndex = memberReservedSlot(self.member) # noqa:FIXME: memberReservedSlot is not defined
# We'll statically assert that this is not too big in
# CGUpdateMemberSlotsMethod, in the case when
@ -4689,6 +4715,7 @@ pub(crate) fn init_{infoName}<D: DomTypes>() {{
result += self.defineJitInfo(setterinfo, setter, "Setter",
False, False, "AliasEverything",
False, False, "0",
# pyrefly: ignore # missing-attribute
[BuiltinTypes[IDLBuiltinType.Types.undefined]],
None)
return result
@ -5381,7 +5408,7 @@ class CGUnionConversionStruct(CGThing):
typename = get_name(memberType)
return CGGeneric(get_match(typename))
other = []
other: list[CGThing] = []
stringConversion = list(map(getStringOrPrimitiveConversion, stringTypes))
numericConversion = list(map(getStringOrPrimitiveConversion, numericTypes))
booleanConversion = list(map(getStringOrPrimitiveConversion, booleanTypes))
@ -5468,10 +5495,10 @@ class ClassItem:
self.name = name
self.visibility = visibility
def declare(self, cgClass):
def declare(self, cgClass) -> str | None:
assert False
def define(self, cgClass):
def define(self, cgClass) -> str | None:
assert False
@ -5482,12 +5509,13 @@ class ClassBase(ClassItem):
def declare(self, cgClass):
return f'{self.visibility} {self.name}'
def define(self, cgClass):
def define(self, cgClass) -> str | None:
# Only in the header
return ''
class ClassMethod(ClassItem):
body: str | None
def __init__(self, name, returnType, args, inline=False, static=False,
virtual=False, const=False, bodyInHeader=False,
templateArgs=None, visibility='public', body=None,
@ -5646,7 +5674,7 @@ pub unsafe fn {self.getDecorators(True)}new({args}) -> Rc<{name}>{body}
args = ', '.join([a.define() for a in self.args])
body = f' {self.getBody()}'
body = f' {self.getBody(cgClass)}'
trimmedBody = stripTrailingWhitespace(body.replace('\n', '\n '))
body = f'\n{trimmedBody}'
if len(body) > 0:
@ -5832,9 +5860,11 @@ class CGProxySpecialOperation(CGPerSignatureCall):
return args
def wrap_return_value(self):
# pyrefly: ignore # missing-attribute
if not self.idlNode.isGetter() or self.templateValues is None:
return ""
# pyrefly: ignore # missing-argument, bad-unpacking
wrap = CGGeneric(wrapForType(**self.templateValues))
wrap = CGIfWrapper("let Some(result) = result", wrap)
return f"\n{wrap.define()}"
@ -6517,8 +6547,8 @@ let this = native_from_object_static::<{self.descriptor.concreteType}>(obj).unwr
self.generate_code(),
])
def generate_code(self):
raise NotImplementedError # Override me!
def generate_code(self) -> CGThing:
raise NotImplementedError
def finalizeHook(descriptor, hookName, context):
@ -7561,6 +7591,7 @@ class CGConcreteBindingRoot(CGThing):
type that is used by handwritten code. Re-export all public types from
the generic bindings with type specialization applied.
"""
root: CGThing | None
def __init__(self, config, prefix, webIDLFile):
descriptors = config.getDescriptors(webIDLFile=webIDLFile,
hasInterfaceObject=True)
@ -7689,7 +7720,7 @@ pub(crate) fn GetConstructorObject(
def define(self):
if not self.root:
return None
return ""
return stripTrailingWhitespace(self.root.define())
@ -7698,6 +7729,7 @@ class CGBindingRoot(CGThing):
Root codegen class for binding generation. Instantiate the class, and call
declare or define to generate header or cpp code (respectively).
"""
root: CGThing | None
def __init__(self, config, prefix, webIDLFile):
descriptors = config.getDescriptors(webIDLFile=webIDLFile,
hasInterfaceObject=True)
@ -7724,7 +7756,7 @@ class CGBindingRoot(CGThing):
return
# Do codegen for all the enums.
cgthings = [CGEnum(e, config) for e in enums]
cgthings: list = [CGEnum(e, config) for e in enums]
# Do codegen for all the typedefs
for t in typedefs:
@ -7776,41 +7808,51 @@ class CGBindingRoot(CGThing):
def define(self):
if not self.root:
return None
return ""
return stripTrailingWhitespace(self.root.define())
def type_needs_tracing(t):
def type_needs_tracing(t: IDLObject):
assert isinstance(t, IDLObject), (t, type(t))
if t.isType():
if isinstance(t, IDLWrapperType):
return type_needs_tracing(t.inner)
# pyrefly: ignore # missing-attribute
if t.nullable():
# pyrefly: ignore # missing-attribute
return type_needs_tracing(t.inner)
# pyrefly: ignore # missing-attribute
if t.isAny():
return True
# pyrefly: ignore # missing-attribute
if t.isObject():
return True
if t.isSequence():
# pyrefly: ignore # missing-attribute
if t.isSequence() :
# pyrefly: ignore # missing-attribute
return type_needs_tracing(t.inner)
if t.isUnion():
# pyrefly: ignore # not-iterable
return any(type_needs_tracing(member) for member in t.flatMemberTypes)
# pyrefly: ignore # bad-argument-type
if is_typed_array(t):
return True
return False
if t.isDictionary():
# pyrefly: ignore # missing-attribute, bad-argument-type
if t.parent and type_needs_tracing(t.parent):
return True
# pyrefly: ignore # missing-attribute
if any(type_needs_tracing(member.type) for member in t.members):
return True
@ -7825,13 +7867,13 @@ def type_needs_tracing(t):
assert False, (t, type(t))
def is_typed_array(t):
def is_typed_array(t: IDLType):
assert isinstance(t, IDLObject), (t, type(t))
return t.isTypedArray() or t.isArrayBuffer() or t.isArrayBufferView()
def type_needs_auto_root(t):
def type_needs_auto_root(t: IDLType):
"""
Certain IDL types, such as `sequence<any>` or `sequence<object>` need to be
traced and wrapped via (Custom)AutoRooter
@ -7853,7 +7895,6 @@ def argument_type(descriptorProvider, ty, optional=False, defaultValue=None, var
ty, descriptorProvider, isArgument=True,
isAutoRooted=type_needs_auto_root(ty))
declType = info.declType
if variadic:
if ty.isGeckoInterface():
declType = CGWrapper(declType, pre="&[", post="]")
@ -7868,6 +7909,7 @@ def argument_type(descriptorProvider, ty, optional=False, defaultValue=None, var
if type_needs_auto_root(ty):
declType = CGTemplatedType("CustomAutoRooterGuard", declType)
assert declType is not None
return declType.define()
@ -8144,6 +8186,14 @@ class CallbackMember(CGNativeMember):
self.exceptionCode = "return Err(JSFailed);\n"
self.body = self.getImpl()
@abstractmethod
def getRvalDecl(self) -> str:
raise NotImplementedError
@abstractmethod
def getCall(self) -> str:
raise NotImplementedError
def getImpl(self):
argvDecl = (
"rooted_vec!(let mut argv);\n"
@ -8299,6 +8349,18 @@ class CallbackMethod(CallbackMember):
else:
return "rooted!(in(*cx) let mut rval = UndefinedValue());\n"
@abstractmethod
def getCallableDecl(self) -> str:
raise NotImplementedError
@abstractmethod
def getThisObj(self) -> str:
raise NotImplementedError
@abstractmethod
def getCallGuard(self) -> str:
raise NotImplementedError
def getCall(self):
if self.argCount > 0:
argv = "argv.as_ptr() as *const JSVal"
@ -8746,7 +8808,7 @@ impl {base} {{
if downcast:
hierarchy[descriptor.interface.parent.identifier.name].append(name)
typeIdCode = []
typeIdCode: list = []
topTypeVariants = [
("ID used by abstract interfaces.", "pub abstract_: ()"),
("ID used by interfaces that are not castable.", "pub alone: ()"),