diff --git a/components/shared/base/generic_channel.rs b/components/shared/base/generic_channel.rs index cd6852fdb81..ce336f0d174 100644 --- a/components/shared/base/generic_channel.rs +++ b/components/shared/base/generic_channel.rs @@ -6,17 +6,25 @@ use std::fmt; use std::fmt::Display; +use std::marker::PhantomData; use ipc_channel::ipc::IpcError; use ipc_channel::router::ROUTER; use malloc_size_of::{MallocSizeOf, MallocSizeOfOps}; +use serde::de::VariantAccess; use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use servo_config::opts; -static GENERIC_CHANNEL_USAGE_ERROR_PANIC_MSG: &str = "May not send a crossbeam channel over an IPC channel. \ - Please also convert the ipc-channel you want to send this GenericReceiver over \ - into a GenericChannel."; +/// A GenericSender that sends messages to a [GenericReceiver]. +/// +/// The sender supports sending messages cross-process, if servo is run in multiprocess mode. +pub struct GenericSender(GenericSenderVariants); -pub enum GenericSender { +/// The actual GenericSender variant. +/// +/// This enum is private, so that outside code can't construct a GenericSender itself. +/// This ensures that users can't construct a crossbeam variant in multiprocess mode. +enum GenericSenderVariants { Ipc(ipc_channel::ipc::IpcSender), /// A crossbeam-channel. To keep the API in sync with the Ipc variant when using a Router, /// which propagates the IPC error, the inner type is a Result. @@ -30,20 +38,91 @@ pub enum GenericSender { impl Serialize for GenericSender { fn serialize(&self, s: S) -> Result { - match self { - GenericSender::Ipc(i) => i.serialize(s), - GenericSender::Crossbeam(_) => panic!("{GENERIC_CHANNEL_USAGE_ERROR_PANIC_MSG}"), + match &self.0 { + GenericSenderVariants::Ipc(sender) => { + s.serialize_newtype_variant("GenericSender", 0, "Ipc", sender) + }, + // All GenericSenders will be IPC channels in multi-process mode, so sending a + // GenericChannel over existing IPC channels is no problem and won't fail. + // In single-process mode, we can also send GenericSenders over other GenericSenders + // just fine, since no serialization is required. + // The only reason we need / want serialization is to support sending GenericSenders + // over existing IPC channels **in single process mode**. This allows us to + // incrementally port channels to the GenericChannel, without needing to follow a + // top-to-bottom approach. + // Long-term we can remove this branch in the code again and replace it with + // unreachable, since likely all IPC channels would be GenericChannels. + GenericSenderVariants::Crossbeam(sender) => { + if opts::get().multiprocess { + return Err(serde::ser::Error::custom( + "Crossbeam channel found in multiprocess mode!", + )); + } // We know everything is in one address-space, so we can "serialize" the sender by + // sending a leaked Box pointer. + let sender_clone_addr = Box::leak(Box::new(sender.clone())) as *mut _ as usize; + s.serialize_newtype_variant("GenericSender", 1, "Crossbeam", &sender_clone_addr) + }, } } } -impl<'a, T: Serialize> Deserialize<'a> for GenericSender { +struct GenericSenderVisitor { + marker: PhantomData, +} + +impl<'de, T: Serialize + Deserialize<'de>> serde::de::Visitor<'de> for GenericSenderVisitor { + type Value = GenericSender; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a GenericSender variant") + } + + fn visit_enum(self, data: A) -> Result + where + A: serde::de::EnumAccess<'de>, + { + #[derive(Deserialize)] + enum GenericSenderVariantNames { + Ipc, + Crossbeam, + } + + let (variant_name, variant_data): (GenericSenderVariantNames, _) = data.variant()?; + + match variant_name { + GenericSenderVariantNames::Ipc => variant_data + .newtype_variant::>() + .map(|sender| GenericSender(GenericSenderVariants::Ipc(sender))), + GenericSenderVariantNames::Crossbeam => { + if opts::get().multiprocess { + return Err(serde::de::Error::custom( + "Crossbeam channel found in multiprocess mode!", + )); + } + let addr = variant_data.newtype_variant::()?; + let ptr = addr as *mut crossbeam_channel::Sender>; + // SAFETY: We know we are in the same address space as the sender, so we can safely + // reconstruct the Box. + #[allow(unsafe_code)] + let sender = unsafe { Box::from_raw(ptr) }; + Ok(GenericSender(GenericSenderVariants::Crossbeam(*sender))) + }, + } + } +} + +impl<'a, T: Serialize + Deserialize<'a>> Deserialize<'a> for GenericSender { fn deserialize(d: D) -> Result, D::Error> where D: Deserializer<'a>, { - // Only ipc_channel will encounter deserialize scenario. - ipc_channel::ipc::IpcSender::::deserialize(d).map(GenericSender::Ipc) + d.deserialize_enum( + "GenericSender", + &["Ipc", "Crossbeam"], + GenericSenderVisitor { + marker: PhantomData, + }, + ) } } @@ -52,9 +131,13 @@ where T: Serialize, { fn clone(&self) -> Self { - match *self { - GenericSender::Ipc(ref chan) => GenericSender::Ipc(chan.clone()), - GenericSender::Crossbeam(ref chan) => GenericSender::Crossbeam(chan.clone()), + match self.0 { + GenericSenderVariants::Ipc(ref chan) => { + GenericSender(GenericSenderVariants::Ipc(chan.clone())) + }, + GenericSenderVariants::Crossbeam(ref chan) => { + GenericSender(GenericSenderVariants::Crossbeam(chan.clone())) + }, } } } @@ -68,11 +151,11 @@ impl fmt::Debug for GenericSender { impl GenericSender { #[inline] pub fn send(&self, msg: T) -> SendResult { - match *self { - GenericSender::Ipc(ref sender) => sender + match self.0 { + GenericSenderVariants::Ipc(ref sender) => sender .send(msg) .map_err(|e| SendError::SerializationError(format!("{e}"))), - GenericSender::Crossbeam(ref sender) => { + GenericSenderVariants::Crossbeam(ref sender) => { sender.send(Ok(msg)).map_err(|_| SendError::Disconnected) }, } @@ -151,10 +234,15 @@ impl From for TryReceiveError { } } +pub type RoutedReceiver = crossbeam_channel::Receiver>; pub type ReceiveResult = Result; pub type TryReceiveResult = Result; -pub enum GenericReceiver +pub struct GenericReceiver(GenericReceiverVariants) +where + T: for<'de> Deserialize<'de> + Serialize; + +enum GenericReceiverVariants where T: for<'de> Deserialize<'de> + Serialize, { @@ -168,9 +256,9 @@ where { #[inline] pub fn recv(&self) -> ReceiveResult { - match *self { - GenericReceiver::Ipc(ref receiver) => Ok(receiver.recv()?), - GenericReceiver::Crossbeam(ref receiver) => { + match self.0 { + GenericReceiverVariants::Ipc(ref receiver) => Ok(receiver.recv()?), + GenericReceiverVariants::Crossbeam(ref receiver) => { // `recv()` returns an error if the channel is disconnected let msg = receiver.recv()?; // `msg` must be `ok` because the corresponding [`GenericSender::Crossbeam`] will @@ -182,9 +270,9 @@ where #[inline] pub fn try_recv(&self) -> TryReceiveResult { - match *self { - GenericReceiver::Ipc(ref receiver) => Ok(receiver.try_recv()?), - GenericReceiver::Crossbeam(ref receiver) => { + match self.0 { + GenericReceiverVariants::Ipc(ref receiver) => Ok(receiver.try_recv()?), + GenericReceiverVariants::Crossbeam(ref receiver) => { let msg = receiver.try_recv()?; Ok(msg.expect("Infallible")) }, @@ -200,8 +288,8 @@ where where T: Send + 'static, { - match self { - GenericReceiver::Ipc(ipc_receiver) => { + match self.0 { + GenericReceiverVariants::Ipc(ipc_receiver) => { let (crossbeam_sender, crossbeam_receiver) = crossbeam_channel::unbounded(); let crossbeam_sender_clone = crossbeam_sender.clone(); ROUTER.add_typed_route( @@ -212,7 +300,7 @@ where ); crossbeam_receiver }, - GenericReceiver::Crossbeam(receiver) => receiver, + GenericReceiverVariants::Crossbeam(receiver) => receiver, } } } @@ -222,9 +310,69 @@ where T: for<'de> Deserialize<'de> + Serialize, { fn serialize(&self, s: S) -> Result { - match self { - GenericReceiver::Ipc(receiver) => receiver.serialize(s), - GenericReceiver::Crossbeam(_) => panic!("{GENERIC_CHANNEL_USAGE_ERROR_PANIC_MSG}"), + match &self.0 { + GenericReceiverVariants::Ipc(receiver) => { + s.serialize_newtype_variant("GenericReceiver", 0, "Ipc", receiver) + }, + GenericReceiverVariants::Crossbeam(receiver) => { + if opts::get().multiprocess { + return Err(serde::ser::Error::custom( + "Crossbeam channel found in multiprocess mode!", + )); + } // We know everything is in one address-space, so we can "serialize" the receiver by + // sending a leaked Box pointer. + let receiver_clone_addr = Box::leak(Box::new(receiver.clone())) as *mut _ as usize; + s.serialize_newtype_variant("GenericReceiver", 1, "Crossbeam", &receiver_clone_addr) + }, + } + } +} + +struct GenericReceiverVisitor { + marker: PhantomData, +} +impl<'de, T> serde::de::Visitor<'de> for GenericReceiverVisitor +where + T: for<'a> Deserialize<'a> + Serialize, +{ + type Value = GenericReceiver; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a GenericReceiver variant") + } + + fn visit_enum(self, data: A) -> Result + where + A: serde::de::EnumAccess<'de>, + { + #[derive(Deserialize)] + enum GenericReceiverVariantNames { + Ipc, + Crossbeam, + } + + let (variant_name, variant_data): (GenericReceiverVariantNames, _) = data.variant()?; + + match variant_name { + GenericReceiverVariantNames::Ipc => variant_data + .newtype_variant::>() + .map(|receiver| GenericReceiver(GenericReceiverVariants::Ipc(receiver))), + GenericReceiverVariantNames::Crossbeam => { + if opts::get().multiprocess { + return Err(serde::de::Error::custom( + "Crossbeam channel found in multiprocess mode!", + )); + } + let addr = variant_data.newtype_variant::()?; + let ptr = addr as *mut RoutedReceiver; + // SAFETY: We know we are in the same address space as the sender, so we can safely + // reconstruct the Box. + #[allow(unsafe_code)] + let receiver = unsafe { Box::from_raw(ptr) }; + Ok(GenericReceiver(GenericReceiverVariants::Crossbeam( + *receiver, + ))) + }, } } } @@ -237,11 +385,42 @@ where where D: Deserializer<'a>, { - // Only ipc_channel will encounter deserialize scenario. - ipc_channel::ipc::IpcReceiver::::deserialize(d).map(GenericReceiver::Ipc) + d.deserialize_enum( + "GenericReceiver", + &["Ipc", "Crossbeam"], + GenericReceiverVisitor { + marker: PhantomData, + }, + ) } } +/// Private helper function to create a crossbeam based channel. +/// +/// Do NOT make this function public! +fn new_generic_channel_crossbeam() -> (GenericSender, GenericReceiver) +where + T: Serialize + for<'de> serde::Deserialize<'de>, +{ + let (tx, rx) = crossbeam_channel::unbounded(); + ( + GenericSender(GenericSenderVariants::Crossbeam(tx)), + GenericReceiver(GenericReceiverVariants::Crossbeam(rx)), + ) +} + +fn new_generic_channel_ipc() -> Result<(GenericSender, GenericReceiver), std::io::Error> +where + T: Serialize + for<'de> serde::Deserialize<'de>, +{ + ipc_channel::ipc::channel().map(|(tx, rx)| { + ( + GenericSender(GenericSenderVariants::Ipc(tx)), + GenericReceiver(GenericReceiverVariants::Ipc(rx)), + ) + }) +} + /// Creates a Servo channel that can select different channel implementations based on multiprocess /// mode or not. If the scenario doesn't require message to pass process boundary, a simple /// crossbeam channel is preferred. @@ -250,12 +429,110 @@ where T: for<'de> Deserialize<'de> + Serialize, { if servo_config::opts::get().multiprocess || servo_config::opts::get().force_ipc { - ipc_channel::ipc::channel() - .map(|(tx, rx)| (GenericSender::Ipc(tx), GenericReceiver::Ipc(rx))) - .ok() + new_generic_channel_ipc().ok() } else { - let (tx, rx) = crossbeam_channel::unbounded(); - Some((GenericSender::Crossbeam(tx), GenericReceiver::Crossbeam(rx))) + Some(new_generic_channel_crossbeam()) + } +} + +#[cfg(test)] +mod single_process_channel_tests { + //! These unit-tests test that ipc_channel and crossbeam_channel Senders and Receivers + //! can be sent over each other without problems in single-process mode. + //! In multiprocess mode we exclusively use `ipc_channel` anyway, which is ensured due + //! to `channel()` being the only way to construct `GenericSender` and Receiver pairs. + use crate::generic_channel::{new_generic_channel_crossbeam, new_generic_channel_ipc}; + + #[test] + fn generic_crossbeam_can_send() { + let (tx, rx) = new_generic_channel_crossbeam(); + tx.send(5).expect("Send failed"); + let val = rx.recv().expect("Receive failed"); + assert_eq!(val, 5); + } + + #[test] + fn generic_crossbeam_ping_pong() { + let (tx, rx) = new_generic_channel_crossbeam(); + let (tx2, rx2) = new_generic_channel_crossbeam(); + + tx.send(tx2).expect("Send failed"); + + std::thread::scope(|s| { + s.spawn(move || { + let reply_sender = rx.recv().expect("Receive failed"); + reply_sender.send(42).expect("Sending reply failed"); + }); + }); + let res = rx2.recv().expect("Receive of reply failed"); + assert_eq!(res, 42); + } + + #[test] + fn generic_ipc_ping_pong() { + let (tx, rx) = new_generic_channel_ipc().unwrap(); + let (tx2, rx2) = new_generic_channel_ipc().unwrap(); + + tx.send(tx2).expect("Send failed"); + + std::thread::scope(|s| { + s.spawn(move || { + let reply_sender = rx.recv().expect("Receive failed"); + reply_sender.send(42).expect("Sending reply failed"); + }); + }); + let res = rx2.recv().expect("Receive of reply failed"); + assert_eq!(res, 42); + } + + #[test] + fn send_crossbeam_sender_over_ipc_channel() { + let (tx, rx) = new_generic_channel_ipc().unwrap(); + let (tx2, rx2) = new_generic_channel_crossbeam(); + + tx.send(tx2).expect("Send failed"); + + std::thread::scope(|s| { + s.spawn(move || { + let reply_sender = rx.recv().expect("Receive failed"); + reply_sender.send(42).expect("Sending reply failed"); + }); + }); + let res = rx2.recv().expect("Receive of reply failed"); + assert_eq!(res, 42); + } + + #[test] + fn send_generic_ipc_channel_over_crossbeam() { + let (tx, rx) = new_generic_channel_crossbeam(); + let (tx2, rx2) = new_generic_channel_ipc().unwrap(); + + tx.send(tx2).expect("Send failed"); + + std::thread::scope(|s| { + s.spawn(move || { + let reply_sender = rx.recv().expect("Receive failed"); + reply_sender.send(42).expect("Sending reply failed"); + }); + }); + let res = rx2.recv().expect("Receive of reply failed"); + assert_eq!(res, 42); + } + + #[test] + fn send_crossbeam_receiver_over_ipc_channel() { + let (tx, rx) = new_generic_channel_ipc().unwrap(); + let (tx2, rx2) = new_generic_channel_crossbeam(); + + tx.send(rx2).expect("Send failed"); + tx2.send(42).expect("Send failed"); + + std::thread::scope(|s| { + s.spawn(move || { + let another_receiver = rx.recv().expect("Receive failed"); + let res = another_receiver.recv().expect("Receive failed"); + assert_eq!(res, 42); + }); + }); } } -pub type RoutedReceiver = crossbeam_channel::Receiver>;