diff --git a/Cargo.lock b/Cargo.lock index 92ce0871..4a59492a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4202,7 +4202,6 @@ dependencies = [ "env_logger", "futures", "image", - "lazy_static", "log", "ndarray", "num", diff --git a/wonnx/Cargo.toml b/wonnx/Cargo.toml index dc8c150c..27eba8d8 100644 --- a/wonnx/Cargo.toml +++ b/wonnx/Cargo.toml @@ -27,7 +27,6 @@ bytemuck = { version = "1.9.1", features = ["extern_crate_alloc"] } protobuf = { version = "2.27.1", features = ["with-bytes"] } log = "0.4.17" tera = { version = "1.15.0", default-features = false } -lazy_static = "1.4.0" thiserror = "1.0.31" serde_derive = "1.0.137" serde = { version = "1.0.137", features = ["derive"] } diff --git a/wonnx/src/compiler.rs b/wonnx/src/compiler.rs index ba757c20..49126346 100644 --- a/wonnx/src/compiler.rs +++ b/wonnx/src/compiler.rs @@ -1,4 +1,6 @@ //! Compiles individual ONNX ops to a WebGPU shader using WGSL templates +use std::sync::OnceLock; + use crate::utils::{ ceil, AttributeNotFoundError, DataTypeError, MultiType, NodeAttributes, ScalarType, Shape, }; @@ -15,9 +17,10 @@ pub const MAX_WORKGROUP_SIZE_X: u32 = 256; pub const MAX_WORKGROUP_SIZE_Y: u32 = 256; // pub const MAX_WORKGROUP_SIZE_Z: u32 = 64; -lazy_static! { - // Templates for shader source code that we generate for nodes - pub static ref TEMPLATES: Tera = { +static TEMPLATES: OnceLock = OnceLock::new(); + +fn get_templates() -> &'static Tera { + TEMPLATES.get_or_init(|| { let mut tera = Tera::default(); tera.add_raw_template( "endomorphism/activation.wgsl", @@ -66,8 +69,9 @@ lazy_static! { .unwrap(); tera.add_raw_template( "matrix/pad.wgsl", - include_str!("../templates/matrix/pad.wgsl") - ).unwrap(); + include_str!("../templates/matrix/pad.wgsl"), + ) + .unwrap(); tera.add_raw_template( "matrix/resize.wgsl", include_str!("../templates/matrix/resize.wgsl"), @@ -136,7 +140,7 @@ lazy_static! { ) .unwrap(); tera - }; + }) } pub struct CompiledNode { @@ -1428,7 +1432,7 @@ pub fn compile( context.insert("mat3x3_stride", &(48)); // Render template - let shader = TEMPLATES + let shader = get_templates() .render(node_template.template, &context) .expect("failed to render shader"); diff --git a/wonnx/src/lib.rs b/wonnx/src/lib.rs index de622522..929edf23 100644 --- a/wonnx/src/lib.rs +++ b/wonnx/src/lib.rs @@ -6,9 +6,6 @@ mod optimizer; mod resource; pub mod utils; -#[macro_use] -extern crate lazy_static; - pub use compiler::CompileError; pub use gpu::GpuError; use ir::IrError;