Skip to content

Commit

Permalink
Merge pull request #1 from views-platform/ensemble_manager
Browse files Browse the repository at this point in the history
Ensemble manager
  • Loading branch information
xiaolong0728 authored Dec 10, 2024
2 parents b684356 + 57904ab commit 23370cf
Show file tree
Hide file tree
Showing 6 changed files with 712 additions and 213 deletions.
2 changes: 1 addition & 1 deletion tests/test_model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def test_save_model_outputs(mock_model_path):
df_output = pd.DataFrame({"col1": [5, 6], "col2": [7, 8]})
path_generated = Path("/path/to/generated")
with patch("builtins.open", new_callable=mock_open), patch("pathlib.Path.mkdir"):
manager._save_model_outputs(df_evaluation, df_output, path_generated)
manager._save_model_outputs(df_evaluation, df_output, path_generated, sequence_number=1)
with patch("pathlib.Path.exists", return_value=True):
assert Path(f"{path_generated}/output_1_test_run_20220101.pkl").exists()
assert Path(f"{path_generated}/evaluation_1_test_run_20220101.pkl").exists()
Expand Down
20 changes: 16 additions & 4 deletions views_pipeline_core/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def parse_args():
help="Enable drift-detection self_test at data-fetch"
)

parser.add_argument(
"-et", "--eval_type", type=str, default="standard",
help="Type of evaluation to be performed"
)

return parser.parse_args()


Expand Down Expand Up @@ -134,10 +139,17 @@ def validate_arguments(args):
)
sys.exit(1)

if not args.train and not args.saved:
# if not training, then we need to use saved data
if (not args.train and not args.sweep) and not args.saved:
# if not training or sweeping, then we need to use saved data
print(
"Error: if --train or --sweep is not set, you should only use --saved flag. Exiting."
)
print("To fix: Add --train or --sweep or --saved flag.")
sys.exit(1)

if args.eval_type not in ["standard", "long", "complete", "live"]:
print(
"Error: if --train is not set, you should only use --saved flag. Exiting."
"Error: --eval_type should be one of 'standard', 'long', 'complete', or 'live'. Exiting."
)
print("To fix: Add --train or --saved flag.")
print("To fix: Set --eval_type to one of the above options.")
sys.exit(1)
322 changes: 322 additions & 0 deletions views_pipeline_core/managers/ensemble_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,322 @@
from views_pipeline_core.managers.path_manager import ModelPath, EnsemblePath
from views_pipeline_core.managers.model_manager import ModelManager
from views_pipeline_core.wandb.utils import add_wandb_monthly_metrics, log_wandb_log_dict
from views_pipeline_core.models.check import ensemble_model_check
from views_pipeline_core.files.utils import read_log_file, create_log_file
from views_pipeline_core.models.outputs import generate_output_dict
from views_pipeline_core.evaluation.metrics import generate_metric_dict
from typing import Union, Optional, List, Dict
import wandb
import logging
import time
import pickle
from pathlib import Path
import subprocess
from datetime import datetime
import pandas as pd

logger = logging.getLogger(__name__)


class EnsembleManager(ModelManager):

def __init__(self, ensemble_path: EnsemblePath) -> None:
super().__init__(ensemble_path)

@staticmethod
def _get_shell_command(model_path: ModelPath,
run_type: str,
train: bool,
evaluate: bool,
forecast: bool,
use_saved: bool = False,
eval_type: str = "standard"
) -> list:
"""
Args:
model_path (ModelPath): model path object for the model
run_type (str): the type of run (calibration, testing, forecasting)
train (bool): if the model should be trained
evaluate (bool): if the model should be evaluated
forecast (bool): if the model should be used for forecasting
use_saved (bool): if the model should use locally stored data
Returns:
"""

shell_command = [f"{str(model_path.model_dir)}/run.sh"]
shell_command.append("--run_type")
shell_command.append(run_type)

if train:
shell_command.append("--train")
if evaluate:
shell_command.append("--evaluate")
if forecast:
shell_command.append("--forecast")
if use_saved:
shell_command.append("--saved")

shell_command.append("--eval_type")
shell_command.append(eval_type)

return shell_command

@staticmethod
def _get_aggregated_df(df_to_aggregate, aggregation):
"""
Aggregates the DataFrames of model outputs based on the specified aggregation method.
Args:
- df_to_aggregate (list of pd.DataFrame): A list of DataFrames of model outputs.
- aggregation (str): The aggregation method to use (either "mean" or "median").
Returns:
- df (pd.DataFrame): The aggregated DataFrame of model outputs.
"""

if aggregation == "mean":
return pd.concat(df_to_aggregate).groupby(level=[0, 1]).mean()
elif aggregation == "median":
return pd.concat(df_to_aggregate).groupby(level=[0, 1]).median()
else:
logger.error(f"Invalid aggregation: {aggregation}")

def execute_single_run(self, args) -> None:
"""
Executes a single run of the model, including data fetching, training, evaluation, and forecasting.
Args:
args: Command line arguments.
"""
self.config = self._update_single_config(args)
self._project = f"{self.config['name']}_{args.run_type}"
self._eval_type = args.eval_type

try:
if not args.train:
ensemble_model_check(self.config)

self._execute_model_tasks(
config=self.config,
train=args.train,
eval=args.evaluate,
forecast=args.forecast,
use_saved=args.saved
)

except Exception as e:
logger.error(f"Error during single run execution: {e}")

def _execute_model_tasks(
self,
config: Optional[Dict] = None,
train: Optional[bool] = None,
eval: Optional[bool] = None,
forecast: Optional[bool] = None,
use_saved: Optional[bool] = None
) -> None:
"""
Executes various model-related tasks including training, evaluation, and forecasting.
Args:
config (dict, optional): Configuration object containing parameters and settings.
train (bool, optional): Flag to indicate if the model should be trained.
eval (bool, optional): Flag to indicate if the model should be evaluated.
forecast (bool, optional): Flag to indicate if forecasting should be performed.
"""
start_t = time.time()
try:
with wandb.init(project=self._project, entity=self._entity, config=config):
add_wandb_monthly_metrics()
self.config = wandb.config

if train:
logger.info(f"Training model {self.config['name']}...")
self._train_ensemble(use_saved)

if eval:
logger.info(f"Evaluating model {self.config['name']}...")
self._evaluate_ensemble(self._eval_type)

if forecast:
logger.info(f"Forecasting model {self.config['name']}...")
self._forecast_ensemble()

wandb.finish()
except Exception as e:
logger.error(f"Error during model tasks execution: {e}")

end_t = time.time()
minutes = (end_t - start_t) / 60
logger.info(f"Done. Runtime: {minutes:.3f} minutes.\n")

def _train_model_artifact(self, model_name:str, run_type: str, use_saved: bool) -> None:
logger.info(f"Training single model {model_name}...")

model_path = ModelPath(model_name)
model_config = ModelManager(model_path).configs
model_config["run_type"] = run_type

shell_command = EnsembleManager._get_shell_command(model_path,
run_type,
train=True,
evaluate=False,
forecast=False,
use_saved=use_saved)

# print(shell_command)
try:
subprocess.run(shell_command, check=True)
except Exception as e:
logger.error(f"Error during shell command execution for model {model_name}: {e}")

def _evaluate_model_artifact(self, model_name:str, run_type: str, eval_type: str) -> None:
logger.info(f"Evaluating single model {model_name}...")

model_path = ModelPath(model_name)
path_raw = model_path.data_raw
path_generated = model_path.data_generated
path_artifacts = model_path.artifacts
path_artifact = self._get_latest_model_artifact(path_artifacts, run_type)

ts = path_artifact.stem[-15:]

preds = []

for sequence_number in range(ModelManager._resolve_evaluation_sequence_number(eval_type)):

pkl_path = f"{path_generated}/predictions_{run_type}_{ts}_{str(sequence_number).zfill(2)}.pkl"
if Path(pkl_path).exists():
logger.info(f"Loading existing {run_type} predictions from {pkl_path}")
with open(pkl_path, "rb") as file:
pred = pickle.load(file)
else:
logger.info(f"No existing {run_type} predictions found. Generating new {run_type} predictions...")
model_config = ModelManager(model_path).configs
model_config["run_type"] = run_type
shell_command = EnsembleManager._get_shell_command(model_path,
run_type,
train=False,
evaluate=True,
forecast=False,
use_saved=True,
eval_type=eval_type)

try:
subprocess.run(shell_command, check=True)
except Exception as e:
logger.error(f"Error during shell command execution for model {model_name}: {e}")

with open(pkl_path, "rb") as file:
pred = pickle.load(file)

data_generation_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
date_fetch_timestamp = read_log_file(path_raw / f"{run_type}_data_fetch_log.txt").get("Data Fetch Timestamp", None)

create_log_file(path_generated, model_config, ts, data_generation_timestamp, date_fetch_timestamp)

preds.append(pred)

return preds

def _forecast_model_artifact(self, model_name:str, run_type: str) -> None:
logger.info(f"Forecasting single model {model_name}...")

model_path = ModelPath(model_name)
path_raw = model_path.data_raw
path_generated = model_path.data_generated
path_artifacts = model_path.artifacts
path_artifact = self._get_latest_model_artifact(path_artifacts, run_type)

ts = path_artifact.stem[-15:]

pkl_path = f"{path_generated}/predictions_{run_type}_{ts}.pkl"
if Path(pkl_path).exists():
logger.info(f"Loading existing {run_type} predictions from {pkl_path}")
with open(pkl_path, "rb") as file:
df = pickle.load(file)
else:
logger.info(f"No existing {run_type} predictions found. Generating new {run_type} predictions...")
model_config = ModelManager(model_path).configs
model_config["run_type"] = run_type
shell_command = EnsembleManager._get_shell_command(model_path,
run_type,
train=False,
evaluate=False,
forecast=True,
use_saved=True)
# print(shell_command)
try:
subprocess.run(shell_command, check=True)
except Exception as e:
logger.error(f"Error during shell command execution for model {model_name}: {e}")

with open(pkl_path, "rb") as file:
df = pickle.load(file)

data_generation_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
date_fetch_timestamp = read_log_file(path_raw / f"{run_type}_data_fetch_log.txt").get("Data Fetch Timestamp", None)

create_log_file(path_generated, model_config, ts, data_generation_timestamp, date_fetch_timestamp)

return df

def _train_ensemble(self, use_saved: bool) -> None:
run_type = self.config["run_type"]

for model_name in self.config["models"]:
self._train_model_artifact(model_name, run_type, use_saved)

def _evaluate_ensemble(self, eval_type: str) -> None:
path_generated_e = self._model_path.data_generated
run_type = self.config["run_type"]
dfs = []
dfs_agg = []

for model_name in self.config["models"]:
dfs.append(self._evaluate_model_artifact(model_name, run_type, eval_type))

for i in range(len(dfs[0])):
df_to_aggregate = [df[i] for df in dfs]
df_agg = EnsembleManager._get_aggregated_df(df_to_aggregate, self.config["aggregation"])
dfs_agg.append(df_agg)

data_generation_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

# _, df_output = generate_output_dict(df_agg, self.config)
# evaluation, df_evaluation = generate_metric_dict(df_agg, self.config)
# log_wandb_log_dict(self.config, evaluation)

# Timestamp of single models is more important than ensemble model timestamp
self.config["timestamp"] = datetime.now().strftime("%Y%m%d_%H%M%S")
# self._save_model_outputs(df_evaluation, df_output, path_generated_e)
for i, df_agg in enumerate(dfs_agg):
self._save_predictions(df_agg, path_generated_e, i)

# How to define an ensemble model timestamp? Currently set as data_generation_timestamp.
create_log_file(path_generated_e, self.config, data_generation_timestamp, data_generation_timestamp, data_fetch_timestamp=None,
model_type="ensemble", models=self.config["models"])

def _forecast_ensemble(self) -> None:
path_generated_e = self._model_path.data_generated
run_type = self.config["run_type"]
dfs = []

for model_name in self.config["models"]:

dfs.append(self._forecast_model_artifact(model_name, run_type))

df_prediction = EnsembleManager._get_aggregated_df(dfs, self.config["aggregation"])
data_generation_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

self.config["timestamp"] = datetime.now().strftime("%Y%m%d_%H%M%S")
self._save_predictions(df_prediction, path_generated_e)

# How to define an ensemble model timestamp? Currently set as data_generation_timestamp.
create_log_file(path_generated_e, self.config, data_generation_timestamp, data_generation_timestamp, data_fetch_timestamp=None,
model_type="ensemble", models=self.config["models"])


Loading

0 comments on commit 23370cf

Please sign in to comment.