Skip to content

Commit

Permalink
feat: adding assertions and matching shader module names with compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
SkillerRaptor committed Aug 26, 2024
1 parent 4894590 commit 040ba96
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 75 deletions.
4 changes: 2 additions & 2 deletions crates/hyper_engine/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ impl Engine {

let opaque_vertex_shader = graphics_device.create_shader_module(&ShaderModuleDescriptor {
path: "./assets/shaders/opaque_shader.hlsl",
entry: "vs_main",
entry_point: "vs_main",
stage: ShaderStage::Vertex,
});

let opaque_fragment_shader =
graphics_device.create_shader_module(&ShaderModuleDescriptor {
path: "./assets/shaders/opaque_shader.hlsl",
entry: "fs_main",
entry_point: "fs_main",
stage: ShaderStage::Fragment,
});

Expand Down
13 changes: 6 additions & 7 deletions crates/hyper_rhi/src/d3d12/shader_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use crate::{

#[derive(Debug)]
pub(crate) struct ShaderModule {
entry: String,
entry_point: String,
stage: ShaderStage,
code: Vec<u8>,
}
Expand All @@ -20,14 +20,13 @@ impl ShaderModule {
pub(super) fn new(descriptor: &ShaderModuleDescriptor) -> Self {
let code = shader_compiler::compile(
descriptor.path,
descriptor.entry,
descriptor.entry_point,
descriptor.stage,
OutputApi::D3D12,
)
.unwrap();
);

Self {
entry: descriptor.entry.to_owned(),
entry_point: descriptor.entry_point.to_owned(),
stage: descriptor.stage,
code,
}
Expand All @@ -39,8 +38,8 @@ impl ShaderModule {
}

impl crate::shader_module::ShaderModule for ShaderModule {
fn entry(&self) -> &str {
&self.entry
fn entry_point(&self) -> &str {
&self.entry_point
}

fn stage(&self) -> ShaderStage {
Expand Down
83 changes: 31 additions & 52 deletions crates/hyper_rhi/src/shader_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,80 +4,59 @@
// SPDX-License-Identifier: MIT
//

use std::{fs, io, path::Path};

use hassle_rs::HassleError;
use thiserror::Error;
use std::{fs, path::Path};

use crate::shader_module::ShaderStage;

#[derive(Debug, Error)]
pub enum ShaderCompilationError {
#[error("failed to find shader '{0}'")]
NotFound(String),

#[error("failed to read directory as shader '{0}'")]
NotAFile(String),

#[error("failed to read shader '{1}'")]
Io(io::Error, String),

#[error("failed to compile shader '{1}'")]
Compilation(HassleError, String),

#[error("failed to validate shader '{1}'")]
Validation(HassleError, String),
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub(crate) enum OutputApi {
D3D12,
Vulkan,
}

pub(crate) fn compile(
path: &str,
entry: &str,
pub(crate) fn compile<P>(
path: P,
entry_point: &str,
stage: ShaderStage,
output_api: OutputApi,
) -> Result<Vec<u8>, ShaderCompilationError> {
let fs_path = Path::new(path);
) -> Vec<u8>
where
P: AsRef<Path>,
{
assert!(!entry_point.is_empty());

if !fs_path.exists() {
return Err(ShaderCompilationError::NotFound(path.to_owned()));
}

if !fs_path.is_file() {
return Err(ShaderCompilationError::NotAFile(path.to_owned()));
}
let path = path.as_ref();

let file_name_os = fs_path.file_name().unwrap();
let file_name = file_name_os.to_str().unwrap();
assert!(path.exists());
assert!(path.is_file());

let profile = match stage {
let source_name = path.file_name().unwrap().to_str().unwrap();
let shader_text = fs::read_to_string(path).unwrap();
let target_profile = match stage {
ShaderStage::Compute => "cs_6_6",
ShaderStage::Fragment => "ps_6_6",
ShaderStage::Vertex => "vs_6_6",
};

let source = fs::read_to_string(fs_path)
.map_err(|error| ShaderCompilationError::Io(error, path.to_owned()))?;

let mut args = Vec::new();
let mut defines = Vec::new();
if output_api == OutputApi::Vulkan {
args.push("-spirv");
defines.push(("HYPER_ENGINE_VULKAN", None));
}
let (args, defines) = if output_api == OutputApi::Vulkan {
(vec!["-spirv"], vec![("HYPER_ENGINE_VULKAN", None)])
} else {
(Vec::new(), vec![("HYPER_ENGINE_D3D12", None)])
};

let mut shader_bytes =
hassle_rs::compile_hlsl(file_name, &source, entry, profile, &args, &defines)
.map_err(|error| ShaderCompilationError::Compilation(error, path.to_owned()))?;
let mut shader_bytes = hassle_rs::compile_hlsl(
source_name,
&shader_text,
entry_point,
&target_profile,
&args,
&defines,
)
.unwrap();

if output_api == OutputApi::D3D12 {
shader_bytes = hassle_rs::validate_dxil(&shader_bytes)
.map_err(|error| ShaderCompilationError::Validation(error, path.to_owned()))?;
shader_bytes = hassle_rs::validate_dxil(&shader_bytes).unwrap();
}

Ok(shader_bytes)
shader_bytes
}
4 changes: 2 additions & 2 deletions crates/hyper_rhi/src/shader_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ pub enum ShaderStage {
#[derive(Clone, Debug)]
pub struct ShaderModuleDescriptor<'a> {
pub path: &'a str,
pub entry: &'a str,
pub entry_point: &'a str,
pub stage: ShaderStage,
}

pub trait ShaderModule: Debug + Downcast {
fn entry(&self) -> &str;
fn entry_point(&self) -> &str;
fn stage(&self) -> ShaderStage;
}

Expand Down
4 changes: 2 additions & 2 deletions crates/hyper_rhi/src/vulkan/graphics_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ impl GraphicsPipeline {
resource_handler: &Arc<ResourceHandler>,
descriptor: &GraphicsPipelineDescriptor,
) -> Self {
let vertex_shader_entry = CString::new(descriptor.vertex_shader.entry()).unwrap();
let vertex_shader_entry = CString::new(descriptor.vertex_shader.entry_point()).unwrap();
let vertex_shader_stage_create_info = vk::PipelineShaderStageCreateInfo::default()
.stage(vk::ShaderStageFlags::VERTEX)
.module(
Expand All @@ -41,7 +41,7 @@ impl GraphicsPipeline {
)
.name(&vertex_shader_entry);

let fragment_shader_entry = CString::new(descriptor.fragment_shader.entry()).unwrap();
let fragment_shader_entry = CString::new(descriptor.fragment_shader.entry_point()).unwrap();
let fragment_shader_stage_create_info = vk::PipelineShaderStageCreateInfo::default()
.stage(vk::ShaderStageFlags::FRAGMENT)
.module(
Expand Down
18 changes: 8 additions & 10 deletions crates/hyper_rhi/src/vulkan/shader_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{

#[derive(Debug)]
pub(crate) struct ShaderModule {
entry: String,
entry_point: String,
stage: ShaderStage,
shader_module: vk::ShaderModule,

Expand All @@ -31,16 +31,14 @@ impl ShaderModule {
) -> Self {
let bytes = shader_compiler::compile(
descriptor.path,
descriptor.entry,
descriptor.entry_point,
descriptor.stage,
OutputApi::Vulkan,
)
.unwrap();
);

let (prefix, code, suffix) = unsafe { bytes.align_to::<u32>() };
if !prefix.is_empty() || !suffix.is_empty() {
panic!("unaligned shader module code");
}
assert!(prefix.is_empty());
assert!(suffix.is_empty());

let create_info = vk::ShaderModuleCreateInfo::default().code(code);

Expand All @@ -52,7 +50,7 @@ impl ShaderModule {
.unwrap();

Self {
entry: descriptor.entry.to_owned(),
entry_point: descriptor.entry_point.to_owned(),
stage: descriptor.stage,
shader_module,

Expand All @@ -76,8 +74,8 @@ impl Drop for ShaderModule {
}

impl crate::shader_module::ShaderModule for ShaderModule {
fn entry(&self) -> &str {
&self.entry
fn entry_point(&self) -> &str {
&self.entry_point
}

fn stage(&self) -> ShaderStage {
Expand Down

0 comments on commit 040ba96

Please sign in to comment.