-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
73 lines (64 loc) · 2.59 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
'''
Este módulo es un script que se usa para
hacer inferencia con un modelo de machine learning
previamente entrenado y que fue guardado en la carpeta ./models
La salida de este script es un archivo .csv con las predicciones
'''
# Importar librerías
import os
import logging
from datetime import datetime
import argparse
from src.scripts_inference import inference, get_user_input
if not os.path.exists("logs/"):
os.makedirs("logs/")
# Setup Logging
now = datetime.now()
date_time = now.strftime("%Y%m%d_%H%M%S")
log_inference_file_name = f"logs/{date_time}_inference.log"
logging.basicConfig(
filename=log_inference_file_name,
level=logging.DEBUG,
filemode='w',
format='%(name)s - %(levelname)s - %(message)s')
logging.info("Inferencia iniciada ...")
def main(command_line_args):
'''
Función principal que ejecuta la inferencia
'''
logging.info("Cargando el modelo ...")
# Directorios de entrada y salida
output_pred = command_line_args.output_path
logging.debug("Ruta del modelo: %s", command_line_args.model_path)
logging.info("El modelo fue cargado exitosamente")
if not os.path.exists(output_pred):
os.makedirs(output_pred)
output_file = os.path.join(output_pred, "predictions.csv")
logging.debug("Ruta de salida para las predicciones: %s", output_file)
logging.info("La predicción ya fue guardada en ./data/predictions")
# Definir las características necesarias para la predicción
feature_columns = ['Id', 'OverallQual',
'GrLivArea', 'FullBath',
'YearBuilt', 'GarageCars',
'GarageArea', 'ExterQual',
'BsmtQual']
logging.debug("Características para la predicción: %s",
', '.join(feature_columns))
# Solicitar entrada del usuario
user_input = get_user_input(feature_columns)
logging.debug("Entrada del usuario: %s", str(user_input))
logging.info("Entrada del usuario obtenida")
# Se ejecuta la inferencia
inference(output_file, user_input)
logging.info("Inferencia finalizada")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description='Script para hacer inferencia con un modelo de ML')
parser.add_argument('--model_path', type=str,
default='./models/rfr_model.joblib',
help='Ruta del modelo entrenado')
parser.add_argument('--output_path', type=str,
default='./data/predictions',
help='Ruta de salida para las predicciones')
args = parser.parse_args()
main(args)