Skip to content

Commit

Permalink
fix(wgpu): Raise validation error instead of panicking in get_bind_gr…
Browse files Browse the repository at this point in the history
…oup_layout.

Fixes #4167.
  • Loading branch information
BGR360 committed Sep 18, 2024
1 parent 0d339fc commit da474bc
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 54 deletions.
205 changes: 166 additions & 39 deletions tests/tests/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,44 +1,16 @@
use wgpu_test::{fail, gpu_test, FailureCase, GpuTestConfiguration, TestParameters};
use wgpu_test::{fail, gpu_test, GpuTestConfiguration, TestParameters};

// Create an invalid shader and a compute pipeline that uses it
// with a default bindgroup layout, and then ask for that layout.
// Validation should fail, but wgpu should not panic.
#[gpu_test]
static PIPELINE_DEFAULT_LAYOUT_BAD_MODULE: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
// https://github.com/gfx-rs/wgpu/issues/4167
.expect_fail(FailureCase::always().panic("Error reflecting bind group")),
)
.run_sync(|ctx| {
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);

fail(
&ctx.device,
|| {
let module = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
source: wgpu::ShaderSource::Wgsl("not valid wgsl".into()),
});

let pipeline =
ctx.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("mandelbrot compute pipeline"),
layout: None,
module: &module,
entry_point: Some("doesn't exist"),
compilation_options: Default::default(),
cache: None,
});
const INVALID_SHADER_DESC: wgpu::ShaderModuleDescriptor = wgpu::ShaderModuleDescriptor {
label: Some("invalid shader"),
source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed("not valid wgsl")),
};

pipeline.get_bind_group_layout(0);
},
None,
);
});
const TRIVIAL_COMPUTE_SHADER_DESC: wgpu::ShaderModuleDescriptor = wgpu::ShaderModuleDescriptor {
label: Some("trivial compute shader"),
source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(
"@compute @workgroup_size(1) fn main() {}",
)),
};

const TRIVIAL_VERTEX_SHADER_DESC: wgpu::ShaderModuleDescriptor = wgpu::ShaderModuleDescriptor {
label: Some("trivial vertex shader"),
Expand All @@ -47,6 +19,161 @@ const TRIVIAL_VERTEX_SHADER_DESC: wgpu::ShaderModuleDescriptor = wgpu::ShaderMod
)),
};

const TRIVIAL_FRAGMENT_SHADER_DESC: wgpu::ShaderModuleDescriptor = wgpu::ShaderModuleDescriptor {
label: Some("trivial fragment shader"),
source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(
"@fragment fn main() -> @location(0) vec4<f32> { return vec4<f32>(0); }",
)),
};

// Create an invalid shader and a compute pipeline that uses it
// with a default bindgroup layout, and then ask for that layout.
// Validation should fail, but wgpu should not panic.
#[gpu_test]
static COMPUTE_PIPELINE_DEFAULT_LAYOUT_BAD_MODULE: GpuTestConfiguration =
GpuTestConfiguration::new()
.parameters(TestParameters::default())
.run_sync(|ctx| {
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);

fail(
&ctx.device,
|| {
let module = ctx.device.create_shader_module(INVALID_SHADER_DESC);

let pipeline =
ctx.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("compute pipeline"),
layout: None,
module: &module,
entry_point: Some("doesn't exist"),
compilation_options: Default::default(),
cache: None,
});

// https://github.com/gfx-rs/wgpu/issues/4167 this used to panic
pipeline.get_bind_group_layout(0);
},
Some("Shader 'invalid shader' parsing error"),
);
});

#[gpu_test]
static COMPUTE_PIPELINE_DEFAULT_LAYOUT_BAD_BGL_INDEX: GpuTestConfiguration =
GpuTestConfiguration::new()
.parameters(TestParameters::default().test_features_limits())
.run_sync(|ctx| {
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);

fail(
&ctx.device,
|| {
let module = ctx.device.create_shader_module(TRIVIAL_COMPUTE_SHADER_DESC);

let pipeline =
ctx.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("compute pipeline"),
layout: None,
module: &module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});

pipeline.get_bind_group_layout(0);
},
Some("Invalid group index 0"),
);
});

#[gpu_test]
static RENDER_PIPELINE_DEFAULT_LAYOUT_BAD_MODULE: GpuTestConfiguration =
GpuTestConfiguration::new()
.parameters(TestParameters::default())
.run_sync(|ctx| {
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);

fail(
&ctx.device,
|| {
let module = ctx.device.create_shader_module(INVALID_SHADER_DESC);

let pipeline =
ctx.device
.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("render pipeline"),
layout: None,
vertex: wgpu::VertexState {
module: &module,
entry_point: Some("doesn't exist"),
compilation_options: Default::default(),
buffers: &[],
},
primitive: Default::default(),
depth_stencil: None,
multisample: Default::default(),
fragment: None,
multiview: None,
cache: None,
});

pipeline.get_bind_group_layout(0);
},
Some("Shader 'invalid shader' parsing error"),
);
});

#[gpu_test]
static RENDER_PIPELINE_DEFAULT_LAYOUT_BAD_BGL_INDEX: GpuTestConfiguration =
GpuTestConfiguration::new()
.parameters(TestParameters::default().test_features_limits())
.run_sync(|ctx| {
ctx.device.push_error_scope(wgpu::ErrorFilter::Validation);

fail(
&ctx.device,
|| {
let vs_module = ctx.device.create_shader_module(TRIVIAL_VERTEX_SHADER_DESC);
let fs_module = ctx
.device
.create_shader_module(TRIVIAL_FRAGMENT_SHADER_DESC);

let pipeline =
ctx.device
.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("render pipeline"),
layout: None,
vertex: wgpu::VertexState {
module: &vs_module,
entry_point: Some("main"),
compilation_options: Default::default(),
buffers: &[],
},
primitive: Default::default(),
depth_stencil: None,
multisample: Default::default(),
fragment: Some(wgpu::FragmentState {
module: &fs_module,
entry_point: Some("main"),
compilation_options: Default::default(),
targets: &[Some(wgpu::ColorTargetState {
format: wgpu::TextureFormat::Rgba8Unorm,
blend: None,
write_mask: wgpu::ColorWrites::ALL,
})],
}),
multiview: None,
cache: None,
});

pipeline.get_bind_group_layout(0);
},
Some("Invalid group index 0"),
);
});

#[gpu_test]
static NO_TARGETLESS_RENDER: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default())
Expand Down
56 changes: 41 additions & 15 deletions wgpu/src/backend/wgpu_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,18 @@ pub struct Queue {
error_sink: ErrorSink,
}

#[derive(Debug)]
pub struct ComputePipeline {
id: wgc::id::ComputePipelineId,
error_sink: ErrorSink,
}

#[derive(Debug)]
pub struct RenderPipeline {
id: wgc::id::RenderPipelineId,
error_sink: ErrorSink,
}

#[derive(Debug)]
pub struct ComputePass {
pass: wgc::command::ComputePass,
Expand Down Expand Up @@ -505,8 +517,8 @@ impl crate::Context for ContextWgpuCore {
type TextureData = Texture;
type QuerySetData = wgc::id::QuerySetId;
type PipelineLayoutData = wgc::id::PipelineLayoutId;
type RenderPipelineData = wgc::id::RenderPipelineId;
type ComputePipelineData = wgc::id::ComputePipelineId;
type RenderPipelineData = RenderPipeline;
type ComputePipelineData = ComputePipeline;
type PipelineCacheData = wgc::id::PipelineCacheId;
type CommandEncoderData = CommandEncoder;
type ComputePassData = ComputePass;
Expand Down Expand Up @@ -1097,7 +1109,10 @@ impl crate::Context for ContextWgpuCore {
"Device::create_render_pipeline",
);
}
id
RenderPipeline {
id,
error_sink: Arc::clone(&device_data.error_sink),
}
}
fn device_create_compute_pipeline(
&self,
Expand Down Expand Up @@ -1139,7 +1154,10 @@ impl crate::Context for ContextWgpuCore {
"Device::create_compute_pipeline",
);
}
id
ComputePipeline {
id,
error_sink: Arc::clone(&device_data.error_sink),
}
}

unsafe fn device_create_pipeline_cache(
Expand Down Expand Up @@ -1531,11 +1549,11 @@ impl crate::Context for ContextWgpuCore {
}

fn compute_pipeline_drop(&self, pipeline_data: &Self::ComputePipelineData) {
self.0.compute_pipeline_drop(*pipeline_data)
self.0.compute_pipeline_drop(pipeline_data.id)
}

fn render_pipeline_drop(&self, pipeline_data: &Self::RenderPipelineData) {
self.0.render_pipeline_drop(*pipeline_data)
self.0.render_pipeline_drop(pipeline_data.id)
}

fn pipeline_cache_drop(&self, cache_data: &Self::PipelineCacheData) {
Expand All @@ -1549,9 +1567,13 @@ impl crate::Context for ContextWgpuCore {
) -> Self::BindGroupLayoutData {
let (id, error) =
self.0
.compute_pipeline_get_bind_group_layout(*pipeline_data, index, None);
.compute_pipeline_get_bind_group_layout(pipeline_data.id, index, None);
if let Some(err) = error {
panic!("Error reflecting bind group {index}: {err}");
self.handle_error_nolabel(
&pipeline_data.error_sink,
err,
"ComputePipeline::get_bind_group_layout",
)
}
id
}
Expand All @@ -1561,11 +1583,15 @@ impl crate::Context for ContextWgpuCore {
pipeline_data: &Self::RenderPipelineData,
index: u32,
) -> Self::BindGroupLayoutData {
let (id, error) = self
.0
.render_pipeline_get_bind_group_layout(*pipeline_data, index, None);
let (id, error) =
self.0
.render_pipeline_get_bind_group_layout(pipeline_data.id, index, None);
if let Some(err) = error {
panic!("Error reflecting bind group {index}: {err}");
self.handle_error_nolabel(
&pipeline_data.error_sink,
err,
"RenderPipeline::get_bind_group_layout",
)
}
id
}
Expand Down Expand Up @@ -2108,7 +2134,7 @@ impl crate::Context for ContextWgpuCore {
) {
if let Err(cause) = self
.0
.compute_pass_set_pipeline(&mut pass_data.pass, *pipeline_data)
.compute_pass_set_pipeline(&mut pass_data.pass, pipeline_data.id)
{
self.handle_error(
&pass_data.error_sink,
Expand Down Expand Up @@ -2311,7 +2337,7 @@ impl crate::Context for ContextWgpuCore {
encoder_data: &mut Self::RenderBundleEncoderData,
pipeline_data: &Self::RenderPipelineData,
) {
wgpu_render_bundle_set_pipeline(encoder_data, *pipeline_data)
wgpu_render_bundle_set_pipeline(encoder_data, pipeline_data.id)
}

fn render_bundle_encoder_set_bind_group(
Expand Down Expand Up @@ -2434,7 +2460,7 @@ impl crate::Context for ContextWgpuCore {
) {
if let Err(cause) = self
.0
.render_pass_set_pipeline(&mut pass_data.pass, *pipeline_data)
.render_pass_set_pipeline(&mut pass_data.pass, pipeline_data.id)
{
self.handle_error(
&pass_data.error_sink,
Expand Down

0 comments on commit da474bc

Please sign in to comment.