diff --git a/components/script_bindings/codegen/codegen.py b/components/script_bindings/codegen/codegen.py index 50a89a7cad6..0a7f97db6ea 100644 --- a/components/script_bindings/codegen/codegen.py +++ b/components/script_bindings/codegen/codegen.py @@ -6,17 +6,20 @@ # fmt: off +from __future__ import annotations + from WebIDL import IDLUnionType from WebIDL import IDLSequenceType from collections import defaultdict from itertools import groupby -from typing import Optional -from collections.abc import Generator +from typing import cast, Optional, Any, Generic, TypeVar +from collections.abc import Generator, Callable, Iterator from abc import abstractmethod import operator import os import re +from re import Match import string import textwrap import functools @@ -40,6 +43,12 @@ from WebIDL import ( IDLTypedefType, IDLUndefinedValue, IDLWrapperType, + IDLRecordType, + IDLAttribute, + IDLConst, + IDLInterfaceOrNamespace, + IDLValue, + IDLMethod, ) from configuration import ( @@ -180,11 +189,11 @@ def containsDomInterface(t: IDLObject, logging: bool = False) -> bool: return False -def toStringBool(arg) -> str: +def toStringBool(arg: bool) -> str: return str(not not arg).lower() -def toBindingNamespace(arg) -> str: +def toBindingNamespace(arg: str) -> str: """ Namespaces are *_Bindings @@ -193,7 +202,7 @@ def toBindingNamespace(arg) -> str: return re.sub("((_workers)?$)", "_Binding\\1", MakeNativeName(arg)) -def toBindingModuleFile(arg) -> str: +def toBindingModuleFile(arg: str) -> str: """ Module files are *Bindings @@ -203,8 +212,9 @@ def toBindingModuleFile(arg) -> str: def toBindingModuleFileFromDescriptor(desc: Descriptor) -> str: - if desc.maybeGetSuperModule() is not None: - return toBindingModuleFile(desc.maybeGetSuperModule()) + isSuperModule = desc.maybeGetSuperModule() + if isSuperModule is not None: + return toBindingModuleFile(isSuperModule) else: return toBindingModuleFile(desc.name) @@ -218,16 +228,22 @@ def stripTrailingWhitespace(text: str) -> str: return f"{joined_lines}{tail}" -def innerContainerType(type): +def innerContainerType(type: IDLType) -> IDLType: assert type.isSequence() or type.isRecord() + assert isinstance(type, (IDLSequenceType, IDLRecordType, IDLNullableType)) return type.inner.inner if type.nullable() else type.inner -def wrapInNativeContainerType(type, inner): +def wrapInNativeContainerType(type: IDLType, inner: CGThing | None) -> CGThing: if type.isSequence(): return CGWrapper(inner, pre="Vec<", post=">") elif type.isRecord(): - key = type.inner.keyType if type.nullable() else type.keyType + if type.nullable(): + assert isinstance(type, IDLNullableType) + key = type.inner.keyType + else: + assert isinstance(type, IDLRecordType) + key = type.keyType return CGRecord(key, inner) else: raise TypeError(f"Unexpected container type {type}") @@ -304,7 +320,7 @@ fill_multiline_substitution_re = re.compile(r"( *)\$\*{(\w+)}(\n)?") @functools.cache -def compile_fill_template(template): +def compile_fill_template(template: str) -> tuple[string.Template, list[tuple[str, str, int]]]: """ Helper function for fill(). Given the template string passed to fill(), do the reusable part of template processing and return a pair (t, @@ -319,7 +335,7 @@ def compile_fill_template(template): assert t.endswith("\n") or "\n" not in t argModList = [] - def replace(match): + def replace(match: Match[str]) -> str: """ Replaces a line like ' $*{xyz}\n' with '${xyz_n}', where n is the indent depth, and add a corresponding entry to @@ -345,7 +361,7 @@ def compile_fill_template(template): return (string.Template(t), argModList) -def fill(template, **args): +def fill(template: str, **args: str) -> str: """ Convenience function for filling in a multiline template. @@ -403,12 +419,12 @@ class CGMethodCall(CGThing): signatures and generation of a call to that signature. """ cgRoot: CGThing - def __init__(self, argsPre, nativeMethodName, static, descriptor, method): + def __init__(self, argsPre: list[str], nativeMethodName: str, static: bool, descriptor: Descriptor, method: IDLMethod) -> None: CGThing.__init__(self) methodName = f'\\"{descriptor.interface.identifier.name}.{method.identifier.name}\\"' - def requiredArgCount(signature): + def requiredArgCount(signature: tuple[IDLType, list[IDLArgument]]) -> int: arguments = signature[1] if len(arguments) == 0: return 0 @@ -419,7 +435,7 @@ class CGMethodCall(CGThing): signatures = method.signatures() - def getPerSignatureCall(signature, argConversionStartsAt=0): + def getPerSignatureCall(signature: tuple[IDLType, list[IDLArgument]], argConversionStartsAt: int = 0) -> CGThing: signatureIndex = signatures.index(signature) return CGPerSignatureCall(signature[0], argsPre, signature[1], f"{nativeMethodName}{'_' * signatureIndex}", @@ -484,7 +500,7 @@ class CGMethodCall(CGThing): # Select the right overload from our set. distinguishingArg = f"HandleValue::from_raw(args.get({distinguishingIndex}))" - def pickFirstSignature(condition, filterLambda): + def pickFirstSignature(condition: str | None, filterLambda: Callable[[Any], bool]) -> bool: sigs = list(filter(filterLambda, possibleSignatures)) assert len(sigs) < 2 if len(sigs) > 0: @@ -613,25 +629,30 @@ class CGMethodCall(CGThing): self.cgRoot = CGWrapper(CGList(overloadCGThings, "\n"), pre="\n") - def define(self): + def define(self) -> str: return self.cgRoot.define() -def dictionaryHasSequenceMember(dictionary): +def dictionaryHasSequenceMember(dictionary: IDLDictionary) -> bool: return (any(typeIsSequenceOrHasSequenceMember(m.type) for m in dictionary.members) or (dictionary.parent + # pyrefly: ignore # bad-argument-type and dictionaryHasSequenceMember(dictionary.parent))) -def typeIsSequenceOrHasSequenceMember(type): +def typeIsSequenceOrHasSequenceMember(type: IDLType) -> bool: if type.nullable(): + assert isinstance(type, IDLNullableType) type = type.inner if type.isSequence(): return True if type.isDictionary(): + # pyrefly: ignore # missing-attribute return dictionaryHasSequenceMember(type.inner) if type.isUnion(): + assert isinstance(type, IDLUnionType) + assert type.flatMemberTypes is not None return any(typeIsSequenceOrHasSequenceMember(m.type) for m in type.flatMemberTypes) return False @@ -653,7 +674,7 @@ class JSToNativeConversionInfo(): """ An object representing information about a JS-to-native conversion. """ - def __init__(self, template, default=None, declType=None): + def __init__(self, template: str | CGThing, default: str | None = None, declType: CGThing | None = None) -> None: """ template: A string representing the conversion code. This will have template substitution performed on it as follows: @@ -673,17 +694,16 @@ class JSToNativeConversionInfo(): self.declType = declType -def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None, +def getJSToNativeConversionInfo(type: IDLType, descriptorProvider: DescriptorProvider, failureCode: str | None = None, isDefinitelyObject: bool = False, isMember: bool | str = False, - isArgument=False, - isAutoRooted=False, - invalidEnumValueFatal=True, - defaultValue=None, - exceptionCode=None, + isArgument: bool = False, + isAutoRooted: bool = False, + invalidEnumValueFatal: bool = True, + defaultValue: IDLValue | None = None, + exceptionCode: str | None = None, allowTreatNonObjectAsNull: bool = False, - isCallbackReturnValue=False, - sourceDescription="value") -> JSToNativeConversionInfo: + sourceDescription: str = "value") -> JSToNativeConversionInfo: """ Get a template for converting a JS value to a native object based on the given type and descriptor. If failureCode is given, then we're actually @@ -748,13 +768,13 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None, else: failOrPropagate = failureCode - def handleOptional(template, declType, default): + def handleOptional(template: str, declType: CGThing | None, default: str | None) -> JSToNativeConversionInfo: assert (defaultValue is None) == (default is None) return JSToNativeConversionInfo(template, default, declType) # Helper functions for dealing with failures due to the JS value being the # wrong type of value. - def onFailureNotAnObject(failureCode): + def onFailureNotAnObject(failureCode: str | None) -> CGThing: return CGWrapper( CGGeneric( failureCode @@ -762,14 +782,14 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None, f'{exceptionCode}')), post="\n") - def onFailureNotCallable(failureCode): + def onFailureNotCallable(failureCode: str | None) -> CGThing: return CGGeneric( failureCode or (f'throw_type_error(*cx, \"{firstCap(sourceDescription)} is not callable.\");\n' f'{exceptionCode}')) # A helper function for handling default values. - def handleDefault(nullValue): + def handleDefault(nullValue: str) -> str | None: if defaultValue is None: return None @@ -787,8 +807,8 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None, # A helper function for wrapping up the template body for # possibly-nullable objecty stuff - def wrapObjectTemplate(templateBody, nullValue, isDefinitelyObject, type, - failureCode=None): + def wrapObjectTemplate(templateBody: str, nullValue: str, isDefinitelyObject: bool, type: IDLType, + failureCode: str | None = None) -> str: if not isDefinitelyObject: # Handle the non-object cases by wrapping up the whole # thing in an if cascade. @@ -806,7 +826,7 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None, return templateBody # A helper function for types that implement FromJSValConvertible trait - def fromJSValTemplate(config, errorHandler, exceptionCode): + def fromJSValTemplate(config: str, errorHandler: str, exceptionCode: str) -> str: return f"""match FromJSValConvertible::from_jsval(*cx, ${{val}}, {config}) {{ Ok(ConversionResult::Success(value)) => value, Ok(ConversionResult::Failure(error)) => {{ @@ -839,12 +859,16 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None, declType = CGGeneric(union_native_type(type)) if type.nullable(): declType = CGWrapper(declType, pre="Option<", post=" >") + assert isinstance(type, (IDLUnionType, IDLNullableType)) templateBody = fromJSValTemplate("()", failOrPropagate, exceptionCode) + flatMemberTypes = type.unroll().flatMemberTypes + assert flatMemberTypes is not None + dictionaries = [ memberType - for memberType in type.unroll().flatMemberTypes + for memberType in flatMemberTypes if memberType.isDictionary() ] if (defaultValue @@ -911,8 +935,7 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None, if type.isGeckoInterface(): assert not isEnforceRange and not isClamp - descriptor = descriptorProvider.getDescriptor( - type.unroll().inner.identifier.name) + descriptor = descriptorProvider.getDescriptor(type.unroll().inner.identifier.name) # pyrefly: ignore # missing-attribute if descriptor.interface.isCallback(): name = descriptor.nativeType @@ -1091,6 +1114,7 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None, if type.nullable(): raise TypeError("We don't support nullable enumerated arguments " "yet") + # pyrefly: ignore # missing-attribute enum = type.inner.identifier.name if invalidEnumValueFatal: handleInvalidEnumValueCode = failureCode or f"throw_type_error(*cx, &error); {exceptionCode}" @@ -1113,6 +1137,7 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None, assert not type.treatNonObjectAsNull() or type.nullable() assert not type.treatNonObjectAsNull() or not type.treatNonCallableAsNull() + # pyrefly: ignore # missing-attribute callback = type.unroll().callback declType = CGGeneric(f"{callback.identifier.name}") finalDeclType = CGTemplatedType("Rc", declType) @@ -1219,6 +1244,7 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None, # There are no nullable dictionaries assert not type.nullable() or (isMember and isMember != "Dictionary") + # pyrefly: ignore # missing-attribute typeName = f"{CGDictionary.makeModuleName(type.inner)}::{CGDictionary.makeDictionaryName(type.inner)}" if containsDomInterface(type): typeName += "" @@ -1273,9 +1299,9 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None, return handleOptional(template, declType, defaultStr) -def instantiateJSToNativeConversionTemplate(templateBody, replacements, - declType, declName, - needsAutoRoot=False) -> CGThing: +def instantiateJSToNativeConversionTemplate(templateBody: str, replacements: dict[str, Any], + declType: CGThing | None, declName: str, + needsAutoRoot: bool = False) -> CGThing: """ Take the templateBody and declType as returned by getJSToNativeConversionInfo, a set of replacements as required by the @@ -1309,7 +1335,7 @@ def instantiateJSToNativeConversionTemplate(templateBody, replacements, return result -def convertConstIDLValueToJSVal(value): +def convertConstIDLValueToJSVal(value: IDLValue) -> str | None: if isinstance(value, IDLNullValue): return "ConstantVal::NullVal" tag = value.type.tag() @@ -1335,8 +1361,8 @@ class CGArgumentConverter(CGThing): unwrap the argument to the right native type. """ converter: CGThing - def __init__(self, argument, index, args, argc, descriptorProvider, - invalidEnumValueFatal=True): + def __init__(self, argument: IDLArgument, index: int, args: str, argc: str, descriptorProvider: DescriptorProvider, + invalidEnumValueFatal: bool=True) -> None: CGThing.__init__(self) assert not argument.defaultValue or argument.optional @@ -1407,7 +1433,7 @@ if {argc} > {index} {{ }} }}""") - def define(self): + def define(self) -> str: return self.converter.define() @@ -1427,34 +1453,41 @@ def wrapForType(jsvalRef: str, result: str = 'result', successCode: str = 'true' return wrap -def typeNeedsCx(type, retVal=False): +def typeNeedsCx(type: IDLType, retVal: bool = False) -> bool: if type is None: return False if type.nullable(): + assert isinstance(type, IDLNullableType) type = type.inner if type.isSequence(): + assert isinstance(type, IDLSequenceType) type = type.inner if type.isUnion(): - return any(typeNeedsCx(t) for t in type.unroll().flatMemberTypes) + assert isinstance(type, IDLUnionType) + flatMemberTypes = type.unroll().flatMemberTypes + assert flatMemberTypes is not None + + return any(typeNeedsCx(t) for t in flatMemberTypes) if retVal and type.isSpiderMonkeyInterface(): return True return type.isAny() or type.isObject() -def returnTypeNeedsOutparam(type): +def returnTypeNeedsOutparam(type: IDLType) -> bool: if type.nullable(): + assert isinstance(type, IDLNullableType) type = type.inner return type.isAny() -def outparamTypeFromReturnType(type): +def outparamTypeFromReturnType(type: IDLType) -> str: if type.isAny(): return "MutableHandleValue" raise TypeError(f"Don't know how to handle {type} as an outparam") # Returns a conversion behavior suitable for a type -def getConversionConfigForType(type, isEnforceRange, isClamp, treatNullAs): +def getConversionConfigForType(type: IDLType, isEnforceRange: bool, isClamp: bool, treatNullAs: str) -> str: if type.isSequence() or type.isRecord(): return getConversionConfigForType(innerContainerType(type), isEnforceRange, isClamp, treatNullAs) if type.isDOMString(): @@ -1483,7 +1516,7 @@ def getConversionConfigForType(type, isEnforceRange, isClamp, treatNullAs): return "()" -def builtin_return_type(returnType): +def builtin_return_type(returnType: IDLType) -> CGThing: result = CGGeneric(builtinNames[returnType.tag()]) if returnType.nullable(): result = CGWrapper(result, pre="Option<", post=">") @@ -1491,7 +1524,7 @@ def builtin_return_type(returnType): # Returns a CGThing containing the type of the return value. -def getRetvalDeclarationForType(returnType, descriptorProvider): +def getRetvalDeclarationForType(returnType: IDLType | None, descriptorProvider: DescriptorProvider) -> CGThing: if returnType is None or returnType.isUndefined(): # Nothing to declare return CGGeneric("()") @@ -1515,6 +1548,7 @@ def getRetvalDeclarationForType(returnType, descriptorProvider): result = CGWrapper(result, pre="Option<", post=">") return result if returnType.isEnum(): + # pyrefly: ignore # missing-attribute result = CGGeneric(returnType.unroll().inner.identifier.name) if returnType.nullable(): result = CGWrapper(result, pre="Option<", post=">") @@ -1524,12 +1558,14 @@ def getRetvalDeclarationForType(returnType, descriptorProvider): return CGGeneric("Rc") if returnType.isGeckoInterface(): descriptor = descriptorProvider.getDescriptor( + # pyrefly: ignore # missing-attribute returnType.unroll().inner.identifier.name) result = CGGeneric(descriptor.returnType) if returnType.nullable(): result = CGWrapper(result, pre="Option<", post=">") return result if returnType.isCallback(): + # pyrefly: ignore # missing-attribute callback = returnType.unroll().callback result = CGGeneric(f'Rc<{getModuleFromObject(callback)}::{callback.identifier.name}>') if returnType.nullable(): @@ -1555,6 +1591,7 @@ def getRetvalDeclarationForType(returnType, descriptorProvider): return result if returnType.isDictionary(): nullable = returnType.nullable() + # pyrefly: ignore # missing-attribute dictName = returnType.inner.name if nullable else returnType.name generic = "" if containsDomInterface(returnType) else "" result = CGGeneric(f"{dictName}{generic}") @@ -1567,7 +1604,7 @@ def getRetvalDeclarationForType(returnType, descriptorProvider): raise TypeError(f"Don't know how to declare return value for {returnType}") -def MemberCondition(pref, func, exposed, secure): +def MemberCondition(pref: str | None, func: str | None, exposed: set | None, secure: bool | None) -> list[str]: """ A string representing the condition for a member to actually be exposed. Any of the arguments can be None. If not None, they should have the @@ -1597,8 +1634,9 @@ def MemberCondition(pref, func, exposed, secure): conditions.append("Condition::Satisfied") return conditions +PropertyDefinerElementType = TypeVar('PropertyDefinerElementType') -class PropertyDefiner: +class PropertyDefiner(Generic[PropertyDefinerElementType]): """ A common superclass for defining things on prototype objects. @@ -1607,28 +1645,29 @@ class PropertyDefiner: things exposed to web pages. """ name: str - regular: list - def __init__(self, descriptor, name: str): + regular: list[PropertyDefinerElementType] + + def __init__(self, descriptor: Descriptor, name: str) -> None: self.descriptor = descriptor self.name = name - def variableName(self): + def variableName(self) -> str: return f"s{self.name}" - def length(self): + def length(self) -> int: return len(self.regular) @abstractmethod - def generateArray(self, array, name) -> str: + def generateArray(self, array: list[PropertyDefinerElementType], name: str) -> str: raise NotImplementedError - def __str__(self): + def __str__(self) -> str: # We only need to generate id arrays for things that will end # up used via ResolveProperty or EnumerateProperties. return self.generateArray(self.regular, self.variableName()) @staticmethod - def getStringAttr(member, name): + def getStringAttr(member: IDLInterfaceMember, name: str) -> str | None: attr = member.getExtendedAttribute(name) if attr is None: return None @@ -1638,7 +1677,7 @@ class PropertyDefiner: return attr[0] @staticmethod - def getControllingCondition(interfaceMember, descriptor): + def getControllingCondition(interfaceMember: IDLInterfaceMember, descriptor: Descriptor) -> list[str]: return MemberCondition( PropertyDefiner.getStringAttr(interfaceMember, "Pref"), @@ -1647,8 +1686,16 @@ class PropertyDefiner: interfaceMember.exposureSet, interfaceMember.getExtendedAttribute("SecureContext")) - def generateGuardedArray(self, array, name, specTemplate, specTerminator, - specType, getCondition, getDataTuple): + def generateGuardedArray( + self, + array: list[PropertyDefinerElementType], + name: str, + specTemplate: Callable[[PropertyDefinerElementType], str] | str, + specTerminator: str | None, + specType: str, + getCondition: Callable[[PropertyDefinerElementType, Descriptor], list[str]], + getDataTuple: Callable[[PropertyDefinerElementType], tuple[str, ...]] + ) -> str: """ This method generates our various arrays. @@ -1677,8 +1724,8 @@ class PropertyDefiner: specs = [] prefableSpecs = [] prefableTemplate = ' Guard::new(%s, (%s)[%d])' - origTemplate = specTemplate if isinstance(specTemplate, str): + origTemplate = specTemplate specTemplate = lambda _: origTemplate # noqa for cond, members in groupby(array, lambda m: getCondition(m, self.descriptor)): @@ -1710,8 +1757,16 @@ pub(crate) fn init_{name}_prefs() {{ return f"{specsArray}{initSpecs}{prefArray}{initPrefs}" - def generateUnguardedArray(self, array, name, specTemplate, specTerminator, - specType, getCondition, getDataTuple): + def generateUnguardedArray( + self, + array: list[PropertyDefinerElementType], + name: str, + specTemplate: Callable[[PropertyDefinerElementType], str] | str, + specTerminator: str, + specType: str, + getCondition: Callable[[PropertyDefinerElementType, Descriptor], list[str]], + getDataTuple: Callable[[PropertyDefinerElementType], tuple[str, ...]] + ) -> str: """ Takes the same set of parameters as generateGuardedArray but instead generates a single, flat array of type `&[specType]` that contains all @@ -1723,8 +1778,8 @@ pub(crate) fn init_{name}_prefs() {{ groups = groupby(array, lambda m: getCondition(m, self.descriptor)) assert len(list(groups)) == 1 - origTemplate = specTemplate if isinstance(specTemplate, str): + origTemplate = specTemplate specTemplate = lambda _: origTemplate # noqa specsArray = [specTemplate(m) % getDataTuple(m) for m in array] @@ -1741,7 +1796,7 @@ pub(crate) fn init_{name}() {{ # The length of a method is the minimum of the lengths of the # argument lists of all its overloads. -def methodLength(method): +def methodLength(method: IDLMethod) -> int: signatures = method.signatures() return min( len([arg for arg in arguments if not arg.optional and not arg.variadic]) @@ -1752,7 +1807,7 @@ class MethodDefiner(PropertyDefiner): """ A class for defining methods on a prototype object. """ - def __init__(self, descriptor, name, static, unforgeable, crossorigin=False): + def __init__(self, descriptor: Descriptor, name: str, static: bool, unforgeable: bool, crossorigin: bool = False) -> None: assert not (static and unforgeable) assert not (static and crossorigin) assert not (unforgeable and crossorigin) @@ -1774,7 +1829,7 @@ class MethodDefiner(PropertyDefiner): and (MemberIsLegacyUnforgeable(m, descriptor) == unforgeable or crossorigin)] else: methods = [] - self.regular = [] + self.regular: list[dict[str, Any]] = [] for m in methods: method = self.methodData(m, descriptor, crossorigin) @@ -1789,7 +1844,7 @@ class MethodDefiner(PropertyDefiner): # failing. Also, may be more tiebreak rules to implement once spec bug # is resolved. # https://www.w3.org/Bugs/Public/show_bug.cgi?id=28592 - def hasIterator(methods, regular): + def hasIterator(methods: list[IDLMethod], regular: list[dict[str, Any]]) -> bool: return (any("@@iterator" in m.aliases for m in methods) or any("@@iterator" == r["name"] for r in regular)) @@ -1800,8 +1855,8 @@ class MethodDefiner(PropertyDefiner): if (not static and not unforgeable and not crossorigin - and descriptor.supportsIndexedProperties()): # noqa - if hasIterator(methods, self.regular): # noqa + and descriptor.supportsIndexedProperties()): + if hasIterator(methods, self.regular): raise TypeError("Cannot have indexed getter/attr on " f"interface {self.descriptor.interface.identifier.name} with other members " "that generate @@iterator, such as " @@ -1874,7 +1929,7 @@ class MethodDefiner(PropertyDefiner): self.crossorigin = crossorigin @staticmethod - def methodData(m, descriptor, crossorigin): + def methodData(m: IDLMethod, descriptor: Descriptor, crossorigin: bool) -> dict[str, Any]: return { "name": m.identifier.name, "methodInfo": not m.isStatic(), @@ -1884,14 +1939,14 @@ class MethodDefiner(PropertyDefiner): "returnsPromise": m.returnsPromise() } - def generateArray(self, array, name): + def generateArray(self, array: list[dict[str, Any]], name: str) -> str: if len(array) == 0: return "" - def condition(m, d): + def condition(m: dict[str, Any], d: Descriptor) -> list[str]: return m["condition"] - def specData(m): + def specData(m: dict[str, Any]) -> tuple: flags = m["flags"] if self.unforgeable: flags += " | JSPROP_PERMANENT | JSPROP_READONLY" @@ -1963,7 +2018,7 @@ class MethodDefiner(PropertyDefiner): class AttrDefiner(PropertyDefiner): - def __init__(self, descriptor, name, static, unforgeable, crossorigin=False): + def __init__(self, descriptor: Descriptor, name: str, static: bool, unforgeable: bool, crossorigin: bool = False) -> None: assert not (static and unforgeable) assert not (static and crossorigin) assert not (unforgeable and crossorigin) @@ -1974,7 +2029,7 @@ class AttrDefiner(PropertyDefiner): self.name = name self.descriptor = descriptor - self.regular = [ + self.regular: list[dict[str, Any]] = [ { "name": m.identifier.name, "attr": m, @@ -2002,11 +2057,11 @@ class AttrDefiner(PropertyDefiner): "kind": "JSPropertySpec_Kind::Value", }) - def generateArray(self, array, name): + def generateArray(self, array: list[dict[str, Any]], name: str) -> str: if len(array) == 0: return "" - def getter(attr): + def getter(attr: dict[str, Any]) -> str: attr = attr['attr'] if self.crossorigin and not attr.getExtendedAttribute("CrossOriginReadable"): @@ -2029,7 +2084,7 @@ class AttrDefiner(PropertyDefiner): return f"JSNativeWrapper {{ op: Some({accessor}), info: {jitinfo} }}" - def setter(attr): + def setter(attr: dict[str, Any]) -> str : attr = attr['attr'] if ((self.crossorigin and not attr.getExtendedAttribute("CrossOriginWritable")) @@ -2051,12 +2106,12 @@ class AttrDefiner(PropertyDefiner): return f"JSNativeWrapper {{ op: Some({accessor}), info: {jitinfo} }}" - def condition(m, d): + def condition(m: dict[str, Any], d: Descriptor) -> list[str]: if m["name"] == "@@toStringTag": return MemberCondition(pref=None, func=None, exposed=None, secure=None) return PropertyDefiner.getControllingCondition(m["attr"], d) - def specData(attr): + def specData(attr: dict[str, Any]) -> tuple: if attr["name"] == "@@toStringTag": return (attr["name"][2:], attr["flags"], attr["kind"], str_to_cstr_ptr(self.descriptor.interface.getClassName())) @@ -2067,7 +2122,7 @@ class AttrDefiner(PropertyDefiner): return (str_to_cstr_ptr(attr["attr"].identifier.name), flags, attr["kind"], getter(attr), setter(attr)) - def template(m): + def template(m: dict[str, Any]) -> str: if m["name"] == "@@toStringTag": return """ JSPropertySpec { name: JSPropertySpec_Name { symbol_: SymbolCode::%s as usize + 1 }, @@ -2116,25 +2171,26 @@ class AttrDefiner(PropertyDefiner): condition, specData) -class ConstDefiner(PropertyDefiner): +class ConstDefiner(PropertyDefiner[IDLConst]): """ A class for definining constants on the interface object """ - def __init__(self, descriptor, name): + def __init__(self, descriptor: Descriptor, name: str) -> None: PropertyDefiner.__init__(self, descriptor, name) self.name = name - self.regular = [m for m in descriptor.interface.members if m.isConst()] + self.regular: list[IDLConst] = [m for m in descriptor.interface.members if m.isConst()] - def generateArray(self, array, name): + def generateArray(self, array: list[IDLConst], name: str) -> str: if len(array) == 0: return "" - def specData(const): + def specData(const: IDLConst) -> tuple: return (str_to_cstr(const.identifier.name), convertConstIDLValueToJSVal(const.value)) return self.generateGuardedArray( - array, name, + array, + name, ' ConstantSpec { name: %s, value: %s }', None, 'ConstantSpec', @@ -2380,7 +2436,7 @@ class CGNamespace(CGWrapper): return CGNamespace(namespaces[0], inner, public=public) -def DOMClassTypeId(desc): +def DOMClassTypeId(desc: Descriptor) -> str: protochain = desc.prototypeChain inner = "" if desc.hasDescendants(): @@ -2396,7 +2452,7 @@ def DOMClassTypeId(desc): return f"crate::codegen::InheritTypes::TopTypeId {{ {protochain[0].lower()}: {inner} }}" -def DOMClass(descriptor): +def DOMClass(descriptor: Descriptor) -> str: protoList = [f'PrototypeList::ID::{proto}' for proto in descriptor.prototypeChain] # Pad out the list to the right length with ID::Last so we # guarantee that all the lists are the same length. ID::Last @@ -2540,11 +2596,11 @@ impl {args['selfName']} {{ """ -def str_to_cstr(s): +def str_to_cstr(s: str) -> str: return f'c"{s}"' -def str_to_cstr_ptr(s): +def str_to_cstr_ptr(s: str) -> str: return f'c"{s}".as_ptr()' @@ -2664,7 +2720,7 @@ class CGGeneric(CGThing): class CGCallbackTempRoot(CGGeneric): - def __init__(self, name): + def __init__(self, name: str) -> None: CGGeneric.__init__(self, f"{name.replace('', '::')}::new(cx, ${{val}}.get().to_object())") @@ -2753,7 +2809,13 @@ def UnionTypes( typedefs=[], imports=imports, config=config) -def DomTypes(descriptors, descriptorProvider, dictionaries, callbacks, typedefs, config): +def DomTypes(descriptors: list[Descriptor], + descriptorProvider: DescriptorProvider, + dictionaries: IDLDictionary, + callbacks: IDLCallback, + typedefs: IDLTypedef, + config: Configuration + ) -> CGThing: traits = [ "crate::interfaces::DomHelpers", "js::rust::Trace", @@ -2763,7 +2825,7 @@ def DomTypes(descriptors, descriptorProvider, dictionaries, callbacks, typedefs, joinedTraits = ' + '.join(traits) elements = [CGGeneric(f"pub trait DomTypes: {joinedTraits} where Self: 'static {{\n")] - def fixupInterfaceTypeReferences(typename): + def fixupInterfaceTypeReferences(typename: str) -> str: return typename.replace("D::", "Self::") for descriptor in descriptors: @@ -2864,7 +2926,13 @@ def DomTypes(descriptors, descriptorProvider, dictionaries, callbacks, typedefs, return CGList(imports + elements) -def DomTypeHolder(descriptors, descriptorProvider, dictionaries, callbacks, typedefs, config): +def DomTypeHolder(descriptors: list[Descriptor], + descriptorProvider: DescriptorProvider, + dictionaries: IDLDictionary, + callbacks: IDLCallback, + typedefs: IDLTypedef, + config: Configuration + ) -> CGThing: elements = [ CGGeneric( "#[derive(JSTraceable, MallocSizeOf, PartialEq)]\n" @@ -3075,7 +3143,7 @@ unsafe { return CGList((CGGeneric(cond) for cond in conditions), " &&\n") -def InitLegacyUnforgeablePropertiesOnHolder(descriptor, properties): +def InitLegacyUnforgeablePropertiesOnHolder(descriptor: Descriptor, properties: PropertyArrays) -> CGThing: """ Define the unforgeable properties on the unforgeable holder for the interface represented by descriptor. @@ -3097,7 +3165,7 @@ def InitLegacyUnforgeablePropertiesOnHolder(descriptor, properties): return CGList(unforgeables, "\n") -def CopyLegacyUnforgeablePropertiesToInstance(descriptor): +def CopyLegacyUnforgeablePropertiesToInstance(descriptor: Descriptor) -> str: """ Copy the unforgeable properties from the unforgeable holder for this interface to the instance object we have. @@ -3292,7 +3360,7 @@ DomRoot::from_ref(&*root)\ """) -def toBindingPath(descriptor): +def toBindingPath(descriptor: Descriptor) -> str: module = toBindingModuleFileFromDescriptor(descriptor) namespace = toBindingNamespace(descriptor.interface.identifier.name) return f"{module}::{namespace}" @@ -3958,7 +4026,7 @@ class CGDefineDOMInterfaceMethod(CGAbstractMethod): ) -def needCx(returnType, arguments, considerTypes): +def needCx(returnType: IDLType, arguments, considerTypes: bool) -> bool: return (considerTypes and (typeNeedsCx(returnType, True) or any(typeNeedsCx(a.type) for a in arguments))) @@ -4278,7 +4346,7 @@ let global = D::GlobalScope::from_object(args.callee()); raise NotImplementedError -def GetConstructorNameForReporting(descriptor, ctor): +def GetConstructorNameForReporting(descriptor: Descriptor, ctor: IDLInterfaceOrNamespace) -> str: # Figure out the name of our constructor for reporting purposes. # For unnamed webidl constructors, identifier.name is "constructor" but # the name JS sees is the interface name; for legacy factory functions @@ -5015,7 +5083,7 @@ class CGStaticMethodJitinfo(CGGeneric): ) -def getEnumValueName(value): +def getEnumValueName(value: str) -> str: # Some enum values can be empty strings. Others might have weird # characters in them. Deal with the former by returning "_empty", # deal with possible name collisions from that by throwing if the @@ -5127,7 +5195,7 @@ impl FromJSValConvertible for super::{ident} {{ return self.cgRoot.define() -def convertConstIDLValueToRust(value): +def convertConstIDLValueToRust(value: IDLConst) -> str: tag = value.type.tag() if tag in [IDLType.Tags.int8, IDLType.Tags.uint8, IDLType.Tags.int16, IDLType.Tags.uint16, @@ -5164,11 +5232,13 @@ class CGConstant(CGThing): return f"pub const {name}: {const_type} = {value};\n" -def getUnionTypeTemplateVars(type, descriptorProvider: DescriptorProvider): +def getUnionTypeTemplateVars(type: IDLType, descriptorProvider: DescriptorProvider) -> dict[str, Any]: if type.isGeckoInterface(): + # pyrefly: ignore # missing-attribute name = type.inner.identifier.name typeName = descriptorProvider.getDescriptor(name).returnType elif type.isEnum(): + # pyrefly: ignore # missing-attribute name = type.inner.identifier.name typeName = name elif type.isDictionary(): @@ -5229,7 +5299,7 @@ def getUnionTypeTemplateVars(type, descriptorProvider: DescriptorProvider): } -def traitRequiresManualImpl(name, ty): +def traitRequiresManualImpl(name: str, ty: IDLObject) -> bool: return name == "Clone" and containsDomInterface(ty) @@ -5855,6 +5925,7 @@ class CGProxySpecialOperation(CGPerSignatureCall): Base class for classes for calling an indexed or named special operation (don't use this directly, use the derived classes below). """ + templateValues: dict[str, Any] | None def __init__(self, descriptor, operation): nativeName = MakeNativeName(descriptor.binaryNameFor(operation, False)) operation = descriptor.operations[operation] @@ -5892,11 +5963,9 @@ class CGProxySpecialOperation(CGPerSignatureCall): return args def wrap_return_value(self): - # pyrefly: ignore # missing-attribute if not self.idlNode.isGetter() or self.templateValues is None: return "" - # pyrefly: ignore # missing-argument, bad-unpacking wrap = CGGeneric(wrapForType(**self.templateValues)) wrap = CGIfWrapper("let Some(result) = result", wrap) return f"\n{wrap.define()}" @@ -6583,7 +6652,7 @@ let this = native_from_object_static::<{self.descriptor.concreteType}>(obj).unwr raise NotImplementedError -def finalizeHook(descriptor, hookName, context): +def finalizeHook(descriptor: Descriptor, hookName: str, context: str) -> str: if descriptor.isGlobal(): release = "finalize_global(obj, this);" elif descriptor.weakReferenceable: @@ -6709,7 +6778,7 @@ pub(crate) fn init_proxy_handler_dom_class() {{ class CGInterfaceTrait(CGThing): - def __init__(self, descriptor, descriptorProvider): + def __init__(self, descriptor: Descriptor, descriptorProvider: DescriptorProvider): CGThing.__init__(self) def attribute_arguments(attribute_type, argument=None, inRealm=False, canGc=False, retval=False): @@ -6737,6 +6806,8 @@ class CGInterfaceTrait(CGThing): name = CGSpecializedMethod.makeNativeName(descriptor, m) infallible = 'infallible' in descriptor.getExtendedAttributes(m) for idx, (rettype, arguments) in enumerate(m.signatures()): + rettype = cast(IDLType, rettype) + arguments = cast(list[IDLArgument], arguments) arguments = method_arguments(descriptor, rettype, arguments, inRealm=name in descriptor.inRealmMethods, canGc=name in descriptor.canGcMethods) @@ -7870,7 +7941,7 @@ class CGBindingRoot(CGThing): return stripTrailingWhitespace(self.root.define()) -def type_needs_tracing(t: IDLObject): +def type_needs_tracing(t: IDLObject) -> bool: assert isinstance(t, IDLObject), (t, type(t)) if t.isType(): @@ -7918,15 +7989,16 @@ def type_needs_tracing(t: IDLObject): return False assert False, (t, type(t)) + return False -def is_typed_array(t: IDLType): +def is_typed_array(t: IDLType) -> bool: assert isinstance(t, IDLObject), (t, type(t)) return t.isTypedArray() or t.isArrayBuffer() or t.isArrayBufferView() -def type_needs_auto_root(t: IDLType): +def type_needs_auto_root(t: IDLType) -> bool: """ Certain IDL types, such as `sequence` or `sequence` need to be traced and wrapped via (Custom)AutoRooter @@ -7943,7 +8015,7 @@ def type_needs_auto_root(t: IDLType): return False -def argument_type(descriptorProvider, ty, optional=False, defaultValue=None, variadic=False): +def argument_type(descriptorProvider: DescriptorProvider, ty: IDLType, optional: bool = False, defaultValue=None, variadic: bool = False) -> str: info = getJSToNativeConversionInfo( ty, descriptorProvider, isArgument=True, isAutoRooted=type_needs_auto_root(ty)) @@ -7966,8 +8038,14 @@ def argument_type(descriptorProvider, ty, optional=False, defaultValue=None, var return declType.define() -def method_arguments(descriptorProvider, returnType, arguments, passJSBits=True, trailing=None, - inRealm=False, canGc=False): +def method_arguments(descriptorProvider: DescriptorProvider, + returnType: IDLType, + arguments: list[IDLArgument], + passJSBits: bool = True, + trailing: tuple[str, str] | None = None, + inRealm: bool = False, + canGc: bool = False + ) -> Iterator[tuple[str, str]]: if needCx(returnType, arguments, passJSBits): yield "cx", "SafeJSContext" @@ -7989,7 +8067,7 @@ def method_arguments(descriptorProvider, returnType, arguments, passJSBits=True, yield "rval", outparamTypeFromReturnType(returnType), -def return_type(descriptorProvider, rettype, infallible): +def return_type(descriptorProvider: DescriptorProvider, rettype: IDLType, infallible: bool) -> str: result = getRetvalDeclarationForType(rettype, descriptorProvider) if rettype and returnTypeNeedsOutparam(rettype): result = CGGeneric("()") @@ -8130,11 +8208,11 @@ class CGCallback(CGClass): # We're always fallible -def callbackGetterName(attr, descriptor): +def callbackGetterName(attr: IDLAttribute, descriptor: Descriptor) -> str: return f"Get{MakeNativeName(descriptor.binaryNameFor(attr.identifier.name, attr.isStatic()))}" -def callbackSetterName(attr, descriptor): +def callbackSetterName(attr: IDLAttribute, descriptor: Descriptor) -> str: return f"Set{MakeNativeName(descriptor.binaryNameFor(attr.identifier.name, attr.isStatic()))}" @@ -8185,19 +8263,19 @@ class CGCallbackInterface(CGCallback): class FakeMember(): - def __init__(self): + def __init__(self) -> None: pass - def isStatic(self): + def isStatic(self) -> bool: return False - def isAttr(self): + def isAttr(self) -> bool: return False - def isMethod(self): + def isMethod(self) -> bool: return False - def getExtendedAttribute(self, name): + def getExtendedAttribute(self, name: str) -> None: return None @@ -8275,7 +8353,6 @@ class CallbackMember(CGNativeMember): self.retvalType, self.descriptorProvider, exceptionCode=self.exceptionCode, - isCallbackReturnValue="Callback", # XXXbz we should try to do better here sourceDescription="return value") template = info.template @@ -8628,11 +8705,11 @@ class CGIterableMethodGenerator(CGGeneric): itrMethod=methodName.title())) -def camel_to_upper_snake(s): +def camel_to_upper_snake(s: str) -> str: return "_".join(m.group(0).upper() for m in re.finditer("[A-Z][a-z]*", s)) -def process_arg(expr, arg): +def process_arg(expr: str, arg: IDLArgument) -> str: if arg.type.isGeckoInterface() and not arg.type.unroll().inner.isCallback(): if arg.variadic or arg.type.isSequence(): expr += ".r()"