Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add input and output validation #104

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
dae5794
Refactor variable definitions to support validation
t-bz Dec 6, 2024
24655a2
Update variable tests
t-bz Dec 6, 2024
768a440
Fix utils tests
t-bz Dec 6, 2024
3e80ceb
Fix base tests
t-bz Dec 6, 2024
370362d
Fix torch_model tests
t-bz Dec 6, 2024
708b9e5
Fix torch_module tests
t-bz Dec 6, 2024
b817d44
Add validation to models
t-bz Dec 6, 2024
47ea12a
Add validation tests
t-bz Dec 6, 2024
c9dd02c
Clean up conftest
t-bz Dec 6, 2024
6fbd75a
Remove SerializeAsAny annotation since variable serialization is expl…
t-bz Dec 6, 2024
297e2da
make default value required
pluflou Dec 10, 2024
50bd4d3
add single and double precision support in variable class
pluflou Dec 10, 2024
892de33
add support for numpy floats
pluflou Dec 10, 2024
6990180
add validation for input_dict and support for precision setting in mo…
pluflou Dec 12, 2024
0f786d0
make input dict validation strict
pluflou Dec 12, 2024
2c0eaab
catch bools in torch tensor inputs
pluflou Dec 12, 2024
eb8425d
drop np.float32 until we have a use-case
pluflou Dec 12, 2024
dc08f80
add dynamic checking for default vals, and strict flag for range chec…
pluflou Dec 12, 2024
629b77d
make type casting more consistent during input validation
pluflou Dec 14, 2024
bd322f7
make default required for inputs only and validate in base class
pluflou Dec 14, 2024
9552dd9
fix range validation tests
pluflou Dec 16, 2024
d2c2da6
remove range check within tolerance for now
pluflou Dec 16, 2024
fa508eb
add is_constant flag and default range, fix unit tests
pluflou Dec 16, 2024
b6aaddf
update example nbs
pluflou Dec 16, 2024
cbb67e1
update example notebooks
pluflou Dec 18, 2024
e409459
add nicer onnx graphs
pluflou Dec 18, 2024
e6c7b6a
simplify validation config
pluflou Dec 18, 2024
45dfde8
Merge branch 'main' of https://github.com/slaclab/lume-model into val…
pluflou Dec 18, 2024
cba3073
fix tests after adjusting config validation format
pluflou Dec 19, 2024
2de7dd2
remove setting torch default dtype
pluflou Dec 19, 2024
6c83403
add some tests
pluflou Dec 20, 2024
b10521a
adjust docstrings
pluflou Dec 20, 2024
6bb2e07
update README
pluflou Dec 20, 2024
f34926f
reset precision to double to fix tests
pluflou Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 131 additions & 13 deletions examples/custom_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"outputs": [],
"source": [
"from lume_model.base import LUMEBaseModel\n",
"from lume_model.variables import ScalarInputVariable, ScalarOutputVariable"
"from lume_model.variables import ScalarVariable"
]
},
{
Expand All @@ -37,7 +37,7 @@
"outputs": [],
"source": [
"class ExampleModel(LUMEBaseModel):\n",
" def evaluate(self, input_dict):\n",
" def _evaluate(self, input_dict):\n",
" output_dict = {\n",
" \"output1\": input_dict[self.input_variables[0].name] ** 2,\n",
" \"output2\": input_dict[self.input_variables[1].name] ** 2,\n",
Expand All @@ -56,26 +56,26 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"id": "97946e64-062d-47d4-8d0c-d7e02a335a56",
"metadata": {},
"outputs": [],
"source": [
"input_variables = [\n",
" ScalarInputVariable(name=\"input1\", default=0.1, value_range=[0.0, 1.0]),\n",
" ScalarInputVariable(name=\"input2\", default=0.2, value_range=[0.0, 1.0]),\n",
" ScalarVariable(name=\"input1\", default_value=0.1),\n",
" ScalarVariable(name=\"input2\", default_value=0.2, value_range=[0.0, 1.0]),\n",
"]\n",
"output_variables = [\n",
" ScalarOutputVariable(name=\"output1\"),\n",
" ScalarOutputVariable(name=\"output2\"),\n",
" ScalarVariable(name=\"output1\"),\n",
" ScalarVariable(name=\"output2\"),\n",
"]\n",
"\n",
"m = ExampleModel(input_variables=input_variables, output_variables=output_variables)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"id": "50aae4be-0d6e-456f-83e8-3a84d6d78f84",
"metadata": {},
"outputs": [
Expand All @@ -85,7 +85,89 @@
"{'output1': 0.09, 'output2': 0.36}"
]
},
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"input_dict = {\n",
" \"input1\": 0.3,\n",
" \"input2\": 0.6,\n",
"}\n",
"m.evaluate(input_dict)"
]
},
{
"cell_type": "markdown",
"id": "ac3f47b9-b316-4baa-b747-dbc57ab921e5",
"metadata": {},
"source": [
"# Save to YAML"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "da644ee0-9bac-4343-81bf-41b2d7571283",
"metadata": {},
"outputs": [],
"source": [
"m.dump(\"example_model.yml\")"
]
},
{
"cell_type": "markdown",
"id": "01e95117-a2f1-4cd6-93ed-70946ee8e8a4",
"metadata": {},
"source": [
"# Load from YAML"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6b9eb663-f8b1-449e-90cc-ebde0f311d02",
"metadata": {},
"outputs": [],
"source": [
"m = ExampleModel(\"example_model.yml\")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "d49d8368-4b68-4bfb-868d-a61164d65724",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ExampleModel(input_variables=[ScalarVariable(name='input1', default_value=0.1, value_range=(-inf, inf), is_constant=False, unit=None), ScalarVariable(name='input2', default_value=0.2, value_range=(0.0, 1.0), is_constant=False, unit=None)], output_variables=[ScalarVariable(name='output1', default_value=None, value_range=(-inf, inf), is_constant=False, unit=None), ScalarVariable(name='output2', default_value=None, value_range=(-inf, inf), is_constant=False, unit=None)], input_validation_config=None, output_validation_config=None)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "92f1e6e1-53c6-4195-8dd7-00163ec13e73",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'output1': 0.09, 'output2': 0.36}"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -98,20 +180,56 @@
"m.evaluate(input_dict)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "33d97167-6a84-42f7-a40a-1a0169dafc5d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\"model_class\": \"ExampleModel\", \"input_variables\": {\"input1\": {\"variable_class\": \"ScalarVariable\", \"default_value\": 0.1, \"value_range\": [-Infinity, Infinity], \"is_constant\": false}, \"input2\": {\"variable_class\": \"ScalarVariable\", \"default_value\": 0.2, \"value_range\": [0.0, 1.0], \"is_constant\": false}}, \"output_variables\": {\"output1\": {\"variable_class\": \"ScalarVariable\", \"value_range\": [-Infinity, Infinity], \"is_constant\": false}, \"output2\": {\"variable_class\": \"ScalarVariable\", \"value_range\": [-Infinity, Infinity], \"is_constant\": false}}, \"input_validation_config\": null, \"output_validation_config\": null}\n"
]
}
],
"source": [
"print(m.json())"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "b07f8b91-d4a2-430f-ae24-6aea581fad96",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'model_class': 'ExampleModel', 'input_variables': [{'variable_class': 'ScalarVariable', 'name': 'input1', 'default_value': 0.1, 'value_range': (-inf, inf), 'is_constant': False, 'unit': None}, {'variable_class': 'ScalarVariable', 'name': 'input2', 'default_value': 0.2, 'value_range': (0.0, 1.0), 'is_constant': False, 'unit': None}], 'output_variables': [{'variable_class': 'ScalarVariable', 'name': 'output1', 'default_value': None, 'value_range': (-inf, inf), 'is_constant': False, 'unit': None}, {'variable_class': 'ScalarVariable', 'name': 'output2', 'default_value': None, 'value_range': (-inf, inf), 'is_constant': False, 'unit': None}], 'input_validation_config': None, 'output_validation_config': None}\n"
]
}
],
"source": [
"print(m.dict())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a547f3c-1706-4b32-bab6-9687627f6a78",
"id": "a35c6fe3-d61b-42ed-b835-f79669ce1508",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:lume-model-dev]",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "conda-env-lume-model-dev-py"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -123,7 +241,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.9.20"
}
},
"nbformat": 4,
Expand Down
Loading
Loading