Be more conservative about safety of dictionary and union values.

Mark dictionaries containing GC values as must_root, and wrap them in
RootedTraceableBox in automatically-generated APIs. To accommodate
union variants that are now flagged as unsafe, add RootedTraceableBox
to union variants that need to be rooted, rather than wrapping the
entire union value.
This commit is contained in:
Josh Matthews 2017-05-26 12:29:31 -04:00
parent e481e8934a
commit da65698c5c
5 changed files with 47 additions and 17 deletions

View file

@ -723,9 +723,6 @@ def getJSToNativeConversionInfo(type, descriptorProvider, failureCode=None,
if type.nullable(): if type.nullable():
declType = CGWrapper(declType, pre="Option<", post=" >") declType = CGWrapper(declType, pre="Option<", post=" >")
if isMember != "Dictionary" and type_needs_tracing(type):
declType = CGTemplatedType("RootedTraceableBox", declType)
templateBody = ("match FromJSValConvertible::from_jsval(cx, ${val}, ()) {\n" templateBody = ("match FromJSValConvertible::from_jsval(cx, ${val}, ()) {\n"
" Ok(ConversionResult::Success(value)) => value,\n" " Ok(ConversionResult::Success(value)) => value,\n"
" Ok(ConversionResult::Failure(error)) => {\n" " Ok(ConversionResult::Failure(error)) => {\n"
@ -1427,6 +1424,8 @@ def getRetvalDeclarationForType(returnType, descriptorProvider):
nullable = returnType.nullable() nullable = returnType.nullable()
dictName = returnType.inner.name if nullable else returnType.name dictName = returnType.inner.name if nullable else returnType.name
result = CGGeneric(dictName) result = CGGeneric(dictName)
if type_needs_tracing(returnType):
result = CGWrapper(result, pre="RootedTraceableBox<", post=">")
if nullable: if nullable:
result = CGWrapper(result, pre="Option<", post=">") result = CGWrapper(result, pre="Option<", post=">")
return result return result
@ -2262,6 +2261,7 @@ def UnionTypes(descriptors, dictionaries, callbacks, typedefs, config):
'dom::bindings::str::ByteString', 'dom::bindings::str::ByteString',
'dom::bindings::str::DOMString', 'dom::bindings::str::DOMString',
'dom::bindings::str::USVString', 'dom::bindings::str::USVString',
'dom::bindings::trace::RootedTraceableBox',
'dom::types::*', 'dom::types::*',
'js::error::throw_type_error', 'js::error::throw_type_error',
'js::jsapi::HandleValue', 'js::jsapi::HandleValue',
@ -4132,15 +4132,23 @@ class CGUnionStruct(CGThing):
self.type = type self.type = type
self.descriptorProvider = descriptorProvider self.descriptorProvider = descriptorProvider
def membersNeedTracing(self):
for t in self.type.flatMemberTypes:
if type_needs_tracing(t):
return True
return False
def define(self): def define(self):
templateVars = map(lambda t: getUnionTypeTemplateVars(t, self.descriptorProvider), templateVars = map(lambda t: (getUnionTypeTemplateVars(t, self.descriptorProvider),
type_needs_tracing(t)),
self.type.flatMemberTypes) self.type.flatMemberTypes)
enumValues = [ enumValues = [
" %s(%s)," % (v["name"], v["typeName"]) for v in templateVars " %s(%s)," % (v["name"], "RootedTraceableBox<%s>" % v["typeName"] if trace else v["typeName"])
for (v, trace) in templateVars
] ]
enumConversions = [ enumConversions = [
" %s::%s(ref inner) => inner.to_jsval(cx, rval)," " %s::%s(ref inner) => inner.to_jsval(cx, rval),"
% (self.type, v["name"]) for v in templateVars % (self.type, v["name"]) for (v, _) in templateVars
] ]
return ("""\ return ("""\
#[derive(JSTraceable)] #[derive(JSTraceable)]
@ -4167,6 +4175,12 @@ class CGUnionConversionStruct(CGThing):
self.type = type self.type = type
self.descriptorProvider = descriptorProvider self.descriptorProvider = descriptorProvider
def membersNeedTracing(self):
for t in self.type.flatMemberTypes:
if type_needs_tracing(t):
return True
return False
def from_jsval(self): def from_jsval(self):
memberTypes = self.type.flatMemberTypes memberTypes = self.type.flatMemberTypes
names = [] names = []
@ -4310,7 +4324,10 @@ class CGUnionConversionStruct(CGThing):
def try_method(self, t): def try_method(self, t):
templateVars = getUnionTypeTemplateVars(t, self.descriptorProvider) templateVars = getUnionTypeTemplateVars(t, self.descriptorProvider)
returnType = "Result<Option<%s>, ()>" % templateVars["typeName"] actualType = templateVars["typeName"]
if type_needs_tracing(t):
actualType = "RootedTraceableBox<%s>" % actualType
returnType = "Result<Option<%s>, ()>" % actualType
jsConversion = templateVars["jsConversion"] jsConversion = templateVars["jsConversion"]
# Any code to convert to Object is unused, since we're already converting # Any code to convert to Object is unused, since we're already converting
@ -6022,13 +6039,17 @@ class CGDictionary(CGThing):
(self.makeMemberName(m[0].identifier.name), self.getMemberType(m)) (self.makeMemberName(m[0].identifier.name), self.getMemberType(m))
for m in self.memberInfo] for m in self.memberInfo]
mustRoot = "#[must_root]\n" if self.membersNeedTracing() else ""
return (string.Template( return (string.Template(
"#[derive(JSTraceable)]\n" "#[derive(JSTraceable)]\n"
"${mustRoot}" +
"pub struct ${selfName} {\n" + "pub struct ${selfName} {\n" +
"${inheritance}" + "${inheritance}" +
"\n".join(memberDecls) + "\n" + "\n".join(memberDecls) + "\n" +
"}").substitute({"selfName": self.makeClassName(d), "}").substitute({"selfName": self.makeClassName(d),
"inheritance": inheritance})) "inheritance": inheritance,
"mustRoot": mustRoot}))
def impl(self): def impl(self):
d = self.dictionary d = self.dictionary
@ -6120,6 +6141,12 @@ class CGDictionary(CGThing):
"insertMembers": CGIndenter(memberInserts, indentLevel=8).define(), "insertMembers": CGIndenter(memberInserts, indentLevel=8).define(),
}) })
def membersNeedTracing(self):
for member, _ in self.memberInfo:
if type_needs_tracing(member.type):
return True
return False
@staticmethod @staticmethod
def makeDictionaryName(dictionary): def makeDictionaryName(dictionary):
return dictionary.identifier.name return dictionary.identifier.name

View file

@ -12,7 +12,7 @@ use dom::bindings::codegen::Bindings::IterableIteratorBinding::IterableKeyOrValu
use dom::bindings::error::Fallible; use dom::bindings::error::Fallible;
use dom::bindings::js::{JS, Root}; use dom::bindings::js::{JS, Root};
use dom::bindings::reflector::{DomObject, Reflector, reflect_dom_object}; use dom::bindings::reflector::{DomObject, Reflector, reflect_dom_object};
use dom::bindings::trace::JSTraceable; use dom::bindings::trace::{JSTraceable, RootedTraceableBox};
use dom::globalscope::GlobalScope; use dom::globalscope::GlobalScope;
use dom_struct::dom_struct; use dom_struct::dom_struct;
use js::conversions::ToJSValConvertible; use js::conversions::ToJSValConvertible;
@ -115,7 +115,7 @@ fn dict_return(cx: *mut JSContext,
result: MutableHandleObject, result: MutableHandleObject,
done: bool, done: bool,
value: HandleValue) -> Fallible<()> { value: HandleValue) -> Fallible<()> {
let mut dict = unsafe { IterableKeyOrValueResult::empty(cx) }; let mut dict = RootedTraceableBox::new(unsafe { IterableKeyOrValueResult::empty(cx) });
dict.done = done; dict.done = done;
dict.value.set(value.get()); dict.value.set(value.get());
rooted!(in(cx) let mut dict_value = UndefinedValue()); rooted!(in(cx) let mut dict_value = UndefinedValue());
@ -130,7 +130,7 @@ fn key_and_value_return(cx: *mut JSContext,
result: MutableHandleObject, result: MutableHandleObject,
key: HandleValue, key: HandleValue,
value: HandleValue) -> Fallible<()> { value: HandleValue) -> Fallible<()> {
let mut dict = unsafe { IterableKeyAndValueResult::empty(cx) }; let mut dict = RootedTraceableBox::new(unsafe { IterableKeyAndValueResult::empty(cx) });
dict.done = false; dict.done = false;
dict.value = Some(vec![Heap::new(key.get()), Heap::new(value.get())]); dict.value = Some(vec![Heap::new(key.get()), Heap::new(value.get())]);
rooted!(in(cx) let mut dict_value = UndefinedValue()); rooted!(in(cx) let mut dict_value = UndefinedValue());

View file

@ -745,6 +745,7 @@ impl<'a, T: JSTraceable + 'static> Drop for RootedTraceable<'a, T> {
/// If you have GC things like *mut JSObject or JSVal, use rooted!. /// If you have GC things like *mut JSObject or JSVal, use rooted!.
/// If you have an arbitrary number of DomObjects to root, use rooted_vec!. /// If you have an arbitrary number of DomObjects to root, use rooted_vec!.
/// If you know what you're doing, use this. /// If you know what you're doing, use this.
#[allow_unrooted_interior]
pub struct RootedTraceableBox<T: 'static + JSTraceable> { pub struct RootedTraceableBox<T: 'static + JSTraceable> {
ptr: *mut T, ptr: *mut T,
} }

View file

@ -13,6 +13,7 @@ use dom::bindings::js::{MutNullableJS, Root};
use dom::bindings::refcounted::Trusted; use dom::bindings::refcounted::Trusted;
use dom::bindings::reflector::{DomObject, reflect_dom_object}; use dom::bindings::reflector::{DomObject, reflect_dom_object};
use dom::bindings::str::DOMString; use dom::bindings::str::DOMString;
use dom::bindings::trace::RootedTraceableBox;
use dom::blob::Blob; use dom::blob::Blob;
use dom::domexception::{DOMErrorName, DOMException}; use dom::domexception::{DOMErrorName, DOMException};
use dom::event::{Event, EventBubbles, EventCancelable}; use dom::event::{Event, EventBubbles, EventCancelable};
@ -338,7 +339,8 @@ impl FileReaderMethods for FileReader {
FileReaderResult::String(ref string) => FileReaderResult::String(ref string) =>
StringOrObject::String(string.clone()), StringOrObject::String(string.clone()),
FileReaderResult::ArrayBuffer(ref arr_buffer) => { FileReaderResult::ArrayBuffer(ref arr_buffer) => {
StringOrObject::Object(Heap::new((*arr_buffer.ptr.get()).to_object())) StringOrObject::Object(RootedTraceableBox::new(
Heap::new((*arr_buffer.ptr.get()).to_object())))
} }
}) })
} }

View file

@ -338,8 +338,8 @@ impl TestBindingMethods for TestBinding {
Some(ByteStringOrLong::ByteString(ByteString::new(vec!()))) Some(ByteStringOrLong::ByteString(ByteString::new(vec!())))
} }
fn ReceiveNullableSequence(&self) -> Option<Vec<i32>> { Some(vec![1]) } fn ReceiveNullableSequence(&self) -> Option<Vec<i32>> { Some(vec![1]) }
fn ReceiveTestDictionaryWithSuccessOnKeyword(&self) -> TestDictionary { fn ReceiveTestDictionaryWithSuccessOnKeyword(&self) -> RootedTraceableBox<TestDictionary> {
TestDictionary { RootedTraceableBox::new(TestDictionary {
anyValue: Heap::new(NullValue()), anyValue: Heap::new(NullValue()),
booleanValue: None, booleanValue: None,
byteValue: None, byteValue: None,
@ -401,7 +401,7 @@ impl TestBindingMethods for TestBinding {
usvstringValue: None, usvstringValue: None,
nonRequiredNullable: None, nonRequiredNullable: None,
nonRequiredNullable2: Some(None), // null nonRequiredNullable2: Some(None), // null
} })
} }
fn DictMatchesPassedValues(&self, arg: RootedTraceableBox<TestDictionary>) -> bool { fn DictMatchesPassedValues(&self, arg: RootedTraceableBox<TestDictionary>) -> bool {
@ -436,9 +436,9 @@ impl TestBindingMethods for TestBinding {
fn PassUnion6(&self, _: UnsignedLongOrBoolean) {} fn PassUnion6(&self, _: UnsignedLongOrBoolean) {}
fn PassUnion7(&self, _: StringSequenceOrUnsignedLong) {} fn PassUnion7(&self, _: StringSequenceOrUnsignedLong) {}
fn PassUnion8(&self, _: ByteStringSequenceOrLong) {} fn PassUnion8(&self, _: ByteStringSequenceOrLong) {}
fn PassUnion9(&self, _: RootedTraceableBox<UnionTypes::TestDictionaryOrLong>) {} fn PassUnion9(&self, _: UnionTypes::TestDictionaryOrLong) {}
#[allow(unsafe_code)] #[allow(unsafe_code)]
unsafe fn PassUnion10(&self, _: *mut JSContext, _: RootedTraceableBox<UnionTypes::StringOrObject>) {} unsafe fn PassUnion10(&self, _: *mut JSContext, _: UnionTypes::StringOrObject) {}
fn PassUnionWithTypedef(&self, _: DocumentOrTestTypedef) {} fn PassUnionWithTypedef(&self, _: DocumentOrTestTypedef) {}
fn PassUnionWithTypedef2(&self, _: LongSequenceOrTestTypedef) {} fn PassUnionWithTypedef2(&self, _: LongSequenceOrTestTypedef) {}
#[allow(unsafe_code)] #[allow(unsafe_code)]