Skip to content

Commit

Permalink
PyO3: Add optional candle.onnx module (#1282)
Browse files Browse the repository at this point in the history
* Start onnx integration

* Merge remote-tracking branch 'upstream/main' into feat/pyo3-onnx

* Implement ONNXModel

* `fmt`

* add `onnx` flag to python ci

* Pin `protoc` to `25.0`

* Setup `protoc` in wheel builds

* Build wheels with `onnx`

* Install `protoc` in manylinux containers

* `apt` -> `yum`

* Download `protoc` via bash script

* Back to `manylinux: auto`

* Disable `onnx` builds for linux
  • Loading branch information
LLukas22 authored Nov 8, 2023
1 parent 7920b45 commit f3a4f3d
Show file tree
Hide file tree
Showing 10 changed files with 343 additions and 6 deletions.
Binary file modified .github/workflows/maturin.yml
Binary file not shown.
8 changes: 7 additions & 1 deletion .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,20 @@ jobs:
path: ~/.cargo/registry
key: ${{ runner.os }}-cargo-registry-${{ hashFiles('**/Cargo.lock') }}

- name: Install Protoc
uses: arduino/setup-protoc@v2
with:
version: "25.0"
repo-token: ${{ secrets.GITHUB_TOKEN }}

- name: Install
working-directory: ./candle-pyo3
run: |
python -m venv .env
source .env/bin/activate
pip install -U pip
pip install pytest maturin black
python -m maturin develop -r
python -m maturin develop -r --features onnx
- name: Check style
working-directory: ./candle-pyo3
Expand Down
2 changes: 1 addition & 1 deletion candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ fn get_attr_opt<'a, T: Attr + ?Sized>(
}
}

fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
pub fn get_tensor(t: &onnx::TensorProto, name: &str) -> Result<Tensor> {
let dims: Vec<usize> = t.dims.iter().map(|&x| x as usize).collect();
match DataType::try_from(t.data_type) {
Ok(DataType::Int32) => {
Expand Down
2 changes: 1 addition & 1 deletion candle-onnx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod onnx {
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
}

mod eval;
pub mod eval;
pub use eval::{dtype, simple_eval};

pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
Expand Down
3 changes: 3 additions & 0 deletions candle-pyo3/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ crate-type = ["cdylib"]
accelerate-src = { workspace = true, optional = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
candle-onnx = {path= "../candle-onnx", version = "0.3.0", optional = true}
half = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
pyo3 = { version = "0.20.0", features = ["extension-module", "abi3-py38"] }
Expand All @@ -29,3 +30,5 @@ default = []
accelerate = ["dep:accelerate-src", "candle/accelerate"]
cuda = ["candle/cuda"]
mkl = ["dep:intel-mkl-src","candle/mkl"]
onnx = ["dep:candle-onnx"]

5 changes: 5 additions & 0 deletions candle-pyo3/py_src/candle/onnx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Generated content DO NOT EDIT
from .. import onnx

ONNXModel = onnx.ONNXModel
ONNXTensorDescription = onnx.ONNXTensorDescription
89 changes: 89 additions & 0 deletions candle-pyo3/py_src/candle/onnx/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Generated content DO NOT EDIT
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Sequence
from os import PathLike
from candle.typing import _ArrayLike, Device, Scalar, Index, Shape
from candle import Tensor, DType, QTensor

class ONNXModel:
"""
A wrapper around an ONNX model.
"""

def __init__(self, path: str):
pass
@property
def doc_string(self) -> str:
"""
The doc string of the model.
"""
pass
@property
def domain(self) -> str:
"""
The domain of the operator set of the model.
"""
pass
def initializers(self) -> Dict[str, Tensor]:
"""
Get the weights of the model.
"""
pass
@property
def inputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The inputs of the model.
"""
pass
@property
def ir_version(self) -> int:
"""
The version of the IR this model targets.
"""
pass
@property
def model_version(self) -> int:
"""
The version of the model.
"""
pass
@property
def outputs(self) -> Optional[Dict[str, ONNXTensorDescription]]:
"""
The outputs of the model.
"""
pass
@property
def producer_name(self) -> str:
"""
The producer of the model.
"""
pass
@property
def producer_version(self) -> str:
"""
The version of the producer of the model.
"""
pass
def run(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
"""
Run the model on the given inputs.
"""
pass

class ONNXTensorDescription:
"""
A wrapper around an ONNX tensor description.
"""

@property
def dtype(self) -> DType:
"""
The data type of the tensor.
"""
pass
@property
def shape(self) -> Tuple[Union[int, str, Any]]:
"""
The shape of the tensor.
"""
pass
22 changes: 19 additions & 3 deletions candle-pyo3/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ extern crate accelerate_src;

use ::candle::{quantized::QTensor, DType, Device, Tensor, WithDType};

mod utils;
use utils::wrap_err;

mod shape;
use shape::{PyShape, PyShapeWithHole};

pub fn wrap_err(err: ::candle::Error) -> PyErr {
PyErr::new::<PyValueError, _>(format!("{err:?}"))
}
#[cfg(feature = "onnx")]
mod onnx;

#[derive(Clone, Debug)]
#[pyclass(name = "Tensor")]
Expand Down Expand Up @@ -1559,6 +1561,14 @@ fn candle_functional_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
Ok(())
}

#[cfg(feature = "onnx")]
fn candle_onnx_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
use onnx::{PyONNXModel, PyONNXTensorDescriptor};
m.add_class::<PyONNXModel>()?;
m.add_class::<PyONNXTensorDescriptor>()?;
Ok(())
}

#[pymodule]
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
Expand All @@ -1567,6 +1577,12 @@ fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let nn = PyModule::new(py, "functional")?;
candle_functional_m(py, nn)?;
m.add_submodule(nn)?;
#[cfg(feature = "onnx")]
{
let onnx = PyModule::new(py, "onnx")?;
candle_onnx_m(py, onnx)?;
m.add_submodule(onnx)?;
}
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;
m.add_class::<PyDType>()?;
Expand Down
Loading

0 comments on commit f3a4f3d

Please sign in to comment.