webgpu: Implement device lost (#32354)

* device lost promise should be init at creation of device object

* device lost impl

* lock for device poll

workaround for wgpu deadlocks

* expect

* Less lost reason reasoning in script
This commit is contained in:
Samson 2024-06-17 14:47:25 +02:00 committed by GitHub
parent 3381f2a704
commit cbc9304c20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 224 additions and 5628 deletions

View file

@ -7,7 +7,7 @@
//! This is roughly based on <https://github.com/LucentFlux/wgpu-async/blob/1322c7e3fcdfc1865a472c7bbbf0e2e06dcf4da8/src/wgpu_future.rs>
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::sync::{Arc, Mutex, MutexGuard};
use std::thread::JoinHandle;
use log::warn;
@ -40,14 +40,25 @@ pub(crate) struct Poller {
is_done: Arc<AtomicBool>,
/// Handle to the WGPU poller thread (to be used for unparking the thread)
handle: Option<JoinHandle<()>>,
/// Lock for device maintain calls (in poll_all_devices and queue_submit)
///
/// This is workaround for wgpu deadlocks: https://github.com/gfx-rs/wgpu/issues/5572
lock: Arc<Mutex<()>>,
}
#[inline]
fn poll_all_devices(global: &Arc<Global>, more_work: &mut bool, force_wait: bool) {
fn poll_all_devices(
global: &Arc<Global>,
more_work: &mut bool,
force_wait: bool,
lock: &Mutex<()>,
) {
let _guard = lock.lock().unwrap();
match global.poll_all_devices(force_wait) {
Ok(all_queue_empty) => *more_work = !all_queue_empty,
Err(e) => warn!("Poller thread got `{e}` on poll_all_devices."),
}
// drop guard
}
impl Poller {
@ -56,9 +67,11 @@ impl Poller {
let is_done = Arc::new(AtomicBool::new(false));
let work = work_count.clone();
let done = is_done.clone();
let lock = Arc::new(Mutex::new(()));
Self {
work_count,
is_done,
lock: Arc::clone(&lock),
handle: Some(
std::thread::Builder::new()
.name("WGPU poller".into())
@ -69,9 +82,9 @@ impl Poller {
// so every `ẁake` (even spurious) will do at least one poll.
// this is mostly useful for stuff that is deferred
// to maintain calls in wgpu (device resource destruction)
poll_all_devices(&global, &mut more_work, false);
poll_all_devices(&global, &mut more_work, false, &lock);
while more_work || work.load(Ordering::Acquire) != 0 {
poll_all_devices(&global, &mut more_work, true);
poll_all_devices(&global, &mut more_work, true, &lock);
}
std::thread::park(); //TODO: should we use timeout here
}
@ -101,6 +114,11 @@ impl Poller {
.thread()
.unpark();
}
/// Lock for device maintain calls (in poll_all_devices and queue_submit)
pub(crate) fn lock(&self) -> MutexGuard<()> {
self.lock.lock().unwrap()
}
}
impl Drop for Poller {

View file

@ -15,6 +15,13 @@ use crate::wgc::id::{
ShaderModuleId, StagingBufferId, SurfaceId, TextureId, TextureViewId,
};
/// <https://gpuweb.github.io/gpuweb/#enumdef-gpudevicelostreason>
#[derive(Clone, Copy, Debug, Deserialize, Serialize)]
pub enum DeviceLostReason {
Unknown,
Destroyed,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub enum WebGPUMsg {
FreeAdapter(AdapterId),
@ -34,14 +41,16 @@ pub enum WebGPUMsg {
FreeRenderBundle(RenderBundleId),
FreeStagingBuffer(StagingBufferId),
FreeQuerySet(QuerySetId),
CleanDevice {
device: WebGPUDevice,
pipeline_id: PipelineId,
},
UncapturedError {
device: WebGPUDevice,
pipeline_id: PipelineId,
error: Error,
},
DeviceLost {
device: WebGPUDevice,
pipeline_id: PipelineId,
reason: DeviceLostReason,
msg: String,
},
Exit,
}

View file

@ -20,7 +20,7 @@ use webrender_api::{DirtyRect, DocumentId};
use webrender_traits::{WebrenderExternalImageRegistry, WebrenderImageHandlerType};
use wgc::command::{ImageCopyBuffer, ImageCopyTexture};
use wgc::device::queue::SubmittedWorkDoneClosure;
use wgc::device::{DeviceDescriptor, HostMap, ImplicitPipelineIds};
use wgc::device::{DeviceDescriptor, DeviceLostClosure, HostMap, ImplicitPipelineIds};
use wgc::id::DeviceId;
use wgc::instance::parse_backends_from_comma_list;
use wgc::pipeline::ShaderModuleDescriptor;
@ -68,7 +68,7 @@ pub(crate) struct WGPU {
script_sender: IpcSender<WebGPUMsg>,
global: Arc<wgc::global::Global>,
adapters: Vec<WebGPUAdapter>,
devices: HashMap<DeviceId, DeviceScope>,
devices: Arc<Mutex<HashMap<DeviceId, DeviceScope>>>,
// Track invalid adapters https://gpuweb.github.io/gpuweb/#invalid
_invalid_adapters: Vec<WebGPUAdapter>,
//TODO: Remove this (https://github.com/gfx-rs/wgpu/issues/867)
@ -114,7 +114,7 @@ impl WGPU {
script_sender,
global,
adapters: Vec::new(),
devices: HashMap::new(),
devices: Arc::new(Mutex::new(HashMap::new())),
_invalid_adapters: Vec::new(),
error_command_encoders: RefCell::new(HashMap::new()),
webrender_api: Arc::new(Mutex::new(webrender_api_sender.create_api())),
@ -525,6 +525,8 @@ impl WGPU {
WebGPURequest::DestroyDevice(device) => {
let global = &self.global;
gfx_select!(device => global.device_destroy(device));
// Wake poller thread to trigger DeviceLostClosure
self.poller.wake();
},
WebGPURequest::DestroySwapChain {
external_id,
@ -580,15 +582,8 @@ impl WGPU {
};
},
WebGPURequest::DropDevice(device_id) => {
let device = WebGPUDevice(device_id);
let pipeline_id = self.devices.remove(&device_id).unwrap().pipeline_id;
if let Err(e) = self.script_sender.send(WebGPUMsg::CleanDevice {
device,
pipeline_id,
}) {
warn!("Unable to send CleanDevice({:?}) ({:?})", device_id, e);
}
let global = &self.global;
// device_drop also calls device lost callback
gfx_select!(device_id => global.device_drop(device_id));
if let Err(e) = self.script_sender.send(WebGPUMsg::FreeDevice(device_id)) {
warn!("Unable to send FreeDevice({:?}) ({:?})", device_id, e);
@ -686,8 +681,44 @@ impl WGPU {
};
let device = WebGPUDevice(device_id);
let queue = WebGPUQueue(queue_id);
self.devices
.insert(device_id, DeviceScope::new(device_id, pipeline_id));
{
self.devices
.lock()
.unwrap()
.insert(device_id, DeviceScope::new(device_id, pipeline_id));
}
let script_sender = self.script_sender.clone();
let devices = Arc::clone(&self.devices);
let callback =
DeviceLostClosure::from_rust(Box::from(move |reason, msg| {
let _ = devices.lock().unwrap().remove(&device_id);
let reason = match reason {
wgt::DeviceLostReason::Unknown => {
Some(crate::DeviceLostReason::Unknown)
},
wgt::DeviceLostReason::Destroyed => {
Some(crate::DeviceLostReason::Destroyed)
},
wgt::DeviceLostReason::Dropped => None,
wgt::DeviceLostReason::ReplacedCallback => {
panic!("DeviceLost callback should only be set once")
},
wgt::DeviceLostReason::DeviceInvalid => {
Some(crate::DeviceLostReason::Unknown)
},
};
if let Some(reason) = reason {
if let Err(e) = script_sender.send(WebGPUMsg::DeviceLost {
device: WebGPUDevice(device_id),
pipeline_id,
reason,
msg,
}) {
warn!("Failed to send WebGPUMsg::DeviceLost: {e}");
}
}
}));
gfx_select!(device_id => global.device_set_device_lost_closure(device_id, callback));
if let Err(e) = sender.send(Some(Ok(WebGPUResponse::RequestDevice {
device_id: device,
queue_id: queue,
@ -744,6 +775,7 @@ impl WGPU {
"Invalid command buffer submitted",
)))
} else {
let _guard = self.poller.lock();
gfx_select!(queue_id => global.queue_submit(queue_id, &command_buffers))
.map_err(Error::from_error)
};
@ -951,6 +983,7 @@ impl WGPU {
data,
} => {
let global = &self.global;
let _guard = self.poller.lock();
//TODO: Report result to content process
let result = gfx_select!(queue_id => global.queue_write_texture(
queue_id,
@ -959,6 +992,7 @@ impl WGPU {
&data_layout,
&size
));
drop(_guard);
self.maybe_dispatch_wgpu_error(queue_id.transmute(), result.err());
},
WebGPURequest::QueueOnSubmittedWorkDone { sender, queue_id } => {
@ -1070,18 +1104,18 @@ impl WGPU {
},
WebGPURequest::PushErrorScope { device_id, filter } => {
// <https://www.w3.org/TR/webgpu/#dom-gpudevice-pusherrorscope>
let device_scope = self
.devices
.get_mut(&device_id)
.expect("Using invalid device");
device_scope.error_scope_stack.push(ErrorScope::new(filter));
let mut devices = self.devices.lock().unwrap();
if let Some(device_scope) = devices.get_mut(&device_id) {
device_scope.error_scope_stack.push(ErrorScope::new(filter));
} // else device is lost
},
WebGPURequest::DispatchError { device_id, error } => {
self.dispatch_error(device_id, error);
},
WebGPURequest::PopErrorScope { device_id, sender } => {
// <https://www.w3.org/TR/webgpu/#dom-gpudevice-poperrorscope>
if let Some(device_scope) = self.devices.get_mut(&device_id) {
if let Some(device_scope) = self.devices.lock().unwrap().get_mut(&device_id)
{
if let Some(error_scope) = device_scope.error_scope_stack.pop() {
if let Err(e) =
sender.send(Some(Ok(WebGPUResponse::PoppedErrorScope(Ok(
@ -1133,7 +1167,7 @@ impl WGPU {
/// <https://www.w3.org/TR/webgpu/#abstract-opdef-dispatch-error>
fn dispatch_error(&mut self, device_id: id::DeviceId, error: Error) {
if let Some(device_scope) = self.devices.get_mut(&device_id) {
if let Some(device_scope) = self.devices.lock().unwrap().get_mut(&device_id) {
if let Some(error_scope) = device_scope
.error_scope_stack
.iter_mut()