diff --git a/tests/unit-tests/gconstruct/test_transform.py b/tests/unit-tests/gconstruct/test_transform.py index 54bed2d382..83cd7a16ae 100644 --- a/tests/unit-tests/gconstruct/test_transform.py +++ b/tests/unit-tests/gconstruct/test_transform.py @@ -722,6 +722,30 @@ def test_bucket_transform(out_dtype): feats_tar = np.array([[1, 1], [1, 1], [1, 1], [1, 1]], dtype=out_dtype) assert_equal(bucket_feats['test'], feats_tar) + +def test_multicolumn(): + # Just get the features without transformation. + feat_op1 = [{ + "feature_col": ["test1", "test2"], + "feature_name": "test3", + }] + (res, _, _) = parse_feat_ops(feat_op1) + assert len(res) == 1 + assert res[0].col_name == feat_op1[0]["feature_col"] + assert res[0].feat_name == feat_op1[0]["feature_name"] + assert isinstance(res[0], Noop) + + data = { + "test1": np.random.rand(4, 2), + "test2": np.random.rand(4, 2) + } + data["test3"] = np.column_stack((data['test1'], data['test2'])) + proc_res = process_features(data, res) + assert "test3" in proc_res + assert proc_res["test3"].dtype == np.float32 + np.testing.assert_allclose(proc_res["test3"], data["test3"]) + + if __name__ == '__main__': test_categorize_transform() test_get_output_dtype() @@ -745,3 +769,4 @@ def test_bucket_transform(out_dtype): test_collect_label_stats() test_custom_label_processor() test_classification_processor() + test_multicolumn() \ No newline at end of file