From 92eebb0d82c58b0941e65d04c5bc8c693f25e796 Mon Sep 17 00:00:00 2001 From: Romain Picard Date: Tue, 5 Mar 2024 09:16:39 +0100 Subject: [PATCH] fix conversion error when a boolean feature has only one value --- ebm2onnx/convert.py | 8 ++++++-- tests/test_convert.py | 8 +++++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/ebm2onnx/convert.py b/ebm2onnx/convert.py index 011d204..8ad7b8d 100644 --- a/ebm2onnx/convert.py +++ b/ebm2onnx/convert.py @@ -18,6 +18,10 @@ 'str': onnx.TensorProto.STRING, } +bool_remap = { + 'False': '0', + 'True': '1', +} def infer_features_dtype(dtype, feature_name): feature_dtype = onnx.TensorProto.DOUBLE @@ -126,8 +130,8 @@ def to_onnx(model, dtype, name="ebm", if feature_dtype == onnx.TensorProto.BOOL: # ONNX converts booleans to strings 0/1, not False/True col_mapping = { - '0': col_mapping['False'], - '1': col_mapping['True'], + bool_remap[k]: v + for k, v in col_mapping.items() } # replace inplace to re-use it in interactions model.bins_[feature_group[0]][0] = col_mapping diff --git a/tests/test_convert.py b/tests/test_convert.py index 0f790dd..55ffc59 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -12,7 +12,7 @@ from .utils import infer_model, create_session -def train_titanic_binary_classification(interactions=0, with_categorical=False): +def train_titanic_binary_classification(interactions=0, with_categorical=False, old_th=65): df = pd.read_csv( os.path.join('examples','titanic_train.csv'), #dtype= { @@ -22,7 +22,7 @@ def train_titanic_binary_classification(interactions=0, with_categorical=False): #} ) df = df.dropna() - df['Old'] = df['Age'] > 65 + df['Old'] = df['Age'] > old_th if with_categorical is False: feature_types=['continuous', 'continuous', 'continuous', 'continuous'] feature_columns = ['Age', 'Fare', 'Pclass', 'Old'] @@ -168,10 +168,12 @@ def test_predict_regression_without_interactions(interactions, explain): @pytest.mark.parametrize("explain", [False, True]) @pytest.mark.parametrize("interactions", [0, 2, [(0, 1, 2)], [(0, 1, 2, 3)]]) -def test_predict_binary_classification_with_categorical(interactions, explain): +@pytest.mark.parametrize("old_th", [65, 0]) +def test_predict_binary_classification_with_categorical(interactions, explain, old_th): model_ebm, x_test, y_test = train_titanic_binary_classification( interactions=interactions, with_categorical=True, + old_th=old_th, ) pred_ebm = model_ebm.predict(x_test)