diff --git a/macros/hooks/model_audit.sql b/macros/hooks/model_audit.sql index 80ba8e1..437b9c3 100644 --- a/macros/hooks/model_audit.sql +++ b/macros/hooks/model_audit.sql @@ -131,9 +131,24 @@ onnx: {} {% endmacro %} +{% macro _get_model_type(ml_config) %} + + {%- set ns = namespace(model_type=none) %} + {%- if ml_config %} + {%- for key, value in ml_config.items() %} + {%- if key.lower() == 'model_type' %} + {%- set ns.model_type = value | string | lower %} + {%- endif %} + {%- endfor %} + {%- endif %} + {% do return(ns.model_type) %} + +{% endmacro %} + {% macro model_audit() %} - {% set model_type = config.get('ml_config')['model_type'].lower() if config.get('ml_config')['model_type'] else None %} + {%- set ml_config = config.get('ml_config', {}) -%} + {% set model_type = dbt_ml._get_model_type(ml_config) %} {% set model_type_repr = model_type if model_type in dbt_ml._audit_insert_templates().keys() else 'default' %} {% set info_types = ['training_info', 'feature_info', 'weights', 'evaluate'] %}