Skip to content

Commit

Permalink
chore: make internal functions and structs private
Browse files Browse the repository at this point in the history
  • Loading branch information
pixelspark committed Jul 31, 2022
1 parent e07b942 commit 7ef3572
Show file tree
Hide file tree
Showing 6 changed files with 15 additions and 35 deletions.
3 changes: 2 additions & 1 deletion wonnx/src/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Compiles individual ONNX ops to a WebGPU shader using WGSL templates
use crate::utils::{
ceil, get_attribute, AttributeNotFoundError, DataTypeError, MultiType, ScalarType, Shape,
};
Expand All @@ -12,7 +13,7 @@ pub const MAX_COMPUTE_WORKGROUPS_PER_DIMENSION: u32 = 65535;
/// The maximum workgroup size per dimension (see <https://www.w3.org/TR/webgpu/#dom-supported-limits-maxcomputeworkgroupsizex>)
pub const MAX_WORKGROUP_SIZE_X: u32 = 256;
pub const MAX_WORKGROUP_SIZE_Y: u32 = 256;
pub const MAX_WORKGROUP_SIZE_Z: u32 = 64;
// pub const MAX_WORKGROUP_SIZE_Z: u32 = 64;

lazy_static! {
// Templates for shader source code that we generate for nodes
Expand Down
1 change: 1 addition & 0 deletions wonnx/src/gpu.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Manages execution of shader code and buffer allocation on the GPU
use std::{
borrow::Cow,
collections::{HashMap, HashSet},
Expand Down
28 changes: 2 additions & 26 deletions wonnx/src/ir.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! DAG representation of ONNX ops allowing for transformations and optimizations before compilation
use crate::onnx::{ModelProto, NodeProto, TensorProto, ValueInfoProto};
use crate::utils::{DataTypeError, ScalarType, Shape};
use crate::utils::{DataTypeError, Shape};
use std::borrow::Cow;
use std::fmt::Debug;
use std::hash::Hash;
Expand Down Expand Up @@ -82,18 +83,6 @@ impl<'m> NodeDefinition<'m> {
NodeDefinition::Missing => Cow::from(""),
}
}

pub fn output_name(&self, output_index: usize) -> Cow<'_, str> {
match self {
NodeDefinition::Operator(op_def) => {
Cow::Borrowed(&op_def.proto.get_output()[output_index])
}
NodeDefinition::Tensor(proto) => Cow::from(proto.get_name()),
NodeDefinition::Input(proto) => Cow::from(proto.get_name()),
NodeDefinition::Outputs { .. } => panic!("can't get output name for outputs node"),
NodeDefinition::Missing => panic!("can't get output name for missing node"),
}
}
}

impl NodeProto {
Expand Down Expand Up @@ -294,19 +283,6 @@ impl<'model> Node<'model> {
inputs: output_nodes?,
}))
}

pub fn output_shape(&self, output_index: usize) -> Result<Shape, IrError> {
Ok(match (&self.definition, output_index) {
(NodeDefinition::Operator(op_def), index) => op_def.output_shapes[index].clone(),
(NodeDefinition::Tensor(tensor_proto), 0) => Shape::from(
ScalarType::from_i32(tensor_proto.get_data_type())?,
tensor_proto.get_dims(),
),
(NodeDefinition::Input(input_proto), 0) => input_proto.get_shape()?,
(NodeDefinition::Outputs { .. }, _) => panic!("output node has no outputs!"),
(_, _) => panic!("node has no output at index {}", output_index),
})
}
}

impl<'model> Debug for NodeDefinition<'model> {
Expand Down
8 changes: 4 additions & 4 deletions wonnx/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pub mod compiler;
mod compiler;
mod gpu;
pub mod ir;
mod ir;
pub mod onnx;
pub mod optimizer;
pub mod resource;
mod optimizer;
mod resource;
pub mod utils;

#[macro_use]
Expand Down
1 change: 1 addition & 0 deletions wonnx/src/optimizer.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Optimizer that walks the DAG and transforms or coalesces ops for quicker execution
use protobuf::RepeatedField;
use std::{borrow::Cow, collections::HashMap, sync::Arc};
use thiserror::Error;
Expand Down
9 changes: 5 additions & 4 deletions wonnx/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//! Various utilities to deal with the ONNX format structure
use protobuf::ProtobufEnum;
use protobuf::RepeatedField;
use serde::Serialize;
Expand All @@ -17,7 +18,7 @@ use thiserror::Error;

/* Minimum size of a buffer you can create with wgpu. Creating buffers smaller than this leads to panic "Validation
* error: buffer binding size X is less than minimum 64" in Device::create_bind_group */
pub const MINIMUM_BUFFER_SIZE_BYTES: u64 = 64;
pub(crate) const MINIMUM_BUFFER_SIZE_BYTES: u64 = 64;

#[derive(Debug, Clone, PartialEq)]
pub struct Shape {
Expand Down Expand Up @@ -278,7 +279,7 @@ impl Display for ScalarType {
/// struct Block {
/// data: [[stride( dt.size_bytes() )]] dt.wgsl_type_name();
/// };
pub enum MultiType {
pub(crate) enum MultiType {
Scalar(ScalarType),
Vec(ScalarType, usize),
Mat(ScalarType, usize, usize),
Expand Down Expand Up @@ -350,7 +351,7 @@ pub struct AttributeNotFoundError {
node_name: String,
}

pub fn get_attribute<T: std::convert::From<onnx::AttributeProto>>(
pub(crate) fn get_attribute<T: std::convert::From<onnx::AttributeProto>>(
attribute: &str,
default: Option<T>,
node: &onnx::NodeProto,
Expand All @@ -371,7 +372,7 @@ pub fn get_attribute<T: std::convert::From<onnx::AttributeProto>>(
}

/// Divide a number by the indicated dividend, then round up to the next multiple of the dividend if there is a rest.
pub fn ceil(num: u64, div: u64) -> u64 {
pub(crate) fn ceil(num: u64, div: u64) -> u64 {
num / div + (num % div != 0) as u64
}

Expand Down

0 comments on commit 7ef3572

Please sign in to comment.