script_binding: Add type check on servo script bindings (#38161)

Introduce type checking with Pyrefly in `components/script_bindings`

This commit adds Pyrefly-based type checking to the
`components/script_bindings` directory. The primary goal is to catch
type inconsistencies early and reduce the likelihood of unexpected
runtime errors.

This change affects the `webidl` component, as these script bindings are
responsible for connecting WebIDL specifications to the Rust codebase.

Testing: `./mach test-wpt webidl`
Fixes: *Link to an issue this pull requests fixes or remove this line if
there is no issue*

---------

Signed-off-by: Jerens Lensun <jerensslensun@gmail.com>
This commit is contained in:
Jerens Lensun 2025-08-01 12:34:24 +08:00 committed by GitHub
parent 4ce5b17605
commit b05d265de5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 176 additions and 88 deletions

View file

@ -4,9 +4,12 @@
# Common codegen classes. # Common codegen classes.
from WebIDL import IDLUnionType
from WebIDL import IDLSequenceType
from collections import defaultdict from collections import defaultdict
from itertools import groupby from itertools import groupby
from typing import Generator, Tuple, Optional, List from typing import Generator, Optional, cast
from abc import abstractmethod
import operator import operator
import os import os
@ -122,47 +125,53 @@ RUST_KEYWORDS = {
} }
def genericsForType(t): def genericsForType(t: IDLObject) -> tuple[str, str]:
if containsDomInterface(t): if containsDomInterface(t):
return ("<D: DomTypes>", "<D>") return ("<D: DomTypes>", "<D>")
return ("", "") return ("", "")
def isDomInterface(t, logging=False): def isDomInterface(t: IDLObject, logging: bool = False) -> bool:
while isinstance(t, IDLNullableType) or isinstance(t, IDLWrapperType): while isinstance(t, IDLNullableType) or isinstance(t, IDLWrapperType):
t = t.inner t = t.inner
if isinstance(t, IDLInterface): if isinstance(t, IDLInterface):
return True return True
# pyrefly: ignore # missing-attribute
if t.isCallback() or t.isPromise(): if t.isCallback() or t.isPromise():
return True return True
return t.isInterface() and (t.isGeckoInterface() or (t.isSpiderMonkeyInterface() and not t.isBufferSource())) # pyrefly: ignore # missing-attribute
return t.isInterface() and (t.isSpiderMonkeyInterface() and not t.isBufferSource())
def containsDomInterface(t, logging=False): def containsDomInterface(t: IDLObject, logging: bool = False) -> bool:
if isinstance(t, IDLArgument): if isinstance(t, IDLArgument):
t = t.type t = t.type
if isinstance(t, IDLTypedefType): if isinstance(t, IDLTypedefType):
t = t.innerType t = t.inner
while isinstance(t, IDLNullableType) or isinstance(t, IDLWrapperType): while isinstance(t, IDLNullableType) or isinstance(t, IDLWrapperType):
t = t.inner t = t.inner
if t.isEnum(): if t.isEnum():
return False return False
if t.isUnion(): if t.isUnion():
# pyrefly: ignore # missing-attribute
return any(map(lambda x: containsDomInterface(x), t.flatMemberTypes)) return any(map(lambda x: containsDomInterface(x), t.flatMemberTypes))
if t.isDictionary(): if t.isDictionary():
# pyrefly: ignore # missing-attribute, bad-argument-type
return any(map(lambda x: containsDomInterface(x), t.members)) or (t.parent and containsDomInterface(t.parent)) return any(map(lambda x: containsDomInterface(x), t.members)) or (t.parent and containsDomInterface(t.parent))
if isDomInterface(t): if isDomInterface(t):
return True return True
# pyrefly: ignore # missing-attribute
if t.isSequence(): if t.isSequence():
# pyrefly: ignore # missing-attribute
return containsDomInterface(t.inner) return containsDomInterface(t.inner)
return False return False
def toStringBool(arg): def toStringBool(arg) -> str:
return str(not not arg).lower() return str(not not arg).lower()
def toBindingNamespace(arg): def toBindingNamespace(arg) -> str:
""" """
Namespaces are *_Bindings Namespaces are *_Bindings
@ -171,7 +180,7 @@ def toBindingNamespace(arg):
return re.sub("((_workers)?$)", "_Binding\\1", MakeNativeName(arg)) return re.sub("((_workers)?$)", "_Binding\\1", MakeNativeName(arg))
def toBindingModuleFile(arg): def toBindingModuleFile(arg) -> str:
""" """
Module files are *Bindings Module files are *Bindings
@ -180,14 +189,14 @@ def toBindingModuleFile(arg):
return re.sub("((_workers)?$)", "Binding\\1", MakeNativeName(arg)) return re.sub("((_workers)?$)", "Binding\\1", MakeNativeName(arg))
def toBindingModuleFileFromDescriptor(desc): def toBindingModuleFileFromDescriptor(desc: Descriptor) -> str:
if desc.maybeGetSuperModule() is not None: if desc.maybeGetSuperModule() is not None:
return toBindingModuleFile(desc.maybeGetSuperModule()) return toBindingModuleFile(desc.maybeGetSuperModule())
else: else:
return toBindingModuleFile(desc.name) return toBindingModuleFile(desc.name)
def stripTrailingWhitespace(text): def stripTrailingWhitespace(text: str) -> str:
tail = '\n' if text.endswith('\n') else '' tail = '\n' if text.endswith('\n') else ''
lines = text.splitlines() lines = text.splitlines()
for i in range(len(lines)): for i in range(len(lines)):
@ -254,7 +263,7 @@ numericTags = [
lineStartDetector = re.compile("^(?=[^\n#])", re.MULTILINE) lineStartDetector = re.compile("^(?=[^\n#])", re.MULTILINE)
def indent(s, indentLevel=2): def indent(s: str, indentLevel: int = 2) -> str:
""" """
Indent C++ code. Indent C++ code.
@ -266,7 +275,7 @@ def indent(s, indentLevel=2):
return re.sub(lineStartDetector, indentLevel * " ", s) return re.sub(lineStartDetector, indentLevel * " ", s)
@functools.cache @functools.cache
def dedent(s): def dedent(s: str) -> str:
""" """
Remove all leading whitespace from s, and remove a blank line Remove all leading whitespace from s, and remove a blank line
at the beginning. at the beginning.
@ -366,12 +375,13 @@ class CGThing():
""" """
Abstract base class for things that spit out code. Abstract base class for things that spit out code.
""" """
def __init__(self): def __init__(self) -> None:
pass # Nothing for now pass # Nothing for now
def define(self): @abstractmethod
def define(self) -> str:
"""Produce code for a Rust file.""" """Produce code for a Rust file."""
raise NotImplementedError # Override me! raise NotImplementedError
class CGMethodCall(CGThing): class CGMethodCall(CGThing):
@ -379,6 +389,7 @@ class CGMethodCall(CGThing):
A class to generate selection of a method signature from a set of A class to generate selection of a method signature from a set of
signatures and generation of a call to that signature. signatures and generation of a call to that signature.
""" """
cgRoot: CGThing
def __init__(self, argsPre, nativeMethodName, static, descriptor, method): def __init__(self, argsPre, nativeMethodName, static, descriptor, method):
CGThing.__init__(self) CGThing.__init__(self)
@ -452,7 +463,7 @@ class CGMethodCall(CGThing):
# Doesn't matter which of the possible signatures we use, since # Doesn't matter which of the possible signatures we use, since
# they all have the same types up to that point; just use # they all have the same types up to that point; just use
# possibleSignatures[0] # possibleSignatures[0]
caseBody = [ caseBody: list[CGThing] = [
CGArgumentConverter(possibleSignatures[0][1][i], CGArgumentConverter(possibleSignatures[0][1][i],
i, "args", "argc", descriptor) i, "args", "argc", descriptor)
for i in range(0, distinguishingIndex)] for i in range(0, distinguishingIndex)]
@ -574,7 +585,7 @@ class CGMethodCall(CGThing):
argCountCases.append(CGCase(str(argCount), argCountCases.append(CGCase(str(argCount),
CGList(caseBody, "\n"))) CGList(caseBody, "\n")))
overloadCGThings = [] overloadCGThings: list[CGThing] = []
overloadCGThings.append( overloadCGThings.append(
CGGeneric(f"let argcount = cmp::min(argc, {maxArgCount});")) CGGeneric(f"let argcount = cmp::min(argc, {maxArgCount});"))
overloadCGThings.append( overloadCGThings.append(
@ -613,7 +624,7 @@ def typeIsSequenceOrHasSequenceMember(type):
return False return False
def union_native_type(t): def union_native_type(t: IDLType) -> str:
name = t.unroll().name name = t.unroll().name
generic = "<D>" if containsDomInterface(t) else "" generic = "<D>" if containsDomInterface(t) else ""
return f'GenericUnionTypes::{name}{generic}' return f'GenericUnionTypes::{name}{generic}'
@ -621,7 +632,7 @@ def union_native_type(t):
# Unfortunately, .capitalize() on a string will lowercase things inside the # Unfortunately, .capitalize() on a string will lowercase things inside the
# string, which we do not want. # string, which we do not want.
def firstCap(string): def firstCap(string: str) -> str:
return f"{string[0].upper()}{string[1:]}" return f"{string[0].upper()}{string[1:]}"
@ -1239,7 +1250,7 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None,
def instantiateJSToNativeConversionTemplate(templateBody, replacements, def instantiateJSToNativeConversionTemplate(templateBody, replacements,
declType, declName, declType, declName,
needsAutoRoot=False): needsAutoRoot=False) -> CGThing:
""" """
Take the templateBody and declType as returned by Take the templateBody and declType as returned by
getJSToNativeConversionInfo, a set of replacements as required by the getJSToNativeConversionInfo, a set of replacements as required by the
@ -1298,6 +1309,7 @@ class CGArgumentConverter(CGThing):
argument list, and the argv and argc strings and generates code to argument list, and the argv and argc strings and generates code to
unwrap the argument to the right native type. unwrap the argument to the right native type.
""" """
converter: CGThing
def __init__(self, argument, index, args, argc, descriptorProvider, def __init__(self, argument, index, args, argc, descriptorProvider,
invalidEnumValueFatal=True): invalidEnumValueFatal=True):
CGThing.__init__(self) CGThing.__init__(self)
@ -1374,7 +1386,7 @@ if {argc} > {index} {{
return self.converter.define() return self.converter.define()
def wrapForType(jsvalRef, result='result', successCode='true', pre=''): def wrapForType(jsvalRef: str, result: str = 'result', successCode: str = 'true', pre: str = '') -> str:
""" """
Reflect a Rust value into JS. Reflect a Rust value into JS.
@ -1413,7 +1425,7 @@ def returnTypeNeedsOutparam(type):
def outparamTypeFromReturnType(type): def outparamTypeFromReturnType(type):
if type.isAny(): if type.isAny():
return "MutableHandleValue" return "MutableHandleValue"
raise f"Don't know how to handle {type} as an outparam" raise TypeError(f"Don't know how to handle {type} as an outparam")
# Returns a conversion behavior suitable for a type # Returns a conversion behavior suitable for a type
@ -1569,7 +1581,9 @@ class PropertyDefiner:
things we're defining. They should also set self.regular to the list of things we're defining. They should also set self.regular to the list of
things exposed to web pages. things exposed to web pages.
""" """
def __init__(self, descriptor, name): name: str
regular: list
def __init__(self, descriptor, name: str):
self.descriptor = descriptor self.descriptor = descriptor
self.name = name self.name = name
@ -1579,6 +1593,10 @@ class PropertyDefiner:
def length(self): def length(self):
return len(self.regular) return len(self.regular)
@abstractmethod
def generateArray(self, array, name) -> str:
raise NotImplementedError
def __str__(self): def __str__(self):
# We only need to generate id arrays for things that will end # We only need to generate id arrays for things that will end
# up used via ResolveProperty or EnumerateProperties. # up used via ResolveProperty or EnumerateProperties.
@ -2126,7 +2144,12 @@ class CGWrapper(CGThing):
""" """
Generic CGThing that wraps other CGThings with pre and post text. Generic CGThing that wraps other CGThings with pre and post text.
""" """
def __init__(self, child, pre="", post="", reindent=False): child: CGThing
pre: str
post: str
reindent: bool
def __init__(self, child, pre: str = "", post: str= "", reindent: bool = False):
CGThing.__init__(self) CGThing.__init__(self)
self.child = child self.child = child
self.pre = pre self.pre = pre
@ -2568,7 +2591,7 @@ class CGList(CGThing):
Generate code for a list of GCThings. Just concatenates them together, with Generate code for a list of GCThings. Just concatenates them together, with
an optional joiner string. "\n" is a common joiner. an optional joiner string. "\n" is a common joiner.
""" """
def __init__(self, children, joiner=""): def __init__(self, children, joiner: str = "") -> None :
CGThing.__init__(self) CGThing.__init__(self)
# Make a copy of the kids into a list, because if someone passes in a # Make a copy of the kids into a list, because if someone passes in a
# generator we won't be able to both declare and define ourselves, or # generator we won't be able to both declare and define ourselves, or
@ -2585,7 +2608,7 @@ class CGList(CGThing):
def join(self, iterable): def join(self, iterable):
return self.joiner.join(s for s in iterable if len(s) > 0) return self.joiner.join(s for s in iterable if len(s) > 0)
def define(self): def define(self) -> str:
return self.join(child.define() for child in self.children if child is not None) return self.join(child.define() for child in self.children if child is not None)
def __len__(self): def __len__(self):
@ -2607,10 +2630,11 @@ class CGGeneric(CGThing):
A class that spits out a fixed string into the codegen. Can spit out a A class that spits out a fixed string into the codegen. Can spit out a
separate string for the declaration too. separate string for the declaration too.
""" """
def __init__(self, text): text: str
def __init__(self, text: str) -> None:
self.text = text self.text = text
def define(self): def define(self) -> str:
return self.text return self.text
@ -2620,11 +2644,11 @@ class CGCallbackTempRoot(CGGeneric):
def getAllTypes( def getAllTypes(
descriptors: List[Descriptor], descriptors: list[Descriptor],
dictionaries: List[IDLDictionary], dictionaries: list[IDLDictionary],
callbacks: List[IDLCallback], callbacks: list[IDLCallback],
typedefs: List[IDLTypedef] typedefs: list[IDLTypedef]
) -> Generator[Tuple[IDLType, Optional[Descriptor]], None, None]: ) -> Generator[tuple[IDLType, Optional[Descriptor]], None, None]:
""" """
Generate all the types we're dealing with. For each type, a tuple Generate all the types we're dealing with. For each type, a tuple
containing type, descriptor, dictionary is yielded. The containing type, descriptor, dictionary is yielded. The
@ -2652,10 +2676,10 @@ def getAllTypes(
def UnionTypes( def UnionTypes(
descriptors: List[Descriptor], descriptors: list[Descriptor],
dictionaries: List[IDLDictionary], dictionaries: list[IDLDictionary],
callbacks: List[IDLCallback], callbacks: list[IDLCallback],
typedefs: List[IDLTypedef], typedefs: list[IDLTypedef],
config: Configuration config: Configuration
): ):
""" """
@ -2677,6 +2701,7 @@ def UnionTypes(
t = t.unroll() t = t.unroll()
if not t.isUnion(): if not t.isUnion():
continue continue
# pyrefly: ignore # missing-attribute
for memberType in t.flatMemberTypes: for memberType in t.flatMemberTypes:
if memberType.isDictionary() or memberType.isEnum() or memberType.isCallback(): if memberType.isDictionary() or memberType.isEnum() or memberType.isCallback():
memberModule = getModuleFromObject(memberType) memberModule = getModuleFromObject(memberType)
@ -2969,8 +2994,8 @@ class CGAbstractMethod(CGThing):
def definition_epilogue(self): def definition_epilogue(self):
return "\n}\n" return "\n}\n"
def definition_body(self): def definition_body(self) -> CGThing:
raise NotImplementedError # Override me! raise NotImplementedError
class CGConstructorEnabled(CGAbstractMethod): class CGConstructorEnabled(CGAbstractMethod):
@ -3533,7 +3558,7 @@ assert!((*cache)[PrototypeList::Constructor::{name} as usize].is_null());
f"{toBindingNamespace(parentName)}::GetProtoObject::<D>(cx, global, prototype_proto.handle_mut())" f"{toBindingNamespace(parentName)}::GetProtoObject::<D>(cx, global, prototype_proto.handle_mut())"
) )
code = [CGGeneric(f""" code: list = [CGGeneric(f"""
rooted!(in(*cx) let mut prototype_proto = ptr::null_mut::<JSObject>()); rooted!(in(*cx) let mut prototype_proto = ptr::null_mut::<JSObject>());
{getPrototypeProto}; {getPrototypeProto};
assert!(!prototype_proto.is_null());""")] assert!(!prototype_proto.is_null());""")]
@ -3741,7 +3766,7 @@ class CGGetPerInterfaceObject(CGAbstractMethod):
self.id = f"{idPrefix}::{MakeNativeName(self.descriptor.name)}" self.id = f"{idPrefix}::{MakeNativeName(self.descriptor.name)}"
self.variant = self.id.split('::')[-2] self.variant = self.id.split('::')[-2]
def definition_body(self): def definition_body(self) -> CGThing:
return CGGeneric( return CGGeneric(
"get_per_interface_object_handle" "get_per_interface_object_handle"
f"(cx, global, ProtoOrIfaceIndex::{self.variant}({self.id}), CreateInterfaceObjects::<D>, rval)" f"(cx, global, ProtoOrIfaceIndex::{self.variant}({self.id}), CreateInterfaceObjects::<D>, rval)"
@ -4038,7 +4063,7 @@ class CGPerSignatureCall(CGThing):
self.argsPre = argsPre self.argsPre = argsPre
self.arguments = arguments self.arguments = arguments
self.argCount = len(arguments) self.argCount = len(arguments)
cgThings = [] cgThings: list[CGThing] = []
cgThings.extend([CGArgumentConverter(arguments[i], i, self.getArgs(), cgThings.extend([CGArgumentConverter(arguments[i], i, self.getArgs(),
self.getArgc(), self.descriptor, self.getArgc(), self.descriptor,
invalidEnumValueFatal=not setter) for invalidEnumValueFatal=not setter) for
@ -4071,7 +4096,7 @@ class CGPerSignatureCall(CGThing):
def getArgs(self): def getArgs(self):
return "args" if self.argCount > 0 else "" return "args" if self.argCount > 0 else ""
def getArgc(self): def getArgc(self) -> str:
return "argc" return "argc"
def getArguments(self): def getArguments(self):
@ -4208,7 +4233,7 @@ class CGAbstractStaticBindingMethod(CGAbstractMethod):
CGAbstractMethod.__init__(self, descriptor, name, "bool", args, extern=True, templateArgs=templateArgs) CGAbstractMethod.__init__(self, descriptor, name, "bool", args, extern=True, templateArgs=templateArgs)
self.exposureSet = descriptor.interface.exposureSet self.exposureSet = descriptor.interface.exposureSet
def definition_body(self): def definition_body(self) -> CGThing:
preamble = """\ preamble = """\
let args = CallArgs::from_vp(vp, argc); let args = CallArgs::from_vp(vp, argc);
let global = D::GlobalScope::from_object(args.callee()); let global = D::GlobalScope::from_object(args.callee());
@ -4217,8 +4242,8 @@ let global = D::GlobalScope::from_object(args.callee());
preamble += f"let global = DomRoot::downcast::<D::{list(self.exposureSet)[0]}>(global).unwrap();\n" preamble += f"let global = DomRoot::downcast::<D::{list(self.exposureSet)[0]}>(global).unwrap();\n"
return CGList([CGGeneric(preamble), self.generate_code()]) return CGList([CGGeneric(preamble), self.generate_code()])
def generate_code(self): def generate_code(self) -> CGThing:
raise NotImplementedError # Override me! raise NotImplementedError
def GetConstructorNameForReporting(descriptor, ctor): def GetConstructorNameForReporting(descriptor, ctor):
@ -4246,7 +4271,7 @@ class CGSpecializedMethod(CGAbstractExternMethod):
Argument('*const JSJitMethodCallArgs', 'args')] Argument('*const JSJitMethodCallArgs', 'args')]
CGAbstractExternMethod.__init__(self, descriptor, name, 'bool', args, templateArgs=["D: DomTypes"]) CGAbstractExternMethod.__init__(self, descriptor, name, 'bool', args, templateArgs=["D: DomTypes"])
def definition_body(self): def definition_body(self) -> CGThing:
nativeName = CGSpecializedMethod.makeNativeName(self.descriptor, nativeName = CGSpecializedMethod.makeNativeName(self.descriptor,
self.method) self.method)
return CGWrapper(CGMethodCall([], nativeName, self.method.isStatic(), return CGWrapper(CGMethodCall([], nativeName, self.method.isStatic(),
@ -4468,7 +4493,7 @@ class CGSpecializedSetter(CGAbstractExternMethod):
Argument('JSJitSetterCallArgs', 'args')] Argument('JSJitSetterCallArgs', 'args')]
CGAbstractExternMethod.__init__(self, descriptor, name, "bool", args, templateArgs=["D: DomTypes"]) CGAbstractExternMethod.__init__(self, descriptor, name, "bool", args, templateArgs=["D: DomTypes"])
def definition_body(self): def definition_body(self) -> CGThing:
nativeName = CGSpecializedSetter.makeNativeName(self.descriptor, nativeName = CGSpecializedSetter.makeNativeName(self.descriptor,
self.attr) self.attr)
return CGWrapper( return CGWrapper(
@ -4668,6 +4693,7 @@ pub(crate) fn init_{infoName}<D: DomTypes>() {{
if self.member.slotIndices is not None: if self.member.slotIndices is not None:
assert isAlwaysInSlot or self.member.getExtendedAttribute("Cached") assert isAlwaysInSlot or self.member.getExtendedAttribute("Cached")
isLazilyCachedInSlot = not isAlwaysInSlot isLazilyCachedInSlot = not isAlwaysInSlot
# pyrefly: ignore # unknown-name
slotIndex = memberReservedSlot(self.member) # noqa:FIXME: memberReservedSlot is not defined slotIndex = memberReservedSlot(self.member) # noqa:FIXME: memberReservedSlot is not defined
# We'll statically assert that this is not too big in # We'll statically assert that this is not too big in
# CGUpdateMemberSlotsMethod, in the case when # CGUpdateMemberSlotsMethod, in the case when
@ -4689,6 +4715,7 @@ pub(crate) fn init_{infoName}<D: DomTypes>() {{
result += self.defineJitInfo(setterinfo, setter, "Setter", result += self.defineJitInfo(setterinfo, setter, "Setter",
False, False, "AliasEverything", False, False, "AliasEverything",
False, False, "0", False, False, "0",
# pyrefly: ignore # missing-attribute
[BuiltinTypes[IDLBuiltinType.Types.undefined]], [BuiltinTypes[IDLBuiltinType.Types.undefined]],
None) None)
return result return result
@ -5381,7 +5408,7 @@ class CGUnionConversionStruct(CGThing):
typename = get_name(memberType) typename = get_name(memberType)
return CGGeneric(get_match(typename)) return CGGeneric(get_match(typename))
other = [] other: list[CGThing] = []
stringConversion = list(map(getStringOrPrimitiveConversion, stringTypes)) stringConversion = list(map(getStringOrPrimitiveConversion, stringTypes))
numericConversion = list(map(getStringOrPrimitiveConversion, numericTypes)) numericConversion = list(map(getStringOrPrimitiveConversion, numericTypes))
booleanConversion = list(map(getStringOrPrimitiveConversion, booleanTypes)) booleanConversion = list(map(getStringOrPrimitiveConversion, booleanTypes))
@ -5468,10 +5495,10 @@ class ClassItem:
self.name = name self.name = name
self.visibility = visibility self.visibility = visibility
def declare(self, cgClass): def declare(self, cgClass) -> str | None:
assert False assert False
def define(self, cgClass): def define(self, cgClass) -> str | None:
assert False assert False
@ -5482,12 +5509,13 @@ class ClassBase(ClassItem):
def declare(self, cgClass): def declare(self, cgClass):
return f'{self.visibility} {self.name}' return f'{self.visibility} {self.name}'
def define(self, cgClass): def define(self, cgClass) -> str | None:
# Only in the header # Only in the header
return '' return ''
class ClassMethod(ClassItem): class ClassMethod(ClassItem):
body: str | None
def __init__(self, name, returnType, args, inline=False, static=False, def __init__(self, name, returnType, args, inline=False, static=False,
virtual=False, const=False, bodyInHeader=False, virtual=False, const=False, bodyInHeader=False,
templateArgs=None, visibility='public', body=None, templateArgs=None, visibility='public', body=None,
@ -5646,7 +5674,7 @@ pub unsafe fn {self.getDecorators(True)}new({args}) -> Rc<{name}>{body}
args = ', '.join([a.define() for a in self.args]) args = ', '.join([a.define() for a in self.args])
body = f' {self.getBody()}' body = f' {self.getBody(cgClass)}'
trimmedBody = stripTrailingWhitespace(body.replace('\n', '\n ')) trimmedBody = stripTrailingWhitespace(body.replace('\n', '\n '))
body = f'\n{trimmedBody}' body = f'\n{trimmedBody}'
if len(body) > 0: if len(body) > 0:
@ -5832,9 +5860,11 @@ class CGProxySpecialOperation(CGPerSignatureCall):
return args return args
def wrap_return_value(self): def wrap_return_value(self):
# pyrefly: ignore # missing-attribute
if not self.idlNode.isGetter() or self.templateValues is None: if not self.idlNode.isGetter() or self.templateValues is None:
return "" return ""
# pyrefly: ignore # missing-argument, bad-unpacking
wrap = CGGeneric(wrapForType(**self.templateValues)) wrap = CGGeneric(wrapForType(**self.templateValues))
wrap = CGIfWrapper("let Some(result) = result", wrap) wrap = CGIfWrapper("let Some(result) = result", wrap)
return f"\n{wrap.define()}" return f"\n{wrap.define()}"
@ -6517,8 +6547,8 @@ let this = native_from_object_static::<{self.descriptor.concreteType}>(obj).unwr
self.generate_code(), self.generate_code(),
]) ])
def generate_code(self): def generate_code(self) -> CGThing:
raise NotImplementedError # Override me! raise NotImplementedError
def finalizeHook(descriptor, hookName, context): def finalizeHook(descriptor, hookName, context):
@ -7561,6 +7591,7 @@ class CGConcreteBindingRoot(CGThing):
type that is used by handwritten code. Re-export all public types from type that is used by handwritten code. Re-export all public types from
the generic bindings with type specialization applied. the generic bindings with type specialization applied.
""" """
root: CGThing | None
def __init__(self, config, prefix, webIDLFile): def __init__(self, config, prefix, webIDLFile):
descriptors = config.getDescriptors(webIDLFile=webIDLFile, descriptors = config.getDescriptors(webIDLFile=webIDLFile,
hasInterfaceObject=True) hasInterfaceObject=True)
@ -7689,7 +7720,7 @@ pub(crate) fn GetConstructorObject(
def define(self): def define(self):
if not self.root: if not self.root:
return None return ""
return stripTrailingWhitespace(self.root.define()) return stripTrailingWhitespace(self.root.define())
@ -7698,6 +7729,7 @@ class CGBindingRoot(CGThing):
Root codegen class for binding generation. Instantiate the class, and call Root codegen class for binding generation. Instantiate the class, and call
declare or define to generate header or cpp code (respectively). declare or define to generate header or cpp code (respectively).
""" """
root: CGThing | None
def __init__(self, config, prefix, webIDLFile): def __init__(self, config, prefix, webIDLFile):
descriptors = config.getDescriptors(webIDLFile=webIDLFile, descriptors = config.getDescriptors(webIDLFile=webIDLFile,
hasInterfaceObject=True) hasInterfaceObject=True)
@ -7724,7 +7756,7 @@ class CGBindingRoot(CGThing):
return return
# Do codegen for all the enums. # Do codegen for all the enums.
cgthings = [CGEnum(e, config) for e in enums] cgthings: list = [CGEnum(e, config) for e in enums]
# Do codegen for all the typedefs # Do codegen for all the typedefs
for t in typedefs: for t in typedefs:
@ -7776,41 +7808,51 @@ class CGBindingRoot(CGThing):
def define(self): def define(self):
if not self.root: if not self.root:
return None return ""
return stripTrailingWhitespace(self.root.define()) return stripTrailingWhitespace(self.root.define())
def type_needs_tracing(t): def type_needs_tracing(t: IDLObject):
assert isinstance(t, IDLObject), (t, type(t)) assert isinstance(t, IDLObject), (t, type(t))
if t.isType(): if t.isType():
if isinstance(t, IDLWrapperType): if isinstance(t, IDLWrapperType):
return type_needs_tracing(t.inner) return type_needs_tracing(t.inner)
# pyrefly: ignore # missing-attribute
if t.nullable(): if t.nullable():
# pyrefly: ignore # missing-attribute
return type_needs_tracing(t.inner) return type_needs_tracing(t.inner)
# pyrefly: ignore # missing-attribute
if t.isAny(): if t.isAny():
return True return True
# pyrefly: ignore # missing-attribute
if t.isObject(): if t.isObject():
return True return True
if t.isSequence(): # pyrefly: ignore # missing-attribute
if t.isSequence() :
# pyrefly: ignore # missing-attribute
return type_needs_tracing(t.inner) return type_needs_tracing(t.inner)
if t.isUnion(): if t.isUnion():
# pyrefly: ignore # not-iterable
return any(type_needs_tracing(member) for member in t.flatMemberTypes) return any(type_needs_tracing(member) for member in t.flatMemberTypes)
# pyrefly: ignore # bad-argument-type
if is_typed_array(t): if is_typed_array(t):
return True return True
return False return False
if t.isDictionary(): if t.isDictionary():
# pyrefly: ignore # missing-attribute, bad-argument-type
if t.parent and type_needs_tracing(t.parent): if t.parent and type_needs_tracing(t.parent):
return True return True
# pyrefly: ignore # missing-attribute
if any(type_needs_tracing(member.type) for member in t.members): if any(type_needs_tracing(member.type) for member in t.members):
return True return True
@ -7825,13 +7867,13 @@ def type_needs_tracing(t):
assert False, (t, type(t)) assert False, (t, type(t))
def is_typed_array(t): def is_typed_array(t: IDLType):
assert isinstance(t, IDLObject), (t, type(t)) assert isinstance(t, IDLObject), (t, type(t))
return t.isTypedArray() or t.isArrayBuffer() or t.isArrayBufferView() return t.isTypedArray() or t.isArrayBuffer() or t.isArrayBufferView()
def type_needs_auto_root(t): def type_needs_auto_root(t: IDLType):
""" """
Certain IDL types, such as `sequence<any>` or `sequence<object>` need to be Certain IDL types, such as `sequence<any>` or `sequence<object>` need to be
traced and wrapped via (Custom)AutoRooter traced and wrapped via (Custom)AutoRooter
@ -7853,7 +7895,6 @@ def argument_type(descriptorProvider, ty, optional=False, defaultValue=None, var
ty, descriptorProvider, isArgument=True, ty, descriptorProvider, isArgument=True,
isAutoRooted=type_needs_auto_root(ty)) isAutoRooted=type_needs_auto_root(ty))
declType = info.declType declType = info.declType
if variadic: if variadic:
if ty.isGeckoInterface(): if ty.isGeckoInterface():
declType = CGWrapper(declType, pre="&[", post="]") declType = CGWrapper(declType, pre="&[", post="]")
@ -7868,6 +7909,7 @@ def argument_type(descriptorProvider, ty, optional=False, defaultValue=None, var
if type_needs_auto_root(ty): if type_needs_auto_root(ty):
declType = CGTemplatedType("CustomAutoRooterGuard", declType) declType = CGTemplatedType("CustomAutoRooterGuard", declType)
assert declType is not None
return declType.define() return declType.define()
@ -8144,6 +8186,14 @@ class CallbackMember(CGNativeMember):
self.exceptionCode = "return Err(JSFailed);\n" self.exceptionCode = "return Err(JSFailed);\n"
self.body = self.getImpl() self.body = self.getImpl()
@abstractmethod
def getRvalDecl(self) -> str:
raise NotImplementedError
@abstractmethod
def getCall(self) -> str:
raise NotImplementedError
def getImpl(self): def getImpl(self):
argvDecl = ( argvDecl = (
"rooted_vec!(let mut argv);\n" "rooted_vec!(let mut argv);\n"
@ -8299,6 +8349,18 @@ class CallbackMethod(CallbackMember):
else: else:
return "rooted!(in(*cx) let mut rval = UndefinedValue());\n" return "rooted!(in(*cx) let mut rval = UndefinedValue());\n"
@abstractmethod
def getCallableDecl(self) -> str:
raise NotImplementedError
@abstractmethod
def getThisObj(self) -> str:
raise NotImplementedError
@abstractmethod
def getCallGuard(self) -> str:
raise NotImplementedError
def getCall(self): def getCall(self):
if self.argCount > 0: if self.argCount > 0:
argv = "argv.as_ptr() as *const JSVal" argv = "argv.as_ptr() as *const JSVal"
@ -8746,7 +8808,7 @@ impl {base} {{
if downcast: if downcast:
hierarchy[descriptor.interface.parent.identifier.name].append(name) hierarchy[descriptor.interface.parent.identifier.name].append(name)
typeIdCode = [] typeIdCode: list = []
topTypeVariants = [ topTypeVariants = [
("ID used by abstract interfaces.", "pub abstract_: ()"), ("ID used by abstract interfaces.", "pub abstract_: ()"),
("ID used by interfaces that are not castable.", "pub alone: ()"), ("ID used by interfaces that are not castable.", "pub alone: ()"),

View file

@ -13,7 +13,7 @@ class Configuration:
Represents global configuration state based on IDL parse data and Represents global configuration state based on IDL parse data and
the configuration file. the configuration file.
""" """
def __init__(self, filename, parseData): def __init__(self, filename, parseData) -> None:
# Read the configuration file. # Read the configuration file.
glbl = {} glbl = {}
exec(compile(open(filename).read(), filename, 'exec'), glbl) exec(compile(open(filename).read(), filename, 'exec'), glbl)
@ -96,7 +96,7 @@ class Configuration:
def getter(x): def getter(x):
return x.isGlobal() return x.isGlobal()
elif key == 'isInline': elif key == 'isInline':
def getter(x): def getter(x) -> bool:
return x.interface.getExtendedAttribute('Inline') is not None return x.interface.getExtendedAttribute('Inline') is not None
elif key == 'isExposedConditionally': elif key == 'isExposedConditionally':
def getter(x): def getter(x):
@ -160,7 +160,7 @@ class Configuration:
class NoSuchDescriptorError(TypeError): class NoSuchDescriptorError(TypeError):
def __init__(self, str): def __init__(self, str) -> None:
TypeError.__init__(self, str) TypeError.__init__(self, str)
@ -168,7 +168,7 @@ class DescriptorProvider:
""" """
A way of getting descriptors for interface names A way of getting descriptors for interface names
""" """
def __init__(self, config): def __init__(self, config) -> None:
self.config = config self.config = config
def getDescriptor(self, interfaceName): def getDescriptor(self, interfaceName):
@ -190,7 +190,7 @@ class Descriptor(DescriptorProvider):
""" """
Represents a single descriptor for an interface. See Bindings.conf. Represents a single descriptor for an interface. See Bindings.conf.
""" """
def __init__(self, config, interface, desc): def __init__(self, config, interface, desc) -> None:
DescriptorProvider.__init__(self, config) DescriptorProvider.__init__(self, config)
self.interface = interface self.interface = interface
@ -277,7 +277,7 @@ class Descriptor(DescriptorProvider):
self.hasDefaultToJSON = False self.hasDefaultToJSON = False
def addOperation(operation, m): def addOperation(operation: str, m) -> None:
if not self.operations[operation]: if not self.operations[operation]:
self.operations[operation] = m self.operations[operation] = m
@ -296,7 +296,7 @@ class Descriptor(DescriptorProvider):
if not m.isMethod(): if not m.isMethod():
continue continue
def addIndexedOrNamedOperation(operation, m): def addIndexedOrNamedOperation(operation: str, m) -> None:
if not self.isGlobal(): if not self.isGlobal():
self.proxy = True self.proxy = True
if m.isIndexed(): if m.isIndexed():
@ -333,8 +333,8 @@ class Descriptor(DescriptorProvider):
# array of extended attributes. # array of extended attributes.
self.extendedAttributes = {'all': {}, 'getterOnly': {}, 'setterOnly': {}} self.extendedAttributes = {'all': {}, 'getterOnly': {}, 'setterOnly': {}}
def addExtendedAttribute(attribute, config): def addExtendedAttribute(attribute, config) -> None:
def add(key, members, attribute): def add(key: str, members, attribute) -> None:
for member in members: for member in members:
self.extendedAttributes[key].setdefault(member, []).append(attribute) self.extendedAttributes[key].setdefault(member, []).append(attribute)
@ -398,17 +398,17 @@ class Descriptor(DescriptorProvider):
def internalNameFor(self, name): def internalNameFor(self, name):
return self._internalNames.get(name, name) return self._internalNames.get(name, name)
def hasNamedPropertiesObject(self): def hasNamedPropertiesObject(self) -> bool:
if self.interface.isExternal(): if self.interface.isExternal():
return False return False
return self.isGlobal() and self.supportsNamedProperties() return self.isGlobal() and self.supportsNamedProperties()
def supportsNamedProperties(self): def supportsNamedProperties(self) -> bool:
return self.operations['NamedGetter'] is not None return self.operations['NamedGetter'] is not None
def getExtendedAttributes(self, member, getter=False, setter=False): def getExtendedAttributes(self, member, getter=False, setter=False):
def maybeAppendInfallibleToAttrs(attrs, throws): def maybeAppendInfallibleToAttrs(attrs, throws) -> None:
if throws is None: if throws is None:
attrs.append("infallible") attrs.append("infallible")
elif throws is True: elif throws is True:
@ -442,7 +442,7 @@ class Descriptor(DescriptorProvider):
parent = parent.parent parent = parent.parent
return None return None
def supportsIndexedProperties(self): def supportsIndexedProperties(self) -> bool:
return self.operations['IndexedGetter'] is not None return self.operations['IndexedGetter'] is not None
def isMaybeCrossOriginObject(self): def isMaybeCrossOriginObject(self):
@ -471,7 +471,7 @@ class Descriptor(DescriptorProvider):
def isExposedConditionally(self): def isExposedConditionally(self):
return self.interface.isExposedConditionally() return self.interface.isExposedConditionally()
def isGlobal(self): def isGlobal(self) -> bool:
""" """
Returns true if this is the primary interface for a global object Returns true if this is the primary interface for a global object
of some sort. of some sort.

View file

@ -13,7 +13,7 @@ SERVO_ROOT = os.path.abspath(os.path.join(SCRIPT_PATH, "..", "..", ".."))
FILTER_PATTERN = re.compile("// skip-unless ([A-Z_]+)\n") FILTER_PATTERN = re.compile("// skip-unless ([A-Z_]+)\n")
def main(): def main() -> None:
os.chdir(os.path.join(os.path.dirname(__file__))) os.chdir(os.path.join(os.path.dirname(__file__)))
sys.path.insert(0, os.path.join(SERVO_ROOT, "third_party", "WebIDL")) sys.path.insert(0, os.path.join(SERVO_ROOT, "third_party", "WebIDL"))
sys.path.insert(0, os.path.join(SERVO_ROOT, "third_party", "ply")) sys.path.insert(0, os.path.join(SERVO_ROOT, "third_party", "ply"))
@ -84,13 +84,13 @@ def main():
f.write(module.encode("utf-8")) f.write(module.encode("utf-8"))
def make_dir(path): def make_dir(path: str):
if not os.path.exists(path): if not os.path.exists(path):
os.makedirs(path) os.makedirs(path)
return path return path
def generate(config, name, filename): def generate(config, name: str, filename: str) -> None:
from codegen import GlobalGenRoots from codegen import GlobalGenRoots
root = getattr(GlobalGenRoots, name)(config) root = getattr(GlobalGenRoots, name)(config)
code = root.define() code = root.define()
@ -98,8 +98,8 @@ def generate(config, name, filename):
f.write(code.encode("utf-8")) f.write(code.encode("utf-8"))
def add_css_properties_attributes(css_properties_json, parser): def add_css_properties_attributes(css_properties_json: str, parser) -> None:
def map_preference_name(preference_name: str): def map_preference_name(preference_name: str) -> str:
"""Map between Stylo preference names and Servo preference names as the """Map between Stylo preference names and Servo preference names as the
`css-properties.json` file is generated by Stylo. This should be kept in sync with the `css-properties.json` file is generated by Stylo. This should be kept in sync with the
preference mapping done in `components/servo_config/prefs.rs`, which handles the runtime version of preference mapping done in `components/servo_config/prefs.rs`, which handles the runtime version of
@ -132,7 +132,7 @@ def add_css_properties_attributes(css_properties_json, parser):
parser.parse(idl, "CSSStyleDeclaration_generated.webidl") parser.parse(idl, "CSSStyleDeclaration_generated.webidl")
def attribute_names(property_name): def attribute_names(property_name: str):
# https://drafts.csswg.org/cssom/#dom-cssstyledeclaration-dashed-attribute # https://drafts.csswg.org/cssom/#dom-cssstyledeclaration-dashed-attribute
if property_name != "float": if property_name != "float":
yield property_name yield property_name
@ -145,11 +145,11 @@ def attribute_names(property_name):
# https://drafts.csswg.org/cssom/#dom-cssstyledeclaration-webkit-cased-attribute # https://drafts.csswg.org/cssom/#dom-cssstyledeclaration-webkit-cased-attribute
if property_name.startswith("-webkit-"): if property_name.startswith("-webkit-"):
yield "".join(camel_case(property_name), True) yield "".join(camel_case(property_name, True))
# https://drafts.csswg.org/cssom/#css-property-to-idl-attribute # https://drafts.csswg.org/cssom/#css-property-to-idl-attribute
def camel_case(chars, webkit_prefixed=False): def camel_case(chars: str, webkit_prefixed: bool = False):
if webkit_prefixed: if webkit_prefixed:
chars = chars[1:] chars = chars[1:]
next_is_uppercase = False next_is_uppercase = False

View file

@ -34,9 +34,12 @@ search-path = [
"tests/wpt/tests/tools/wptserve", "tests/wpt/tests/tools/wptserve",
"python/mach", "python/mach",
"python/wpt", "python/wpt",
"third_party/WebIDL",
"components/script_bindings/codegen",
] ]
project-includes = [ project-includes = [
"python/**/*.py", "python/**/*.py",
"components/script_bindings",
] ]
project-excludes = [ project-excludes = [
"**/venv/**", "**/venv/**",

View file

@ -12,6 +12,7 @@ import string
import traceback import traceback
from collections import OrderedDict, defaultdict from collections import OrderedDict, defaultdict
from itertools import chain from itertools import chain
from typing import Any
from ply import lex, yacc from ply import lex, yacc
@ -2522,7 +2523,7 @@ class IDLEnum(IDLObjectWithIdentifier):
class IDLType(IDLObject): class IDLType(IDLObject):
Tags = enum( Tags: Any = enum(
# The integer types # The integer types
"int8", "int8",
"uint8", "uint8",

View file

@ -0,0 +1,21 @@
diff --git a/third_party/WebIDL/WebIDL.py b/third_party/WebIDL/WebIDL.py
index b742a06bddd..babad83322f 100644
--- a/third_party/WebIDL/WebIDL.py
+++ b/third_party/WebIDL/WebIDL.py
@@ -12,6 +12,7 @@ import string
import traceback
from collections import OrderedDict, defaultdict
from itertools import chain
+from typing import Any
from ply import lex, yacc
@@ -2527,7 +2528,7 @@ class IDLEnum(IDLObjectWithIdentifier):
class IDLType(IDLObject):
- Tags = enum(
+ Tags: Any = enum(
# The integer types
"int8",
"uint8",

View file

@ -8,6 +8,7 @@ patch < like-as-iterable.patch
patch < builtin-array.patch patch < builtin-array.patch
patch < array-type.patch patch < array-type.patch
patch < transferable.patch patch < transferable.patch
patch < idltype-tags-type-hint.patch
wget https://hg.mozilla.org/mozilla-central/archive/tip.zip/dom/bindings/parser/tests/ -O tests.zip wget https://hg.mozilla.org/mozilla-central/archive/tip.zip/dom/bindings/parser/tests/ -O tests.zip
rm -r tests rm -r tests