Skip to content

Commit

Permalink
update exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
zshiqiang committed Dec 5, 2023
2 parents 18ffcc5 + b43ac5e commit f923482
Show file tree
Hide file tree
Showing 17 changed files with 1,026 additions and 123 deletions.
65 changes: 60 additions & 5 deletions .readthedocs.yml
Original file line number Diff line number Diff line change
@@ -1,9 +1,64 @@
# Read the Docs configuration file
# Read the Docs configuration file for Sphinx projects

# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details

conda:
file: docs/environment.yml

# Required

version: 2


# Set the OS, Python version and other tools you might need

build:

os: ubuntu-22.04

tools:

python: "3.8"

# You can also specify other tool versions:

# nodejs: "20"

# rust: "1.70"

# golang: "1.20"


# Build documentation in the "docs/" directory with Sphinx

sphinx:

configuration: docs/conf.py

# You can configure Sphinx to use a different builder, for instance use the dirhtml builder for simpler URLs

# builder: "dirhtml"

# Fail on all warnings to avoid broken references

# fail_on_warning: true


# Optionally build your docs in additional formats such as PDF and ePub

# formats:

# - pdf

# - epub


# Optional but recommended, declare the Python requirements required

# to build your documentation

# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html

python:
version: 3.8
setup_py_install: true

install:

- requirements: docs/requirements.txt
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
:align: center
:width: 200px

.. image:: https://github.com/cog-imperial/OMLT/workflows/CI/badge.svg?branch=main
.. image:: https://github.com/cog-imperial/OMLT/actions/workflows/main.yml/badge.svg
:target: https://github.com/cog-imperial/OMLT/actions?workflow=CI
:alt: CI Status

Expand Down
11 changes: 0 additions & 11 deletions docs/environment.yml

This file was deleted.

10 changes: 10 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Required dependencies for Sphinx documentation
sphinx
sphinx-rtd-theme
numpy
pyomo
networkx
onnx
tensorflow
linear-tree
importlib-metadata
14 changes: 12 additions & 2 deletions src/omlt/gbt/gbt_formulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,13 +197,23 @@ def _branching_y(tree_id, branch_node_id):
node_mask = (nodes_tree_ids == tree_id) & (nodes_node_ids == branch_node_id)
feature_id = nodes_feature_ids[node_mask]
branch_value = nodes_values[node_mask]
assert len(feature_id) == 1 and len(branch_value) == 1
if len(branch_value) != 1:
raise ValueError(
f"The given tree_id and branch_node_id do not uniquely identify a branch value."
)
if len(feature_id) != 1:
raise ValueError(
f"The given tree_id and branch_node_id do not uniquely identify a feature."
)
feature_id = feature_id[0]
branch_value = branch_value[0]
(branch_y_idx,) = np.where(
branch_value_by_feature_id[feature_id] == branch_value
)
assert len(branch_y_idx) == 1
if len(branch_y_idx) != 1:
raise ValueError(
f"The given tree_id and branch_node_id do not uniquely identify a branch index."
)
return block.y[feature_id, branch_y_idx[0]]

def _sum_of_z_l(tree_id, start_node_id):
Expand Down
26 changes: 19 additions & 7 deletions src/omlt/gbt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,24 +56,36 @@ def scaling_object(self, scaling_object):
def _model_num_inputs(model):
"""Returns the number of input variables"""
graph = model.graph
assert len(graph.input) == 1
if len(graph.input) != 1:
raise ValueError(
f"Model graph input field is multi-valued {graph.input}. A single value is required."
)
return _tensor_size(graph.input[0])


def _model_num_outputs(model):
"""Returns the number of output variables"""
graph = model.graph
assert len(graph.output) == 1
if len(graph.output) != 1:
raise ValueError(
f"Model graph output field is multi-valued {graph.output}. A single value is required."
)
return _tensor_size(graph.output[0])


def _tensor_size(tensor):
"""Returns the size of an input tensor"""
tensor_type = tensor.type.tensor_type
size = None
for dim in tensor_type.shape.dim:
if dim.dim_value is not None and dim.dim_value > 0:
assert size is None
size = dim.dim_value
assert size is not None
dim_values = [
dim.dim_value
for dim in tensor_type.shape.dim
if dim.dim_value is not None and dim.dim_value > 0
]
if len(dim_values) == 1:
size = dim_values[0]
elif dim_values == []:
raise ValueError(f"Tensor {tensor} has no positive dimensions.")
else:
raise ValueError(f"Tensor {tensor} has multiple positive dimensions.")
return size
Loading

0 comments on commit f923482

Please sign in to comment.