From 669c7eee1b0e21824e794ae233fffc8684e5cf38 Mon Sep 17 00:00:00 2001 From: Romain Picard Date: Mon, 4 Mar 2024 18:29:21 +0100 Subject: [PATCH] fix predictions with boolean features (#12) fixes #11 ONNX converts boolean to strings as 0 and 1 while the EBM python implementation expects False/True. To fix this, we replace the bins keys with values returned by ONNX. Signed-off-by: Romain Picard --- ebm2onnx/convert.py | 8 +++++ ebm2onnx/operators.py | 2 +- tests/test_convert.py | 2 +- tests/test_operators.py | 66 +++++++++++++++++++++++++++-------------- tests/utils.py | 14 +++++++-- 5 files changed, 66 insertions(+), 26 deletions(-) diff --git a/ebm2onnx/convert.py b/ebm2onnx/convert.py index b92a58e..011d204 100644 --- a/ebm2onnx/convert.py +++ b/ebm2onnx/convert.py @@ -123,6 +123,14 @@ def to_onnx(model, dtype, name="ebm", feature_dtype = infer_features_dtype(dtype, feature_name) part = graph.create_input(root, feature_name, feature_dtype, [None]) + if feature_dtype == onnx.TensorProto.BOOL: + # ONNX converts booleans to strings 0/1, not False/True + col_mapping = { + '0': col_mapping['False'], + '1': col_mapping['True'], + } + # replace inplace to re-use it in interactions + model.bins_[feature_group[0]][0] = col_mapping if feature_dtype != onnx.TensorProto.STRING: part = ops.cast(onnx.TensorProto.STRING)(part) part = ops.flatten()(part) diff --git a/ebm2onnx/operators.py b/ebm2onnx/operators.py index c545676..7c1f431 100644 --- a/ebm2onnx/operators.py +++ b/ebm2onnx/operators.py @@ -41,7 +41,7 @@ def _argmax(g): def cast(to): - def _cast(g): + def _cast(g): cast_result_name = g.generate_name('cast_result') nodes = [ onnx.helper.make_node("Cast", [g.transients[0].name], [cast_result_name], to=to), diff --git a/tests/test_convert.py b/tests/test_convert.py index 00060f5..0f790dd 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -27,7 +27,7 @@ def train_titanic_binary_classification(interactions=0, with_categorical=False): feature_types=['continuous', 'continuous', 'continuous', 'continuous'] feature_columns = ['Age', 'Fare', 'Pclass', 'Old'] else: - feature_types=['continuous', 'continuous', 'nominal', 'continuous', 'nominal'] + feature_types=['continuous', 'continuous', 'nominal', 'nominal', 'nominal'] feature_columns = ['Age', 'Fare', 'Pclass', 'Old', 'Embarked'] label_column = "Survived" diff --git a/tests/test_operators.py b/tests/test_operators.py index c22d911..f7f1657 100644 --- a/tests/test_operators.py +++ b/tests/test_operators.py @@ -1,3 +1,4 @@ +import pytest import ebm2onnx.graph as graph import ebm2onnx.operators as ops @@ -24,29 +25,50 @@ def test_add(): ) -def test_cast(): +@pytest.mark.parametrize( + "from_type,to_type,input,output", + [ + pytest.param( + onnx.TensorProto.INT64, + onnx.TensorProto.FLOAT, + {'i': [[1], [2], [11], [4]]}, + [[[1.0], [2.0], [11.0], [4.0]]], + id='int64_to_float' + ), + pytest.param( + onnx.TensorProto.INT64, + onnx.TensorProto.STRING, + {'i': [[1], [2], [11], [4]]}, + [[["1"], ["2"], ["11"], ["4"]]], + id='int64_to_string' + ), + pytest.param( + onnx.TensorProto.BOOL, + onnx.TensorProto.UINT8, + {'i': [[False], [True]]}, + [[[0], [1]]], + id='bool_to_uint8' + ), + pytest.param( + onnx.TensorProto.BOOL, + onnx.TensorProto.STRING, + {'i': [[False], [True]]}, + [[["0"], ["1"]]], + id='bool_to_string' + ), + ] +) +def test_cast(from_type, to_type, input, output): g = graph.create_graph() - - i = graph.create_input(g, "i", onnx.TensorProto.INT64, [None, 1]) - - l = ops.cast(onnx.TensorProto.FLOAT)(i) - l = graph.add_output(l, l.transients[0].name, onnx.TensorProto.FLOAT, [None, 1]) - - assert_model_result(l, - input={ - 'i': [ - [1], - [2], - [11], - [4], - ] - }, - expected_result=[[ - [1.0], - [2.0], - [11.0], - [4.0] - ]] + i = graph.create_input(g, "i", from_type, [None, 1]) + l = ops.cast(to_type)(i) + l = graph.add_output(l, l.transients[0].name, to_type, [None, 1]) + + assert_model_result( + l, + input=input, + expected_result=output, + exact_match=to_type in [onnx.TensorProto.INT64, onnx.TensorProto.STRING] ) diff --git a/tests/utils.py b/tests/utils.py index 01835cf..de34add 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,7 +33,13 @@ def infer_model(model, input): os.unlink(filename) -def assert_model_result(g, input, expected_result, atol=1e-08, save_path=None): +def assert_model_result( + g, input, + expected_result, + exact_match=False, + atol=1e-08, + save_path=None +): model = graph.compile(g, target_opset=13) _, filename = tempfile.mkstemp() try: @@ -45,8 +51,12 @@ def assert_model_result(g, input, expected_result, atol=1e-08, save_path=None): pred = sess.run(None, input) print(pred) + print(expected_result) for i, p in enumerate(pred): - assert np.allclose(p, np.array(expected_result[i])) + if exact_match: + assert p.tolist() == expected_result[i] + else: + assert np.allclose(p, np.array(expected_result[i])) finally: os.unlink(filename)