diff --git a/apps/anomalist/__init__.py b/apps/anomalist/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/apps/anomalist/anomalist_api.py b/apps/anomalist/anomalist_api.py new file mode 100644 index 00000000000..47194c29015 --- /dev/null +++ b/apps/anomalist/anomalist_api.py @@ -0,0 +1,452 @@ +import tempfile +import time +from pathlib import Path +from typing import List, Literal, Optional, Tuple, cast, get_args + +import numpy as np +import pandas as pd +import structlog +from owid.catalog import find +from sqlalchemy.engine import Engine +from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import Session + +from apps.anomalist.detectors import ( + AnomalyDetector, + AnomalyTimeChange, + AnomalyUpgradeChange, + AnomalyUpgradeMissing, + get_long_format_score_df, +) +from apps.anomalist.gp_detector import AnomalyGaussianProcessOutlier +from apps.wizard.utils.paths import WIZARD_ANOMALIES_RELATIVE +from etl import grapher_model as gm +from etl.config import OWID_ENV +from etl.db import get_engine, read_sql +from etl.files import create_folder, upload_file_to_server +from etl.grapher_io import variable_data_df_from_s3 + +log = structlog.get_logger() + +# Name of index columns for dataframe. +INDEX_COLUMNS = ["entity_name", "year"] + +# TODO: this is repeated in detector classes, is there a way to DRY this? +ANOMALY_TYPE = Literal["time_change", "upgrade_change", "upgrade_missing", "gp_outlier"] + +# Define mapping of available anomaly types to anomaly detectors. +ANOMALY_DETECTORS = { + detector.anomaly_type: detector + for detector in [ + AnomalyTimeChange, + AnomalyUpgradeChange, + AnomalyUpgradeMissing, + AnomalyGaussianProcessOutlier, + ] +} + + +def load_detector(anomaly_type: ANOMALY_TYPE) -> AnomalyDetector: + """Load detector.""" + return ANOMALY_DETECTORS[anomaly_type] + + +def load_latest_population(): + # NOTE: The "channels" parameter of the find function is not working well. + candidates = find("population", channels=("grapher",), dataset="population", namespace="demography").sort_values( + "version", ascending=False + ) + population = ( + candidates[(candidates["table"] == "population") & (candidates["channel"] == "grapher")] + .iloc[0] + .load() + .reset_index()[["country", "year", "population"]] + ).rename(columns={"country": "entity_name"}, errors="raise") + + return population + + +def get_variables_views_in_charts( + variable_ids: List[int], +) -> pd.DataFrame: + # Assumed base url for all charts. + base_url = "https://ourworldindata.org/grapher/" + + # SQL query to join variables, charts, and analytics pageviews data + query = f"""\ + SELECT + v.id AS variable_id, + c.id AS chart_id, + cc.slug AS chart_slug, + ap.views_7d, + ap.views_14d, + ap.views_365d + FROM + charts c + JOIN + chart_dimensions cd ON c.id = cd.chartId + JOIN + variables v ON cd.variableId = v.id + JOIN + chart_configs cc ON c.configId = cc.id + LEFT JOIN + analytics_pageviews ap ON ap.url = CONCAT('{base_url}', cc.slug) + WHERE + v.id IN ({', '.join([str(v_id) for v_id in variable_ids])}) + ORDER BY + v.id ASC; + """ + df = read_sql(query) + # Handle potential duplicated rows + df = df.drop_duplicates().reset_index(drop=True) + + if len(df) == 0: + df = pd.DataFrame(columns=["variable_id", "chart_id", "chart_slug", "views_7d", "views_14d", "views_365d"]) + + return df + + +def add_population_score(df_reduced: pd.DataFrame) -> pd.DataFrame: + # To normalize the analytics score to the range 0, 1, divide by an absolute maximum number of people. + # NOTE: This should a safe assumption before ~2060. + absolute_maximum_population = 1e10 + # Value to use to fill missing values in the population score (e.g. for regions like "Middle East" that are not included in our population dataset). + fillna_value = 0.5 + + # Load the latest population data. + df_population = load_latest_population() + error = f"Expected a maximum population below {absolute_maximum_population}." + assert df_population[df_population["year"] < 2040]["population"].max() < absolute_maximum_population, error + + # First, get the unique combinations of country-years in the scores dataframe, and add population to it. + df_score_population = ( + df_reduced[["entity_name", "year"]] # type: ignore + .drop_duplicates() + .merge(df_population, on=["entity_name", "year"], how="left") + ) + + # To normalize the population score to the range 0, 1, divide by an absolute maximum population. + # To have more convenient numbers, take the natural logarithm of the population. + df_score_population["score_population"] = np.log(df_score_population["population"]) / np.log( + absolute_maximum_population + ) + + # Add the population score to the scores dataframe. + df_reduced = df_reduced.merge(df_score_population, on=["entity_name", "year"], how="left").drop( + columns="population", errors="raise" + ) + + # Variables that do not have population data will have a population score nan. Fill them with a low value. + df_reduced["score_population"] = df_reduced["score_population"].fillna(fillna_value) + + return df_reduced + + +def add_analytics_score(df_reduced: pd.DataFrame) -> pd.DataFrame: + # Focus on the following specific analytics column. + analytics_column = "views_14d" + # To normalize the analytics score to the range 0, 1, divide by an absolute maximum number of views. + absolute_maximum_views = 1e6 + # Value to use to fill missing values in the analytics score (e.g. for variables that are not used in charts). + fillna_value = 0.1 + + # Get number of views in charts for each variable id. + df_views = get_variables_views_in_charts(list(df_reduced["variable_id"].unique())) + # Sanity check. + if not df_views.empty: + error = f"Expected a maximum number of views below {absolute_maximum_views}. Change this limit." + assert df_views[analytics_column].max() < absolute_maximum_views, error + + # Get the sum of the number of views in charts for each variable id in the last 14 days. + # So, if an indicator is used in multiple charts, their views are summed. + # This rewards indicators that are used multiple times, and on popular charts. + # NOTE: The analytics table often contains nans. For now, for convenience, fill them with 1.1 views (to avoid zeros when calculating the log). + df_score_analytics = ( + df_views.groupby("variable_id") + .agg({analytics_column: "sum"}) + .reset_index() + .rename(columns={analytics_column: "views"}) + ) + # To have more convenient numbers, take the natural logarithm of the views. + df_score_analytics["score_analytics"] = np.log(df_score_analytics["views"]) / np.log(absolute_maximum_views) + + # Add the analytics score to the scores dataframe. + df_reduced = df_reduced.merge(df_score_analytics, on=["variable_id"], how="left") + + # Variables that do not have charts will have an analytics score nan. + # Fill them with a low value (e.g. 0.1) to avoid zeros when calculating the final score. + df_reduced["score_analytics"] = df_reduced["score_analytics"].fillna(fillna_value) + + return df_reduced + + +def add_weighted_score(df: pd.DataFrame) -> pd.DataFrame: + """Add a weighted combined score.""" + w_score = 1 + w_pop = 1 + w_views = 1 + df["score_weighted"] = ( + w_score * df["score"] + w_pop * df["score_population"] + w_views * df["score_analytics"] + ) / (w_score + w_pop + w_views) + + return df + + +def add_auxiliary_scores(df: pd.DataFrame) -> pd.DataFrame: + # Add a population score. + df = add_population_score(df_reduced=df) + + # Add an analytics score. + df = add_analytics_score(df_reduced=df) + + # Rename columns for convenience. + df = df.rename(columns={"variable_id": "indicator_id", "anomaly_score": "score"}, errors="raise") + + # Create a weighted combined score. + df = add_weighted_score(df) + + return df + + +def anomaly_detection( + anomaly_types: Optional[Tuple[str, ...]] = None, + variable_mapping: Optional[dict[int, int]] = None, + variable_ids: Optional[list[int]] = None, + dry_run: bool = False, + force: bool = False, + reset_db: bool = False, +) -> None: + """Detect anomalies.""" + engine = get_engine() + + # Ensure the 'anomalies' table exists. Optionally reset it if reset_db is True. + gm.Anomaly.create_table(engine, if_exists="replace" if reset_db else "skip") + + # If no anomaly types are provided, default to all available types + if not anomaly_types: + anomaly_types = get_args(ANOMALY_TYPE) + + # Parse the variable_mapping if any provided. + if not variable_mapping: + variable_mapping = dict() + + if variable_ids is None: + variable_ids = [] + + # Load metadata for: + # * All variables in dataset_ids (if any dataset_id is given). + # * All variables in variable_ids. + # * All variables in variable_mapping (both old and new). + variable_ids_mapping = ( + (set(variable_mapping.keys()) | set(variable_mapping.values())) if variable_mapping else set() + ) + variable_ids_all = list(variable_ids_mapping | set(variable_ids or [])) + # Dictionary variable_id: Variable object, for all variables (old and new). + variables = { + variable.id: variable for variable in _load_variables_meta(engine=engine, variable_ids=variable_ids_all) + } + + # Create a dictionary of all variable_ids for each dataset_id (only for new variables). + dataset_variable_ids = {} + # TODO: Ensure variable_ids always corresponds to new variables. + # Note that currently, if dataset_id is passed and variable_ids is not, this will not load anything. + for variable_id in variable_ids: + variable = variables[variable_id] + if variable.datasetId not in dataset_variable_ids: + dataset_variable_ids[variable.datasetId] = [] + dataset_variable_ids[variable.datasetId].append(variable) + + for dataset_id, variables_in_dataset in dataset_variable_ids.items(): + # Get dataset's checksum + with Session(engine) as session: + dataset = gm.Dataset.load_dataset(session, dataset_id) + + log.info("loading_data_from_s3.start") + variables_old = [ + variables[variable_id_old] + for variable_id_old in variable_mapping.keys() + if variable_mapping[variable_id_old] in [variable.id for variable in variables_in_dataset] + ] + variables_old_and_new = variables_in_dataset + variables_old + t = time.time() + df = load_data_for_variables(engine=engine, variables=variables_old_and_new) + log.info("loading_data_from_s3.end", t=time.time() - t) + + for anomaly_type in anomaly_types: + # Instantiate the anomaly detector. + if anomaly_type not in ANOMALY_DETECTORS: + raise ValueError(f"Unsupported anomaly type: {anomaly_type}") + + if not force: + if not needs_update(engine, dataset, anomaly_type): + log.info(f"Anomaly type {anomaly_type} for dataset {dataset_id} already exists in the database.") + continue + + log.info(f"Detecting anomaly type {anomaly_type} for dataset {dataset_id}") + + # Instantiate the anomaly detector. + detector = ANOMALY_DETECTORS[anomaly_type]() + + # Select the variable ids that are included in the current dataset. + variable_ids_for_current_dataset = [variable.id for variable in variables_in_dataset] + # Select the subset of the mapping that is relevant for the current dataset. + variable_mapping_for_current_dataset = { + variable_old: variable_new + for variable_old, variable_new in variable_mapping.items() + if variable_new in variable_ids_for_current_dataset + } + + # Get the anomaly score dataframe for the current dataset and anomaly type. + df_score = detector.get_score_df( + df=df, + variable_ids=variable_ids_for_current_dataset, + variable_mapping=variable_mapping_for_current_dataset, + ) + + if df_score.empty: + log.info("No anomalies detected.`") + continue + + # Create a long format score dataframe. + df_score_long = get_long_format_score_df(df_score) + + # TODO: validate format of the output dataframe + anomaly = gm.Anomaly( + datasetId=dataset_id, + datasetSourceChecksum=dataset.sourceChecksum, + anomalyType=anomaly_type, + ) + # We could store the full dataframe in the database, but it ends up making the load quite slow. + # Since we are not using it for now, we will store only the reduced dataframe. + # anomaly.dfScore = df_score_long + + # Reduce dataframe + df_score_long_reduced = ( + df_score_long.sort_values("anomaly_score", ascending=False) + .drop_duplicates(subset=["entity_name", "variable_id"], keep="first") + .reset_index(drop=True) + ) + anomaly.dfReduced = df_score_long_reduced + + ################################################################## + # TODO: Use this as an alternative to storing binary files in the DB + # anomaly = gm.Anomaly( + # datasetId=dataset_id, + # anomalyType=detector.anomaly_type, + # ) + # anomaly.dfScore = None + + # # Export anomaly file + # anomaly.path_file = export_anomalies_file(df_score, dataset_id, detector.anomaly_type) + ################################################################## + + if not dry_run: + with Session(engine) as session: + # log.info("Deleting existing anomalies") + session.query(gm.Anomaly).filter( + gm.Anomaly.datasetId == dataset_id, + gm.Anomaly.anomalyType == anomaly_type, + ).delete(synchronize_session=False) + session.commit() + + # Don't save anomalies if there are none + if df_score_long.empty: + log.info(f"No anomalies found for anomaly type {anomaly_type} in dataset {dataset_id}") + else: + # Insert new anomalies + log.info("Writing anomaly to database") + session.add(anomaly) + session.commit() + + +def needs_update(engine: Engine, dataset: gm.Dataset, anomaly_type: str) -> bool: + """If there's an anomaly with the dataset checksum in DB, it doesn't need + to be updated.""" + with Session(engine) as session: + try: + anomaly = gm.Anomaly.load( + session, + dataset_id=dataset.id, + anomaly_type=anomaly_type, + ) + except NoResultFound: + return True + + return anomaly.datasetSourceChecksum != dataset.sourceChecksum + + +def export_anomalies_file(df: pd.DataFrame, dataset_id: int, anomaly_type: str) -> str: + """Export anomaly df to local file (and upload to staging server if applicable).""" + filename = f"{dataset_id}_{anomaly_type}.feather" + path = Path(f".anomalies/{filename}") + path_str = str(path) + if OWID_ENV.env_local == "staging": + create_folder(path.parent) + df.to_feather(path_str) + elif OWID_ENV.env_local == "dev": + # tmp_filename = Path("tmp.feather") + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_file_path = Path(tmp_dir) / filename + df.to_feather(tmp_file_path) + upload_file_to_server(tmp_file_path, f"owid@{OWID_ENV.name}:/home/owid/etl/{WIZARD_ANOMALIES_RELATIVE}") + else: + raise ValueError( + f"Unsupported environment: {OWID_ENV.env_local}. Did you try production? That's not supported!" + ) + return path_str + + +# @memory.cache +def load_data_for_variables(engine: Engine, variables: list[gm.Variable]) -> pd.DataFrame: + # TODO: cache this on disk & re-validate with etags + df_long = variable_data_df_from_s3(engine, [v.id for v in variables], workers=None) + + df_long = df_long.rename(columns={"variableId": "variable_id", "entityName": "entity_name"}) + + # pivot dataframe + df = df_long.pivot(index=["entity_name", "year"], columns="variable_id", values="value") + + # reorder in the same order as variables + df = df[[v.id for v in variables]] + + # set non-numeric values to NaN + df = df.apply(pd.to_numeric, errors="coerce") + + # remove variables with all nulls or all zeros or constant values + df = df.loc[:, df.fillna(0).std(axis=0) != 0] + + df = df.reset_index().astype({"entity_name": str}) + + return df # type: ignore + + +# @memory.cache +def _load_variables_meta(engine: Engine, variable_ids: list[int]) -> list[gm.Variable]: + q = """ + select id from variables + where id in %(variable_ids)s + """ + df = read_sql(q, engine, params={"variable_ids": variable_ids}) + + # select all variables using SQLAlchemy + with Session(engine) as session: + return gm.Variable.load_variables(session, list(df["id"])) + + +def combine_and_reduce_scores_df(anomalies: List[gm.Anomaly]) -> pd.DataFrame: + """Get the combined dataframe with scores for all anomalies, and reduce it to include only the largest anomaly for each contry-indicator.""" + # Combine the reduced dataframes for all anomalies into a single dataframe. + dfs = [] + for anomaly in anomalies: + df = anomaly.dfReduced + if df is None: + log.warning(f"Anomaly {anomaly} has no reduced dataframe.") + continue + df["type"] = anomaly.anomalyType + dfs.append(df) + + df_reduced = cast(pd.DataFrame, pd.concat(dfs, ignore_index=True)) + # Dtypes + # df = df.astype({"year": int}) + + return df_reduced diff --git a/apps/anomalist/cli.py b/apps/anomalist/cli.py new file mode 100644 index 00000000000..8c3610b8497 --- /dev/null +++ b/apps/anomalist/cli.py @@ -0,0 +1,127 @@ +import json +from typing import Optional, Tuple, get_args + +import click +import structlog +from joblib import Memory +from rich_click.rich_command import RichCommand + +from apps.anomalist.anomalist_api import ANOMALY_TYPE, anomaly_detection +from etl.db import get_engine, read_sql +from etl.paths import CACHE_DIR + +log = structlog.get_logger() + +memory = Memory(CACHE_DIR, verbose=0) + + +@click.command(name="anomalist", cls=RichCommand, help=anomaly_detection.__doc__) +@click.option( + "--anomaly-types", + type=click.Choice(list(get_args(ANOMALY_TYPE))), + multiple=True, + default=None, + help="Type (or types) of anomaly detection algorithm to use.", +) +@click.option( + "--dataset-ids", + type=int, + multiple=True, + default=None, + help="Generate anomalies for the variables of a specific dataset ID (or multiple dataset IDs).", +) +@click.option( + "--variable-mapping", + type=str, + default="", + help="Optional JSON dictionary mapping variable IDs from a previous to a new version (where at least some of the new variable IDs must belong to the datasets whose IDs were given).", +) +@click.option( + "--variable-ids", + type=int, + multiple=True, + default=None, + help="Generate anomalies for a list of variable IDs (in addition to the ones from dataset ID, if any dataset was given).", +) +@click.option( + "--dry-run/--no-dry-run", + default=False, + type=bool, + help="Do not write to target database.", +) +@click.option( + "--force", + "-f", + is_flag=True, + help="TBD", +) +@click.option( + "--reset-db/--no-reset-db", + default=False, + type=bool, + help="Drop anomalies table and recreate it. This is useful for development when the schema changes.", +) +def cli( + anomaly_types: Optional[Tuple[str, ...]], + dataset_ids: Optional[list[int]], + variable_mapping: str, + variable_ids: Optional[list[int]], + dry_run: bool, + force: bool, + reset_db: bool, +) -> None: + """TBD + + TBD + + **Example 1:** Create random anomaly for a dataset + + ``` + $ etl anomalist --anomaly-type sample --dataset-ids 6369 + ``` + + **Example 2:** Create GP anomalies + + ``` + $ etl anomalist --anomaly-type gp --dataset-ids 6369 + ``` + + **Example 3:** Create anomalies by comparing dataset to its previous version + + ``` + $ etl anomalist --anomaly-type gp --dataset-ids 6589 + ``` + """ + # Convert variable mapping from JSON to dictionary. + if variable_mapping: + try: + variable_mapping_dict = { + int(variable_old): int(variable_new) + for variable_old, variable_new in json.loads(variable_mapping).items() + } + except json.JSONDecodeError: + raise ValueError("Invalid JSON format for variable_mapping.") + else: + variable_mapping_dict = {} + + # Load all variables from given datasets + if dataset_ids: + assert not variable_ids, "Cannot specify both dataset IDs and variable IDs." + q = """ + select id from variables + where datasetId in %(dataset_ids)s + """ + variable_ids = list(read_sql(q, get_engine(), params={"dataset_ids": dataset_ids})["id"]) + + anomaly_detection( + anomaly_types=anomaly_types, + variable_mapping=variable_mapping_dict, + variable_ids=list(variable_ids) if variable_ids else None, + dry_run=dry_run, + force=force, + reset_db=reset_db, + ) + + +if __name__ == "__main__": + cli() diff --git a/apps/anomalist/detectors.py b/apps/anomalist/detectors.py new file mode 100644 index 00000000000..c5d2846cfa7 --- /dev/null +++ b/apps/anomalist/detectors.py @@ -0,0 +1,232 @@ +from typing import Dict, List + +import numpy as np +import pandas as pd +import structlog +from sklearn.ensemble import IsolationForest +from sklearn.impute import SimpleImputer +from sklearn.preprocessing import StandardScaler +from sklearn.svm import OneClassSVM +from tqdm.auto import tqdm + +from etl.data_helpers.misc import bard + +log = structlog.get_logger() + +# Name of index columns for dataframe. +INDEX_COLUMNS = ["entity_name", "year"] + + +def estimate_bard_epsilon(series: pd.Series) -> float: + # Make all values positive, and ignore zeros. + positive_values = abs(series.dropna()) + # Ignore zeros, since they can lead to epsilon being zero, hence allowing division by zero in BARD. + positive_values = positive_values.loc[positive_values > 0] + # Estimate epsilon as the absolute range of values divided by 10. + # eps = (positive_values.max() - positive_values.min()) / 10 + # Instead of just taking maximum and minimum, take 95th percentile and 5th percentile. + eps = (positive_values.quantile(0.95) - positive_values.quantile(0.05)) / 10 + + return eps + + +def get_long_format_score_df(df_score: pd.DataFrame) -> pd.DataFrame: + # Output is already in long format + if set(df_score.columns) == {"entity_name", "year", "variable_id", "anomaly_score"}: + df_score_long = df_score + else: + # Create a reduced score dataframe. + df_score_long = df_score.melt( + id_vars=["entity_name", "year"], var_name="variable_id", value_name="anomaly_score" + ) + + # Drop NaN anomalies. + df_score_long = df_score_long.dropna(subset=["anomaly_score"]) + + # Drop zero anomalies. + df_score_long = df_score_long[df_score_long["anomaly_score"] != 0] + + # Save memory by converting to categoricals. + df_score_long = df_score_long.astype({"entity_name": "category", "year": "category", "variable_id": "category"}) + + return df_score_long + + +class AnomalyDetector: + anomaly_type: str + + @staticmethod + def get_text(entity: str, year: int) -> str: + return f"Anomaly happened in {entity} in {year}!" + + def get_score_df(self, df: pd.DataFrame, variable_ids: List[int], variable_mapping: Dict[int, int]) -> pd.DataFrame: + raise NotImplementedError + + def get_zeros_df(self, df: pd.DataFrame, variable_ids: List[int]) -> pd.DataFrame: + # Create a dataframe of zeros. + df_zeros = pd.DataFrame(np.zeros_like(df), columns=df.columns)[INDEX_COLUMNS + variable_ids] + df_zeros[INDEX_COLUMNS] = df[INDEX_COLUMNS].copy() + return df_zeros + + def get_nans_df(self, df: pd.DataFrame, variable_ids: List[int]) -> pd.DataFrame: + # Create a dataframe of nans. + df_nans = pd.DataFrame(np.empty_like(df), columns=df.columns)[INDEX_COLUMNS + variable_ids] + df_nans[variable_ids] = np.nan + df_nans[INDEX_COLUMNS] = df[INDEX_COLUMNS].copy() + return df_nans + + +class AnomalyUpgradeMissing(AnomalyDetector): + """New data misses entity-years that used to exist in old version.""" + + anomaly_type = "upgrade_missing" + + @staticmethod + def get_text(entity: str, year: int) -> str: + return f"There are missing values for {entity}! There might be other data points affected." + + def get_score_df(self, df: pd.DataFrame, variable_ids: List[int], variable_mapping: Dict[int, int]) -> pd.DataFrame: + # Create a dataframe of zeros. + df_lost = self.get_zeros_df(df, variable_ids) + + # Make 1 all cells that used to have data in the old version, but are missing in the new version. + for variable_id_old, variable_id_new in variable_mapping.items(): + affected_rows = df[(df[variable_id_old].notnull()) & (df[variable_id_new].isnull())].index + df_lost.loc[affected_rows, variable_id_new] = 1 + + return df_lost + + +class AnomalyUpgradeChange(AnomalyDetector): + """New dataframe has changed abruptly with respect to the old version.""" + + anomaly_type = "upgrade_change" + + @staticmethod + def get_text(entity: str, year: int) -> str: + return f"There are abrupt changes for {entity} in {year}! There might be other data points affected." + + def get_score_df(self, df: pd.DataFrame, variable_ids: List[int], variable_mapping: Dict[int, int]) -> pd.DataFrame: + # Create a dataframe of zeros. + df_version_change = self.get_zeros_df(df, variable_ids) + + for variable_id_old, variable_id_new in variable_mapping.items(): + # Calculate the BARD epsilon for each variable. + eps = estimate_bard_epsilon(series=df[variable_id_new]) + # Calculate the BARD for each variable. + variable_bard = bard(a=df[variable_id_old], b=df[variable_id_new], eps=eps) + # Add bard to the dataframe. + df_version_change[variable_id_new] = variable_bard + + return df_version_change + + +class AnomalyTimeChange(AnomalyDetector): + """New dataframe has abrupt changes in time series.""" + + anomaly_type = "time_change" + + @staticmethod + def get_text(entity: str, year: int) -> str: + return f"There are significant changes for {entity} in {year} compared to the old version of the indicator. There might be other data points affected." + + def get_score_df(self, df: pd.DataFrame, variable_ids: List[int], variable_mapping: Dict[int, int]) -> pd.DataFrame: + # Create a dataframe of zeros. + df_time_change = self.get_zeros_df(df, variable_ids) + + # Sanity check. + error = "The function that detects abrupt time changes assumes the data is sorted by entity_name and year. But this is not the case. Either ensure the data is sorted, or fix the function." + assert (df.sort_values(by=INDEX_COLUMNS).index == df.index).all(), error + for variable_id in variable_ids: + series = df[variable_id].copy() + # Calculate the BARD epsilon for this variable. + eps = estimate_bard_epsilon(series=series) + # Calculate the BARD for this variable. + _bard = bard(series, series.shift(), eps).fillna(0) + + # Add bard to the dataframe. + df_time_change[variable_id] = _bard + # The previous procedure includes the calculation of the deviation between the last point of an entity and the first point of the next, which is meaningless, and can lead to a high BARD. + # Therefore, make zero the first point of each entity_name for all columns. + # df_time_change.loc[df_time_change["entity_name"].diff().fillna(1) > 0, self.variable_ids] = 0 + df_time_change.loc[df_time_change["entity_name"] != df_time_change["entity_name"].shift(), variable_ids] = 0 + + return df_time_change + + +class AnomalyIsolationForest(AnomalyDetector): + """Anomaly detection using Isolation Forest, applied separately to each country-variable time series.""" + + anomaly_type = "isolation_forest" + + def get_score_df(self, df: pd.DataFrame, variable_ids: List[int], variable_mapping: Dict[int, int]) -> pd.DataFrame: + # Initialize a dataframe of zeros. + df_anomalies = self.get_zeros_df(df, variable_ids) + + # Initialize an imputer to handle missing values. + imputer = SimpleImputer(strategy="mean") + + for variable_id in tqdm(variable_ids): + for country, group in df.groupby("entity_name", observed=True): + # Get the time series for the current country and variable. + series = group[[variable_id]].copy() + + # Skip if the series is all zeros or nans. + if ((series == 0).all().values) or (series.dropna().shape[0] == 0): + continue + + # Impute missing values for this country's time series. + series_imputed = imputer.fit_transform(series) + + # Scale the data. + scaler = StandardScaler() + series_scaled = scaler.fit_transform(series_imputed) + + # Initialize the Isolation Forest model. + isolation_forest = IsolationForest(contamination=0.05, random_state=1) # type: ignore + + # Fit the model and calculate anomaly scores. + isolation_forest.fit(series_scaled) + scores = isolation_forest.decision_function(series_scaled) + df_anomalies.loc[df["entity_name"] == country, variable_id] = scores + + return df_anomalies + + +class AnomalyOneClassSVM(AnomalyDetector): + """Anomaly detection using One-Class SVM, applied separately to each country-variable time series.""" + + anomaly_type = "one_class_svm" + + def get_score_df(self, df: pd.DataFrame, variable_ids: List[int], variable_mapping: Dict[int, int]) -> pd.DataFrame: + # Initialize a dataframe of zeros. + df_anomalies = self.get_zeros_df(df, variable_ids) + + # Initialize an imputer to handle missing values. + imputer = SimpleImputer(strategy="mean") + + for variable_id in tqdm(variable_ids): + for country, group in df.groupby("entity_name", observed=True): + # Get the time series for the current country and variable. + series = group[[variable_id]].copy() + + # Skip if the series is all zeros or nans. + if ((series == 0).all().values) or (series.dropna().shape[0] == 0): + continue + + # Impute missing values for this country's time series. + series_imputed = imputer.fit_transform(series) + + # Scale the data for better performance. + scaler = StandardScaler() + series_scaled = scaler.fit_transform(series_imputed) + + # Initialize the One-Class SVM model for this country's time series. + svm = OneClassSVM(kernel="rbf", gamma="scale", nu=0.05) + + # Fit the model and calculate anomaly scores. + svm.fit(series_scaled) + scores = svm.decision_function(series_scaled) + df_anomalies.loc[df["entity_name"] == country, variable_id] = scores + + return df_anomalies diff --git a/apps/anomalist/explore.ipynb b/apps/anomalist/explore.ipynb new file mode 100644 index 00000000000..18d379c0ac1 --- /dev/null +++ b/apps/anomalist/explore.ipynb @@ -0,0 +1,49 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# from myml.nbinit import *\n", + "import pandas as pd\n", + "from apps.anomalist.gp_anomaly import GPAnomalyDetector\n", + "from apps.anomalist.cli import load_data_for_variable\n", + "from etl.db import get_engine\n", + "from etl import grapher_model as gm\n", + "from sqlalchemy.orm import Session\n", + "\n", + "engine = get_engine()\n", + "\n", + "# get random dataset and random variable\n", + "q = \"\"\"\n", + "with t as (\n", + " select id from datasets order by rand() limit 1\n", + ")\n", + "select id from variables\n", + "where datasetId in (select id from t)\n", + "order by rand()\n", + "limit 1\n", + "\"\"\"\n", + "\n", + "mf = pd.read_sql(q, engine)\n", + "variable_id = mf.id[0]\n", + "\n", + "with Session(engine) as session:\n", + " variable = gm.Variable.load_variable(session, variable_id)\n", + "df = load_data_for_variable(engine, variable)\n", + "\n", + "gp = GPAnomalyDetector()\n", + "gp.viz(df, variable)" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/apps/anomalist/gp_detector.py b/apps/anomalist/gp_detector.py new file mode 100644 index 00000000000..bec71936713 --- /dev/null +++ b/apps/anomalist/gp_detector.py @@ -0,0 +1,253 @@ +import random +import time +import warnings +from multiprocessing import Pool +from typing import Dict, List, Optional + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import structlog +from joblib import Memory +from sklearn.exceptions import ConvergenceWarning +from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.gaussian_process.kernels import RBF, WhiteKernel +from tqdm.auto import tqdm + +from etl import grapher_model as gm +from etl.paths import CACHE_DIR + +from .detectors import AnomalyDetector + +log = structlog.get_logger() + + +memory = Memory(CACHE_DIR, verbose=0) + + +@memory.cache +def _load_population(): + from .anomalist_api import load_latest_population + + # Load the latest population data from the API + pop = load_latest_population() + # Filter the population data to only include the year 2024 + pop = pop[pop.year == 2024] + # Set 'entity_name' as the index and return the population series + return pop.set_index("entity_name")["population"] + + +def _processing_queue(items: list[tuple[str, int]]) -> List[tuple]: + """ + Create a processing queue of (entity, variable_id) pairs, weighted by population probability. + """ + # Load the population data (cached for efficiency) + population = _load_population().to_dict() + + # Create a probability array for each (entity, variable_id) pair based on the entity probability + probs = np.array([population.get(entity, np.nan) for entity, variable_id in items]) + + # Fill any missing values with the mean probability + probs = np.nan_to_num(probs, nan=np.nanmean(probs)) # type: ignore + + # Randomly shuffle the items based on their probabilities + items_index = np.random.choice( + len(items), + size=len(items), + replace=False, + p=probs / probs.sum(), + ) + + # Return the shuffled list of items + return np.array(items, dtype=object)[items_index] # type: ignore + + +class AnomalyGaussianProcessOutlier(AnomalyDetector): + anomaly_type = "gp_outlier" + + # TODO: max_time is hard-coded to 10, but it should be configurable in production + def __init__(self, max_time: Optional[float] = 10, n_jobs: int = 1): + self.max_time = max_time + self.n_jobs = n_jobs + + @staticmethod + def get_text(entity: str, year: int) -> str: + return f"There are some outliers for {entity}! These were detected using Gaussian processes. There might be other data points affected." + + def get_score_df(self, df: pd.DataFrame, variable_ids: List[int], variable_mapping: Dict[int, int]) -> pd.DataFrame: + # Convert to long format + df_wide = df.melt(id_vars=["entity_name", "year"]) + # Filter to only include the specified variable IDs. + df_wide = ( + df_wide[df_wide["variable_id"].isin(variable_ids)] + .set_index(["entity_name", "variable_id"]) + .dropna() + .sort_index() + ) + + # Create a processing queue with (entity_name, variable_id) pairs + items = _processing_queue( + list(df_wide.index.unique()), + ) + + start_time = time.time() + + results = [] + + # Iterate through each (entity_name, variable_id) pair in the processing queue + for entity_name, variable_id in tqdm(items): + # Stop processing if the maximum time is reached + if self.max_time is not None and (time.time() - start_time) > self.max_time: + log.info("Max processing time reached, stopping further processing.") + break + + # Get the data for the current entity and variable + group = df_wide.loc[(entity_name, variable_id)] + + # Skip if the series has only one or fewer data points + if isinstance(group, pd.Series) or len(group) <= 1: + continue + + # Prepare the input features (X) and target values (y) for Gaussian Process + X, y = self.get_Xy(pd.Series(group["value"].values, index=group["year"])) + + # Skip if the target values have zero standard deviation (i.e., all values are identical) + if y.std() == 0: + continue + + if self.n_jobs == 1: + # Fit the Gaussian Process model and make predictions + z = self.fit_predict_z(X, y) + z = pd.DataFrame({"anomaly_score": np.abs(z), "year": group["year"].values}, index=group.index) + results.append(z) + else: + # Add it to a list for parallel processing + results.append((self, X, y, group, start_time)) + + # Process results in parallel + # NOTE: There's a lot of overhead in parallelizing this, so the gains are minimal on my laptop. It could be + # better on a staging server. + if self.n_jobs != 1: + with Pool(self.n_jobs) as pool: + # Split the workload evenly across the number of jobs + chunksize = len(items) // self.n_jobs + 1 + results = pool.starmap(self._fit_parallel, results, chunksize=chunksize) + + log.info("Finished processing", elapsed=round(time.time() - start_time, 2)) + + if not results: + return pd.DataFrame() + + df_score_long = pd.concat(results).reset_index() + + # Normalize the anomaly scores by mapping interval (0, 3+) to (0, 1) + df_score_long["anomaly_score"] = np.minimum(df_score_long["anomaly_score"] / 3, 1) + + return df_score_long + + @staticmethod + def _fit_parallel(obj: "AnomalyGaussianProcessOutlier", X, y, group, start_time): + # Stop early + if obj.max_time and (time.time() - start_time) > obj.max_time: + return pd.DataFrame() + z = obj.fit_predict_z(X, y) + z = pd.DataFrame({"anomaly_score": np.abs(z), "year": group["year"].values}, index=group.index) + return z + + def get_Xy(self, series: pd.Series) -> tuple[np.ndarray, np.ndarray]: + X = series.index.values.reshape(-1, 1) + y = series.values + return X, y # type: ignore + + def fit_predict_z(self, X, y) -> np.ndarray: + # t = time.time() + mean_pred, std_pred = self.fit_predict(X, y) + # Calculate the Z-score for each point (standard score) + z = (y - mean_pred) / std_pred + # log.info( + # "Fitted GP", + # variable_id=variable_id, + # entity_name=entity_name, + # n_samples=len(X), + # elapsed=round(time.time() - t, 2), + # ) + return z + + def fit_predict(self, X, y): + # normalize data... but is it necessary? + X_mean = np.mean(X) + y_mean = np.mean(y) + y_std = np.std(y) + assert y_std > 0, "Standard deviation of y is zero" + + X_normalized = X - X_mean + y_normalized = (y - y_mean) / y_std + + x_range = X_normalized.max() - X_normalized.min() + + # TODO: we could also preprocess data by: + # - applying power transform to un-log data + # - removing linear trend + # - use Nystroem kernel approximation + + # Bounds are set to prevent overfitting to the data and missing outliers, especially + # the lower bounds for the length scale and noise level + length_scale_bounds = (min(1e1, x_range), max(1e3, x_range)) + noise_level_bounds = (1e-1, 1e1) + + kernel = 1.0 * RBF(length_scale_bounds=length_scale_bounds) + WhiteKernel(noise_level_bounds=noise_level_bounds) + + self.gp = gp = GaussianProcessRegressor(kernel=kernel, n_restarts_optimizer=0) + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ConvergenceWarning) + gp.fit(X_normalized, y_normalized) + + # print kernel + # log.info(f"Optimized Kernel: {gp.kernel_}") + + # Make predictions with confidence intervals + mean_pred, std_pred = gp.predict(X_normalized, return_std=True) # type: ignore + + # Denormalize + mean_pred = mean_pred * y_std + y_mean + std_pred = std_pred * y_std # type: ignore + + return mean_pred, std_pred + + def viz(self, df: pd.DataFrame, meta: gm.Variable, country: Optional[str] = None): + if df.empty: + log.warning("No data to visualize") + return + + country = country or random.choice(df.columns) + series = df[country].dropna() + + if len(series) <= 1: + log.warning(f"Insufficient data for {country}") + return + + X, y = self.get_Xy(series) + log.info("Fitting Gaussian Process", country=country, n_samples=len(X)) + mean_prediction, std_prediction = self.fit_predict(X, y) + + log.info(f"Optimized Kernel: {self.gp.kernel_}") + + plt.figure(figsize=(10, 6)) + plt.scatter(X, y, label="Observations") + plt.plot(X, y, linestyle="dotted", label="Observed Data") + plt.plot(X, mean_prediction, label="Mean Prediction", color="orange") + plt.fill_between( + X.ravel(), + mean_prediction - 1.96 * std_prediction, + mean_prediction + 1.96 * std_prediction, + alpha=0.3, + color="lightblue", + label=r"95% Confidence Interval", + ) + plt.legend() + plt.title(f"{meta.name}: {country}") + + z = (y - mean_prediction) / std_prediction + print("Max Z-score: ", np.abs(z).max()) + + plt.show() diff --git a/apps/cli/__init__.py b/apps/cli/__init__.py index 8d056da3db8..eb9a8e6d80c 100644 --- a/apps/cli/__init__.py +++ b/apps/cli/__init__.py @@ -191,6 +191,12 @@ def cli_back() -> None: "owidbot": "apps.owidbot.cli.cli", }, }, + { + "name": "Anomalist", + "commands": { + "anomalist": "apps.anomalist.cli.cli", + }, + }, ] # Add subgroups (don't modify) + subgroups diff --git a/apps/owidbot/anomalist.py b/apps/owidbot/anomalist.py new file mode 100644 index 00000000000..dab84205269 --- /dev/null +++ b/apps/owidbot/anomalist.py @@ -0,0 +1,61 @@ +from structlog import get_logger + +from apps.anomalist.anomalist_api import anomaly_detection +from apps.wizard.app_pages.anomalist.utils import load_variable_mapping +from etl import grapher_model as gm +from etl.config import OWIDEnv +from etl.db import Engine, read_sql + +from .chart_diff import production_or_master_engine + +log = get_logger() + + +def run(branch: str) -> None: + """Compute all anomalist for new and updated datasets.""" + # Get engines for branch and production + source_engine = OWIDEnv.from_staging(branch).get_engine() + target_engine = production_or_master_engine() + + # Create table with anomalist if it doesn't exist + gm.Anomaly.create_table(source_engine, if_exists="skip") + + # Load new dataset ids + datasets_new_ids = _load_datasets_new_ids(source_engine, target_engine) + + if not datasets_new_ids: + log.info("No new datasets found.") + return + + log.info(f"New datasets: {datasets_new_ids}") + + # Load all their variables + q = """SELECT id FROM variables WHERE datasetId IN %(dataset_ids)s""" + variable_ids = list(read_sql(q, source_engine, params={"dataset_ids": datasets_new_ids})["id"]) + + # Load variable mapping + variable_mapping_dict = load_variable_mapping(datasets_new_ids) + + # Run anomalist + anomaly_detection( + variable_mapping=variable_mapping_dict, + variable_ids=variable_ids, + ) + + +def _load_datasets_new_ids(source_engine: Engine, target_engine: Engine) -> list[int]: + # Get new datasets + # TODO: replace by real catalogPath when we have it in MySQL + q = """SELECT + id, + CONCAT(namespace, "/", version, "/", shortName) as catalogPath + FROM datasets + """ + source_datasets = read_sql(q, source_engine) + target_datasets = read_sql(q, target_engine) + + return list( + source_datasets[ + source_datasets.catalogPath.isin(set(source_datasets["catalogPath"]) - set(target_datasets["catalogPath"])) + ]["id"] + ) diff --git a/apps/owidbot/chart_diff.py b/apps/owidbot/chart_diff.py index 660e972b9bd..944dffb5fad 100644 --- a/apps/owidbot/chart_diff.py +++ b/apps/owidbot/chart_diff.py @@ -3,6 +3,7 @@ from apps.wizard.app_pages.chart_diff.chart_diff import ChartDiffsLoader from etl.config import OWID_ENV, OWIDEnv, get_container_name +from etl.db import Engine from . import github_utils as gh_utils @@ -66,14 +67,18 @@ def run(branch: str, charts_df: pd.DataFrame) -> str: return body -def call_chart_diff(branch: str) -> pd.DataFrame: - source_engine = OWIDEnv.from_staging(branch).get_engine() - +def production_or_master_engine() -> Engine: + """Return the production engine if available, otherwise connect to staging-site-master.""" if OWID_ENV.env_remote == "production": - target_engine = OWID_ENV.get_engine() + return OWID_ENV.get_engine() else: log.warning("ENV file doesn't connect to production DB, comparing against staging-site-master") - target_engine = OWIDEnv.from_staging("master").get_engine() + return OWIDEnv.from_staging("master").get_engine() + + +def call_chart_diff(branch: str) -> pd.DataFrame: + source_engine = OWIDEnv.from_staging(branch).get_engine() + target_engine = production_or_master_engine() df = ChartDiffsLoader(source_engine, target_engine).get_diffs_summary_df( config=True, diff --git a/apps/owidbot/cli.py b/apps/owidbot/cli.py index 498d5533e2a..c21a950c239 100644 --- a/apps/owidbot/cli.py +++ b/apps/owidbot/cli.py @@ -8,7 +8,7 @@ from rich import print from rich_click.rich_command import RichCommand -from apps.owidbot import chart_diff, data_diff, grapher +from apps.owidbot import anomalist, chart_diff, data_diff, grapher from etl.config import get_container_name from . import github_utils as gh_utils @@ -16,12 +16,13 @@ log = structlog.get_logger() REPOS = Literal["etl", "owid-grapher"] -SERVICES = Literal["data-diff", "chart-diff", "grapher"] +SERVICES = Literal["data-diff", "chart-diff", "grapher", "anomalist"] @click.command("owidbot", cls=RichCommand, help=__doc__) @click.argument("repo_branch", type=str) -@click.option("--services", type=click.Choice(get_args(SERVICES)), multiple=True) +# @click.option("--services", type=click.Choice(get_args(SERVICES)), multiple=True) +@click.option("--services", type=str, multiple=True) @click.option( "--include", type=str, @@ -36,7 +37,7 @@ ) def cli( repo_branch: str, - services: List[Literal[SERVICES]], + services: List[str], include: str, dry_run: bool, ) -> None: @@ -76,8 +77,16 @@ def cli( elif service == "grapher": services_body["grapher"] = grapher.run(branch) + + elif service == "anomalist": + # TODO: anomalist could post a summary of anomalies to the PR + anomalist.run(branch) + else: - raise AssertionError("Invalid service") + # We raise a warning instead of an error to make it backward compatible on old + # staging servers when adding a new service. + log.warning("Invalid service", service=service) + continue # get existing comment (do this as late as possible to avoid race conditions) comment = gh_utils.get_comment_from_pr(pr) diff --git a/apps/utils/map_datasets.py b/apps/utils/map_datasets.py index bd22bb21848..9c1976693f6 100644 --- a/apps/utils/map_datasets.py +++ b/apps/utils/map_datasets.py @@ -67,7 +67,11 @@ def get_grapher_changes(files_changed: Dict[str, Dict[str, str]], steps_df: pd.D steps_affected.append(candidate["step"].item()) steps_affected.extend(candidate["all_usages"].item()) - if (candidate["channel"].item() == "grapher") & (file_status == "A"): + # NOTE: I'm not sure why I originally imposed that file_status needed to be "A". + # I think it's possible that one commits changes to a new grapher dataset, and therefore they would appear as "M". + # So, I'll remove this condition for now. But if we detect issues, we may need to add it back. + # if (candidate["channel"].item() == "grapher") & (file_status == "A"): + if candidate["channel"].item() == "grapher": current_grapher_step = candidate["step"].item() ## Get grapher dataset id and name of the new dataset. diff --git a/apps/wizard/app_pages/anomalist/app.py b/apps/wizard/app_pages/anomalist/app.py new file mode 100644 index 00000000000..edf0c90d057 --- /dev/null +++ b/apps/wizard/app_pages/anomalist/app.py @@ -0,0 +1,685 @@ +"""Anomalist app page. + +The main structure of the app is implemented. Its main logic is: + +1. User fills the form with datasets. +2. User submits the form. +3. The app loads the datasets and checks the database if there are already anomalies for this dataset. + 3.1 If yes: the app loads already existing anomalies. (will show a warning with option to refresh) + 3.2 If no: the app estimates the anomalies. +4. The app shows the anomalies, along with filters to sort / re-order them. + +TODO: +- Test with upgrade flow more extensively. +- What happens with data that do not have years (e.g. dates)? +- We can infer if the anomalies are out of sync (because user has updated the data) by checking the dataset checksum. NOTE: checksum might change bc of metadata changes, so might show several false positives. +- Further explore LLM summary: + - We should store the LLM summary in the DB. We need a new table for this. Each summary is associated with a set of anomalies (Anomaly table), at a precise moment. We should detect out-of-sync here too. +- Hiding capabilities. Option to hide anomalies would be great. Idea: have a button in each anomaly box to hide it. We need a register of the hidden anomalies. We then could have a st.popover element in the filter section which only appears if there are anomalies hidden. Then, we can list them there, in case the user wants to unhide some. +- New dataset detection. We should explore if this can be done quicker. + +""" + +from typing import List, Tuple, cast + +import pandas as pd +import streamlit as st + +from apps.anomalist.anomalist_api import anomaly_detection, load_detector +from apps.utils.gpt import OpenAIWrapper, get_cost_and_tokens, get_number_tokens +from apps.wizard.app_pages.anomalist.utils import ( + AnomalyTypeEnum, + create_tables, + get_datasets_and_mapping_inputs, + get_scores, +) +from apps.wizard.utils import cached, set_states, url_persist +from apps.wizard.utils.chart_config import bake_chart_config +from apps.wizard.utils.components import Pagination, grapher_chart, st_horizontal, tag_in_md +from apps.wizard.utils.db import WizardDB +from etl.config import OWID_ENV +from etl.grapher_io import load_variables + +# PAGE CONFIG +st.set_page_config( + page_title="Wizard: Anomalist", + page_icon="🪄", + layout="wide", +) + + +# OTHER CONFIG +ANOMALY_TYPES = { + AnomalyTypeEnum.TIME_CHANGE.value: { + "tag_name": "Time change", + "color": "gray", + "icon": ":material/timeline", + }, + AnomalyTypeEnum.UPGRADE_CHANGE.value: { + "tag_name": "Version change", + "color": "orange", + "icon": ":material/upgrade", + }, + AnomalyTypeEnum.UPGRADE_MISSING.value: { + "tag_name": "Missing point", + "color": "red", + "icon": ":material/hide_source", + }, + AnomalyTypeEnum.GP_OUTLIER.value: { + "tag_name": "Gaussian Process", + "color": "blue", + "icon": ":material/notifications", + }, +} +ANOMALY_TYPE_NAMES = {k: v["tag_name"] for k, v in ANOMALY_TYPES.items()} +ANOMALY_TYPES_TO_DETECT = tuple(ANOMALY_TYPES.keys()) + +# GPT +MODEL_NAME = "gpt-4o" + +SORTING_STRATEGIES = { + "relevance": "Relevance", + "score": "Anomaly score", + "population": "Population", + "views": "Chart views", + "population+views": "Population+views", +} + +# SESSION STATE +# Datasets selected by the user in first multiselect +st.session_state.anomalist_datasets_selected = st.session_state.get("anomalist_datasets_selected", []) + +# Indicators corresponding to datasets selected by the user (plus variable mapping) +st.session_state.anomalist_indicators = st.session_state.get("anomalist_indicators", {}) +st.session_state.anomalist_mapping = st.session_state.get("anomalist_mapping", {}) + +# FLAG: True when user clicks submits form with datasets. Set to false by the end of the execution. +st.session_state.anomalist_datasets_submitted = st.session_state.get("anomalist_datasets_submitted", False) + +# List with anomalies found in the selected datasets (dataset last submitted in the form by the user) +st.session_state.anomalist_anomalies = st.session_state.get("anomalist_anomalies", []) +st.session_state.anomalist_df = st.session_state.get("anomalist_df", None) +# FLAG: True if the anomalies were directly loaded from DB (not estimated) +st.session_state.anomalist_anomalies_out_of_date = st.session_state.get("anomalist_anomalies_out_of_date", False) + +# Filter: Entities and indicators +st.session_state.anomalist_filter_entities = st.session_state.get("anomalist_filter_entities", []) +st.session_state.anomalist_filter_indicators = st.session_state.get("anomalist_filter_indicators", []) + +# Sorting +st.session_state.anomalist_sorting_columns = st.session_state.get("anomalist_sorting_columns", []) + +# FLAG: True to trigger anomaly detection manually +st.session_state.anomalist_trigger_detection = st.session_state.get("anomalist_trigger_detection", False) + + +###################################################################### +# FUNCTIONS +###################################################################### + + +@st.cache_data(ttl=60) +def get_variable_mapping(variable_ids): + """Get variable mapping for specific variable IDs.""" + # Get variable mapping, if exists. Then keep only 'relevant' variables + mapping = WizardDB.get_variable_mapping() + mapping = {k: v for k, v in mapping.items() if v in variable_ids} + return mapping + + +@st.fragment() +def llm_ask(df: pd.DataFrame): + st.button( + "AI Summary", + on_click=lambda: llm_dialog(df), + icon=":material/robot:", + help=f"Ask GPT {MODEL_NAME} to summarize the anomalies. This is experimental.", + ) + + +@st.dialog("AI summary of anomalies", width="large") +def llm_dialog(df: pd.DataFrame): + """Ask LLM for summary of the anomalies.""" + ask_llm_for_summary(df) + + +@st.cache_data +def ask_llm_for_summary(df: pd.DataFrame): + NUM_ANOMALIES_PER_TYPE = 2_000 + + variable_ids = list(df["indicator_id"].unique()) + metadata = load_variables(variable_ids) + # Get metadata summary + metadata_summary = "" + for m in metadata: + _summary = f"- {m.name}\n" f"- {m.descriptionShort}\n" f"- {m.unit}" + metadata_summary += f"{_summary}\n-------------\n" + + # df = st.session_state.anomalist_df + # Get dataframe + df = df[["entity_name", "year", "type", "indicator_id", "score_weighted"]] + # Keep top anomalies based on weighed score + df = df.sort_values("score_weighted", ascending=False) + df = cast(pd.DataFrame, df.head(NUM_ANOMALIES_PER_TYPE)) + + # Round score (reduce token number) + df["score_weighted"] = df["score_weighted"].apply(lambda x: int(round(100 * x))) + # Reshape, pivot indicator_score to have one score column per id + df = df.pivot_table( + index=["entity_name", "year", "type"], columns="indicator_id", values="score_weighted" + ).reset_index() + # As string (one per anomaly type) + groups = df.groupby("type") + df_str = "" + for group in groups: + _text = group[0] + _df = group[1].set_index(["entity_name", "year"]).drop(columns="type") + # Dataframe as string + _df_str = cast(str, _df.to_csv()).replace(".0,", ",") + text = f"### Anomalies of type '{_text}'\n\n{_df_str}\n\n-------------------\n\n" + df_str += text + + # Ask LLM for summary + client = OpenAIWrapper() + + # Prepare messages for Insighter + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": ( + f""" + The user has obtained anomalies for a list of indicators. This list comes in the format of a dataframe with columns: + - 'entity_name': Typically a country name. + - 'year': The year in which the anomaly was detected. + - 'type': The type of anomaly detected. Allowed types are: + - 'time_change': A significant change in the indicator over time. + - 'upgrade_change': A significant change in the indicator after an upgrade. + - 'upgrade_missing': A missing value in the indicator after an upgrade. + - 'gp_outlier': An outlier detected using Gaussian processes. + + Additionally, there is a column per indicator (identified by the indicator ID), with the weighed score of the anomaly. The weighed score is an estimate on how relevant the anomaly is, based on the anomaly score, population in the country, and views of charts using this indicator. + + The user will provide this dataframe. + + You should try to summarise this list of anomalies, so that the information is more digestable. Some ideas: + + - Try to find if there are common patterns across entities or years. + - Try to remove redundant information as much as possible. For instance: if the same entity has multiple anomalies of the same type, you can group them together. Or if the same entity has multiple anomalies of different types, you can group them together. + - Try to find the most relevant anomalies. Either because these affect multiple entities or because they have a high weighed score. + + Indicators are identified by column 'indicator_id'. To do a better judgement, find below the name, description and units details for each indicator. Use this information to provide a more insightful summary. + + {metadata_summary} + """ + ), + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": df_str, + }, + ], + }, + ] + + text_in = "\n".join([m["content"][0]["text"] for m in messages]) + num_tokens = get_number_tokens(text_in, MODEL_NAME) + + # Check if the number of tokens is within limits + if num_tokens > 128_000: + st.warning( + f"There are too many tokens in the GPT query to model {MODEL_NAME}. The query has {num_tokens} tokens, while the maximum allowed is 128,000. We will support this in the future. Raise this issue to re-prioritize it." + ) + else: + # Ask GPT (stream) + stream = client.chat.completions.create( + model=MODEL_NAME, + messages=messages, # type: ignore + max_tokens=3000, + stream=True, + ) + response = cast(str, st.write_stream(stream)) + + cost, num_tokens = get_cost_and_tokens(text_in, response, cast(str, MODEL_NAME)) + cost_msg = f"**Cost**: ≥{cost} USD.\n\n **Tokens**: ≥{num_tokens}." + st.info(cost_msg) + + +# Functions to filter the results +def filter_df(df: pd.DataFrame): + """Apply filters from user to the dataframe. + + Filter parameters are stored in the session state: + + - `anomalist_filter_entities`: list of entities to filter. + - `anomalist_filter_indicators`: list of indicators to filter. + - `anomalist_filter_anomaly_types`: list of anomaly types to filter. + - `anomalist_min_year`: minimum year to filter. + - `anomalist_max_year`: maximum year to filter. + - `anomalist_sorting_strategy`: sorting strategy. + """ + # Filter dataframe + df = _filter_df( + df=df, + year_min=st.session_state.anomalist_min_year, + year_max=st.session_state.anomalist_max_year, + anomaly_types=st.session_state.anomalist_filter_anomaly_types, + entities=st.session_state.anomalist_filter_entities, + indicators=st.session_state.anomalist_filter_indicators, + ) + ## Sort dataframe + df, st.session_state.anomalist_sorting_columns = _sort_df(df, st.session_state.anomalist_sorting_strategy) + return df + + +@st.cache_data +def _filter_df(df: pd.DataFrame, year_min, year_max, anomaly_types, entities, indicators) -> pd.DataFrame: + """Used in filter_df.""" + ## Year + df = df[(df["year"] >= year_min) & (df["year"] <= year_max)] + ## Anomaly type + if len(anomaly_types) > 0: + df = df[df["type"].isin(anomaly_types)] + ## Entities + if len(entities) > 0: + df = df[df["entity_name"].isin(entities)] + # Indicators + if len(indicators) > 0: + df = df[df["indicator_id"].isin(indicators)] + + return df + + +@st.cache_data +def _sort_df(df: pd.DataFrame, sort_strategy: str) -> Tuple[pd.DataFrame, List[str]]: + """Used in filter_df.""" + ## Sort + columns_sort = [] + match sort_strategy: + case "relevance": + columns_sort = ["score_weighted"] + case "score": + columns_sort = ["score"] + case "population": + columns_sort = ["score_population"] + case "views": + columns_sort = ["score_analytics"] + case "population+views": + columns_sort = ["score_population", "score_analytics"] + case _: + pass + if columns_sort != []: + df = df.sort_values(columns_sort, ascending=False) + + return df, columns_sort + + +# Functions to show the anomalies +@st.fragment +def show_anomaly_compact(index, df): + """Show anomaly compactly. + + Container with all anomalies of a certain type and for a concrete indicator. + """ + indicator_id, an_type = index + row = 0 + + key = f"{indicator_id}_{an_type}" + key_table = f"anomaly_table_{key}" + key_selection = f"selected_entities_{key}" + + # Get relevant metadata for this view + # By default, the entity with highest score, but user may have selected other ones! + entity_default = df.iloc[row]["entity_name"] + entities = st.session_state.get(f"selected_entities_{key}", entity_default) + entities = entities if entities != [] else [entity_default] + + # entities = df["entity_id"].tolist() + year_default = df.iloc[row]["year"] + indicator_uri = st.session_state.anomalist_indicators.get(indicator_id) + + # Generate descriptive text. Only contains information about top-scoring entity. + text = load_detector(an_type).get_text(entity_default, year_default) + + # Render + with st.container(border=True): + # Title + link = OWID_ENV.indicator_admin_site(indicator_id) + st.markdown(f"{tag_in_md(**ANOMALY_TYPES[an_type])} **[{indicator_uri}]({link})**") + col1, col2 = st.columns(2) + # Chart + with col1: + # Bake chart config + # If the anomaly is compared to previous indicator, then we need to show two indicators (old and new)! + if an_type in {AnomalyTypeEnum.UPGRADE_CHANGE.value, AnomalyTypeEnum.UPGRADE_MISSING.value}: + display = [ + { + "name": "New", + }, + { + "name": "Old", + }, + ] + assert indicator_id in st.session_state.anomalist_mapping_inv, "Indicator ID not found in mapping!" + indicator_id_old = st.session_state.anomalist_mapping_inv[indicator_id] + config = bake_chart_config( + variable_id=[indicator_id, indicator_id_old], + selected_entities=entities, + display=display, + ) + config["title"] = indicator_uri + config["subtitle"] = "Comparison of old and new indicator versions." + + # config = bake_chart_config( + # variable_id=[indicator_id], + # selected_entities=entities, + # ) + else: + config = bake_chart_config(variable_id=indicator_id, selected_entities=entities) + config["hideAnnotationFieldsInTitle"]["time"] = True + # Actually plot + grapher_chart(chart_config=config, owid_env=OWID_ENV) + + # Description and other entities + with col2: + # Description + st.info(text) + # Other entities + with st.container(border=False): + st.markdown("**Select** other affected entities") + st.dataframe( + df[["entity_name"] + st.session_state.anomalist_sorting_columns], + selection_mode=["multi-row"], + key=key_table, + on_select=lambda df=df, key_table=key_table, key_selection=key_selection: _change_chart_selection( + df, key_table, key_selection + ), + hide_index=True, + ) + + # TODO: Enable anomaly-specific hiding + # key_btn = f"button_{key}" + # st.button("Hide anomaly", key=key_btn, icon=":material/hide:") + + +def _change_chart_selection(df, key_table, key_selection): + """Change selection in grapher chart.""" + # st.toast(f"Changing entity in indicator {indicator_id}") + # Get selected row number + rows = st.session_state[key_table]["selection"]["rows"] + + # Update entities in chart + st.session_state[key_selection] = df.iloc[rows]["entity_name"].tolist() + + +###################################################################### + + +# Load the main inputs: +# * List of all Grapher datasets. +# * List of newly created Grapher datasets (the ones we most likely want to inspect). +# * The variable mapping generated by "indicator upgrader", if there was any. +DATASETS_ALL, DATASETS_NEW, VARIABLE_MAPPING = get_datasets_and_mapping_inputs() + +# Create DB tables +create_tables() + +############################################################################ +# RENDER +# Below you can find the different elements of Anomalist being rendered. +############################################################################ + +# 1/ PAGE TITLE +# Show title +st.title(":material/planner_review: Anomalist") + + +# 2/ DATASET FORM +# Ask user to select datasets. By default, we select the new datasets (those that are new in the current PR compared to master). +st.markdown( + """ + """, + unsafe_allow_html=True, +) + +with st.form(key="dataset_search"): + query_dataset_ids = [int(v) for v in st.query_params.get_all("anomalist_datasets_selected")] + + st.session_state.anomalist_datasets_selected = st.multiselect( + "Select datasets", + # options=cached.load_dataset_uris(), + options=DATASETS_ALL.keys(), + # max_selections=1, + default=query_dataset_ids or DATASETS_NEW.keys(), + format_func=DATASETS_ALL.get, + ) + st.query_params["anomalist_datasets_selected"] = st.session_state.anomalist_datasets_selected # type: ignore + + st.form_submit_button( + "Detect anomalies", + type="primary", + help="This will load the indicators from the selected datasets and scan for anomalies. This can take some time.", + on_click=lambda: set_states({"anomalist_datasets_submitted": True}), + ) + + +# 3/ SCAN FOR ANOMALIES +# If anomalies for dataset already exist in DB, load them. Warn user that these are being loaded from DB +if not st.session_state.anomalist_anomalies or st.session_state.anomalist_datasets_submitted: + # 3.1/ Check if anomalies are already there in DB + with st.spinner("Loading anomalies (already detected) from database..."): + st.session_state.anomalist_anomalies = WizardDB.load_anomalies(st.session_state.anomalist_datasets_selected) + + # Load indicators in selected datasets + st.session_state.anomalist_indicators = cached.load_variables_display_in_dataset( + dataset_id=st.session_state.anomalist_datasets_selected, + only_slug=True, + ) + + # Get indicator IDs + variable_ids = list(st.session_state.anomalist_indicators.keys()) + st.session_state.anomalist_mapping = {k: v for k, v in VARIABLE_MAPPING.items() if v in variable_ids} + st.session_state.anomalist_mapping_inv = {v: k for k, v in st.session_state.anomalist_mapping.items()} + + # 3.2/ No anomaly found in DB, estimate them + if (len(st.session_state.anomalist_anomalies) == 0) | (st.session_state.anomalist_trigger_detection): + # Reset flag + st.session_state.anomalist_anomalies_out_of_date = False + + with st.spinner("Scanning for anomalies... This can take some time."): + anomaly_detection( + anomaly_types=ANOMALY_TYPES_TO_DETECT, + variable_ids=variable_ids, + variable_mapping=st.session_state.anomalist_mapping, + dry_run=False, + reset_db=False, + ) + + # Fill list of anomalies... + st.session_state.anomalist_anomalies = WizardDB.load_anomalies(st.session_state.anomalist_datasets_selected) + + # Reset manual trigger + st.session_state.anomalist_trigger_detection = False + + # 3.3/ Anomalies found in DB. If outdated, set FLAG to True, so we can show a warning later on. + else: + # Check if data in DB is out of date + data_out_of_date = True + + # Set flag (if data is out of date) + if data_out_of_date: + st.session_state.anomalist_anomalies_out_of_date = True + else: + st.session_state.anomalist_anomalies_out_of_date = False + + # 3.4/ Parse obtained anomalist into dataframe + if len(st.session_state.anomalist_anomalies) > 0: + # Combine scores from all anomalies, reduce them (to get the maximum anomaly score for each entity-indicator), + # and add population and analytics scores. + df = get_scores(anomalies=st.session_state.anomalist_anomalies) + + st.session_state.anomalist_df = df + else: + st.session_state.anomalist_df = None + +# 4/ SHOW ANOMALIES (only if any are found) +if st.session_state.anomalist_df is not None: + ENTITIES_AVAILABLE = st.session_state.anomalist_df["entity_name"].unique() + YEAR_MIN = st.session_state.anomalist_df["year"].min() + YEAR_MAX = st.session_state.anomalist_df["year"].max() + INDICATORS_AVAILABLE = st.session_state.anomalist_df["indicator_id"].unique() + ANOMALY_TYPES_AVAILABLE = st.session_state.anomalist_df["type"].unique() + + # 4.0/ WARNING: Show warning if anomalies are loaded from DB without re-computing + # TODO: we could actually know if anomalies are out of sync from dataset/indicators. Maybe based on dataset/indicator checksums? Starting to implement this idea with data_out_of_date + if st.session_state.anomalist_anomalies_out_of_date: + st.caption( + "Anomalies are being loaded from the database. This might be out of sync with current dataset. Click on button below to run the anomaly-detection algorithm again." + ) + st.button( + "Re-scan datasets for anomalies", + icon="🔄", + on_click=lambda: set_states( + { + "anomalist_trigger_detection": True, + "anomalist_datasets_submitted": True, + } + ), + ) + + # 4.1/ ASK FOR FILTER PARAMS + # User can customize which anomalies are shown to them + with st.container(border=True): + st.markdown("##### Select filters") + + # If there is a dataset selected, load the indicators + if len(st.session_state.anomalist_datasets_selected) > 0: + # Load anomalies + st.session_state.anomalist_indicators = cached.load_variables_display_in_dataset( + dataset_id=st.session_state.anomalist_datasets_selected, + only_slug=True, + ) + + col1, col2 = st.columns([10, 4]) + # Indicator + with col1: + options = [ + indicator for indicator in INDICATORS_AVAILABLE if indicator in st.session_state.anomalist_indicators + ] + + url_persist(st.multiselect)( + label="Indicators", + options=options, + format_func=st.session_state.anomalist_indicators.get, + help="Show anomalies affecting only a selection of indicators.", + placeholder="Select indicators", + key="anomalist_filter_indicators", + ) + + with col2: + # Entity + url_persist(st.multiselect)( + label="Entities", + options=ENTITIES_AVAILABLE, + help="Show anomalies affecting only a selection of entities.", + placeholder="Select entities", + key="anomalist_filter_entities", + ) + + # Anomaly type + col1, col2, _ = st.columns(3) + with col1: + cols = st.columns(2) + with cols[0]: + st.selectbox( + label="Sort by", + options=SORTING_STRATEGIES.keys(), + format_func=SORTING_STRATEGIES.get, + help=( + """ + Sort anomalies by a certain criteria. + + - **Relevance**: This is a combined score based on population in country, views of charts using this indicator, and anomaly-algorithm error score. The higher this score, the more relevant the anomaly. + - **Anomaly score**: The anomaly detection algorithm assigns a score to each anomaly based on its significance. + - **Population**: Population score, based on the population in the affected country. + - **Views**: Views of charts using this indicator. + - **Population+views**: Combined population and chart views to rank. + """ + ), + key="anomalist_sorting_strategy", + ) + with cols[1]: + url_persist(st.multiselect)( + label="Detectors", + options=ANOMALY_TYPES_AVAILABLE, + format_func=ANOMALY_TYPE_NAMES.get, + help="Show anomalies of a certain type.", + placeholder="Select anomaly types", + key="anomalist_filter_anomaly_types", + ) + with col2: + with st_horizontal(): + url_persist(st.number_input)( + "Min year", + value=YEAR_MIN, + min_value=YEAR_MIN, + max_value=YEAR_MAX, + step=1, + key="anomalist_min_year", + ) + url_persist(st.number_input)( + "Max year", + value=YEAR_MAX, + min_value=YEAR_MIN, + max_value=YEAR_MAX, + step=1, + key="anomalist_max_year", + ) + + # 4.3/ APPLY FILTERS + df = filter_df(st.session_state.anomalist_df) + + # 5/ SHOW ANOMALIES + # Different types need formatting + # mask = df["type"] == "upgrade_missing" + # df_missing = df[mask] + # df_change = df[~mask] + + # Show anomalies with time and version changes + if not df.empty: + # LLM summary option + llm_ask(df) + + # st.dataframe(df_change) + groups = df.groupby(["indicator_id", "type"], sort=False) + items = list(groups) + items_per_page = 10 + + # Define pagination + pagination = Pagination( + items=items, + items_per_page=items_per_page, + pagination_key="pagination-demo", + ) + + # Show items (only current page) + for item in pagination.get_page_items(): + show_anomaly_compact(item[0], item[1]) + + # Show controls only if needed + if len(items) > items_per_page: + pagination.show_controls(mode="bar") + +# Reset state +set_states({"anomalist_datasets_submitted": False}) diff --git a/apps/wizard/app_pages/anomalist/mock.py b/apps/wizard/app_pages/anomalist/mock.py new file mode 100644 index 00000000000..b37af5edbc1 --- /dev/null +++ b/apps/wizard/app_pages/anomalist/mock.py @@ -0,0 +1,114 @@ +import random + +import pandas as pd +import streamlit as st + +from apps.wizard.app_pages.anomalist.utils import AnomalyTypeEnum +from apps.wizard.utils import cached + +# This should be removed and replaced with dynamic fields +ENTITIES_DEFAULT = [ + "Spain", + "France", + "Germany", + "Italy", + "United Kingdom", + "United States", + "China", + "India", + "Japan", + "Brazil", + "Russia", + "Canada", + "South Africa", + "Australia", + "Venezuela", + "Croatia", + "Azerbaijan", +] + + +def mock_anomalies_df_time_change(indicators_id, n=5): + records = [ + { + "entity": random.sample(ENTITIES_DEFAULT, 1)[0], + "year": random.randint(1950, 2020), + "score": round(random.random(), 2), + "indicator_id": random.sample(indicators_id, 1)[0], + } + for i in range(n) + ] + + df = pd.DataFrame(records) + return df + + +def mock_anomalies_df_upgrade_change(indicators_id_upgrade, n=5): + records = [ + { + "entity": random.sample(ENTITIES_DEFAULT, 1)[0], + "year": random.randint(1950, 2020), + "score": round(random.random(), 2), + "indicator_id": random.sample(indicators_id_upgrade, 1)[0], + } + for i in range(n) + ] + + df = pd.DataFrame(records) + return df + + +def mock_anomalies_df_upgrade_missing(indicators_id_upgrade, n=5): + records = [ + { + "entity": random.sample(ENTITIES_DEFAULT, 1)[0], + "year": random.randint(1950, 2020), + "score": random.randint(0, 1), + "indicator_id": random.sample(indicators_id_upgrade, 1)[0], + } + for i in range(n) + ] + + df = pd.DataFrame(records) + return df + + +@st.cache_data(ttl=60 * 60) +def mock_anomalies_df(indicators_id, indicators_id_upgrade, n=5): + # 1/ Get anomalies df + ## Time change + df_change = mock_anomalies_df_time_change(indicators_id, n) + df_change["type"] = AnomalyTypeEnum.TIME_CHANGE.value + ## Upgrade: value change + df_upgrade_change = mock_anomalies_df_upgrade_change(indicators_id_upgrade, n) + df_upgrade_change["type"] = AnomalyTypeEnum.UPGRADE_CHANGE.value + ## Upgrade: Missing data point + df_upgrade_miss = mock_anomalies_df_upgrade_missing(indicators_id_upgrade, n) + df_upgrade_miss["type"] = AnomalyTypeEnum.UPGRADE_MISSING.value + + # 2/ Combine + df = pd.concat([df_change, df_upgrade_change, df_upgrade_miss]) + + # Ensure there is only one row per entity, anomaly type and indicator + df = df.sort_values("score", ascending=False).drop_duplicates(["entity", "type", "indicator_id"]) + + # Replace entity name with entity ID + entity_mapping = cached.load_entity_ids() + entity_mapping_inv = {v: k for k, v in entity_mapping.items()} + df["entity_id"] = df["entity"].map(entity_mapping_inv) + # st.write(entity_mapping) + + # 3/ Add meta scores + num_scores = len(df) + df["score_population"] = [random.random() for i in range(num_scores)] + df["score_analytics"] = [random.random() for i in range(num_scores)] + + # 4/ Weighed combined score + # Weighed combined score + w_score = 1 + w_pop = 1 + w_views = 1 + df["score_weighted"] = ( + w_score * df["score"] + w_pop * df["score_population"] + w_views * df["score_analytics"] + ) / (w_score + w_pop + w_views) + return df diff --git a/apps/wizard/app_pages/anomalist.py b/apps/wizard/app_pages/anomalist/old.py similarity index 98% rename from apps/wizard/app_pages/anomalist.py rename to apps/wizard/app_pages/anomalist/old.py index cb76e03c806..7440961a9ab 100644 --- a/apps/wizard/app_pages/anomalist.py +++ b/apps/wizard/app_pages/anomalist/old.py @@ -1,3 +1,7 @@ +"""This code contains the first demo for Anomalist. + +This will be removed, but is kept alive since it may contain useful code snippets (e.g. streaming structured LLM outputs). +""" from typing import cast import streamlit as st diff --git a/apps/wizard/app_pages/anomalist/utils.py b/apps/wizard/app_pages/anomalist/utils.py new file mode 100644 index 00000000000..39f07080a9a --- /dev/null +++ b/apps/wizard/app_pages/anomalist/utils.py @@ -0,0 +1,123 @@ +"""Utils for chart revision tool.""" +import time +from enum import Enum +from typing import Dict, List, Optional, Tuple + +import pandas as pd +import streamlit as st +from sqlalchemy.orm import Session +from structlog import get_logger + +import etl.grapher_model as gm +from apps.anomalist.anomalist_api import add_auxiliary_scores, combine_and_reduce_scores_df +from apps.wizard.utils.db import WizardDB +from apps.wizard.utils.io import get_new_grapher_datasets_and_their_previous_versions +from etl.config import OWID_ENV, OWIDEnv +from etl.db import get_engine + +# Logger +log = get_logger() + + +class AnomalyTypeEnum(Enum): + TIME_CHANGE = "time_change" + UPGRADE_CHANGE = "upgrade_change" + UPGRADE_MISSING = "upgrade_missing" + GP_OUTLIER = "gp_outlier" + # AI = "ai" # Uncomment if needed + + +def infer_variable_mapping(dataset_id_new: int, dataset_id_old: int) -> Dict[int, int]: + engine = get_engine() + with Session(engine) as session: + variables_new = gm.Variable.load_variables_in_datasets(session=session, dataset_ids=[dataset_id_new]) + variables_old = gm.Variable.load_variables_in_datasets(session=session, dataset_ids=[dataset_id_old]) + # Create a mapping from old ids to new variable ids for variables whose shortNames are identical in the old and new versions. + _variables = {variable.shortName: variable.id for variable in variables_new} + variable_mapping = { + old_variable.id: _variables[old_variable.shortName] + for old_variable in variables_old + if old_variable.shortName in _variables + } + return variable_mapping + + +@st.cache_data(show_spinner=False) +@st.spinner("Retrieving datasets...") +def get_datasets_and_mapping_inputs() -> Tuple[Dict[int, str], Dict[int, str], Dict[int, int]]: + t = time.time() + # Get all datasets from DB. + df_datasets = gm.Dataset.load_all_datasets(columns=["id", "name"]) + + # Detect local files that correspond to new or modified grapher steps, identify their corresponding grapher dataset ids, and the grapher dataset id of the previous version (if any). + # NOTE: this is quite slow taking ~4s, it would be faster to reuse function `_load_datasets_new_ids` from owidbot/anomalist.py + dataset_new_and_old = get_new_grapher_datasets_and_their_previous_versions() + + # List new dataset ids. + datasets_new_ids = list(dataset_new_and_old) + + # Load mapping created by indicator upgrader (if any). + variable_mapping = load_variable_mapping(datasets_new_ids, dataset_new_and_old) + + # For convenience, create a dataset name "[id] Name". + df_datasets["id_name"] = "[" + df_datasets["id"].astype(str) + "] " + df_datasets["name"] + # List all grapher datasets. + datasets_all = df_datasets[["id", "id_name"]].set_index("id").squeeze().to_dict() + # List new datasets. + datasets_new = {k: v for k, v in datasets_all.items() if k in datasets_new_ids} + + log.info("get_datasets_and_mapping_inputs", t=time.time() - t) + + return datasets_all, datasets_new, variable_mapping # type: ignore + + +def load_variable_mapping( + datasets_new_ids: List[int], dataset_new_and_old: Optional[Dict[int, Optional[int]]] = None +) -> Dict[int, int]: + mapping = WizardDB.get_variable_mapping_raw() + if len(mapping) > 0: + log.info("Using variable mapping created by indicator upgrader.") + # Set of ids of new datasets that appear in the mapping generated by indicator upgrader. + datasets_new_mapped = set(mapping["dataset_id_new"]) + # Set of ids of expected new datasets. + datasets_new_expected = set(datasets_new_ids) + # Sanity check. + if not (datasets_new_mapped <= datasets_new_expected): + log.error( + f"Indicator upgrader mapped indicators to new datasets ({datasets_new_mapped}) that are not among the datasets detected as new in the code ({datasets_new_expected}). Look into this." + ) + # Create a mapping dictionary. + variable_mapping = mapping.set_index("id_old")["id_new"].to_dict() + elif dataset_new_and_old: + log.info("Inferring variable mapping (since no mapping was created by indicator upgrader).") + # Infer the mapping of the new datasets (assuming no names have changed). + variable_mapping = dict() + for dataset_id_new, dataset_id_old in dataset_new_and_old.items(): + if dataset_id_old is None: + continue + # Infer + variable_mapping.update(infer_variable_mapping(dataset_id_new, dataset_id_old)) + else: + # No mapping available. + variable_mapping = dict() + + return variable_mapping # type: ignore + + +def create_tables(_owid_env: OWIDEnv = OWID_ENV): + """Create all required tables. + + If exist, nothing is created. + """ + gm.Anomaly.create_table(_owid_env.engine, if_exists="skip") + + +@st.cache_data(show_spinner=False) +def get_scores(anomalies: List[gm.Anomaly]) -> pd.DataFrame: + """Combine and reduce scores dataframe.""" + df = combine_and_reduce_scores_df(anomalies) + + # Add a population score, an analytics score, and a weighted score. + df = add_auxiliary_scores(df=df) + + return df diff --git a/apps/wizard/app_pages/anomalist_2.py b/apps/wizard/app_pages/anomalist_2.py deleted file mode 100644 index 61369b2880a..00000000000 --- a/apps/wizard/app_pages/anomalist_2.py +++ /dev/null @@ -1,110 +0,0 @@ -import pandas as pd -import streamlit as st - -from apps.wizard.utils import cached -from apps.wizard.utils.components import grapher_chart, st_horizontal - -# PAGE CONFIG -st.set_page_config( - page_title="Wizard: Anomalist", - page_icon="🪄", -) -# OTHER CONFIG -ANOMALY_TYPES = [ - "Upgrade", - "Abrupt change", - "Context change", -] - -# SESSION STATE -st.session_state.datasets_selected = st.session_state.get("datasets_selected", []) -st.session_state.filter_indicators = st.session_state.get("filter_indicators", []) -st.session_state.indicators = st.session_state.get("indicators", []) - -# PAGE TITLE -st.title(":material/planner_review: Anomalist") - - -# DATASET SEARCH -st.markdown( - """ - """, - unsafe_allow_html=True, -) -with st.form(key="dataset_search"): - st.session_state.datasets_selected = st.multiselect( - "Select datasets", - options=cached.load_dataset_uris(), - max_selections=1, - ) - - st.form_submit_button("Detect anomalies", type="primary") - - -# FILTER PARAMS -with st.container(border=True): - st.markdown("##### Filter Parameters") - options = [] - if len(st.session_state.datasets_selected) > 0: - st.session_state.indicators = cached.load_variables_in_dataset(st.session_state.datasets_selected) - options = [o.catalogPath for o in st.session_state.indicators] - - st.session_state.filter_indicators = st.multiselect( - label="Indicator", - options=options, - ) - - with st_horizontal(): - st.session_state.filter_indicators = st.multiselect( - label="Indicator type", - options=["New indicator", "Indicator upgrade"], - ) - st.session_state.filter_indicators = st.multiselect( - label="Anomaly type", - options=ANOMALY_TYPES, - ) - - # st.multiselect("Anomaly type", min_value=0.0, max_value=1.0, value=0.5, step=0.01) - st.number_input("Minimum score", min_value=0.0, max_value=1.0, value=0.5, step=0.01) - -# SHOW ANOMALIES -data = { - "anomaly": ["Anomaly 1", "Anomaly 2", "Anomaly 3"], - "description": ["Description 1", "Description 2", "Description 3"], -} - - -# SHOW ANOMALIES -def show_anomaly(df: pd.DataFrame): - if len(st.session_state.anomalies["selection"]["rows"]) > 0: - # Get selected row number - row_num = st.session_state.anomalies["selection"]["rows"][0] - # Get indicator id - indicator_id = df.index[row_num] - action(indicator_id) - - -@st.dialog("Show anomaly", width="large") -def action(indicator_id): - grapher_chart(variable_id=indicator_id) - - -if len(st.session_state.indicators) > 0: - df = pd.DataFrame( - { - "indicator_id": [i.id for i in st.session_state.indicators], - "reviewed": [False for i in st.session_state.indicators], - }, - ).set_index("indicator_id") - - st.dataframe( - df, - key="anomalies", - selection_mode="single-row", - on_select=lambda df=df: show_anomaly(df), - use_container_width=True, - ) diff --git a/apps/wizard/app_pages/indicator_upgrade/app.py b/apps/wizard/app_pages/indicator_upgrade/app.py index 73344caef7f..8b0ee1391a8 100644 --- a/apps/wizard/app_pages/indicator_upgrade/app.py +++ b/apps/wizard/app_pages/indicator_upgrade/app.py @@ -26,7 +26,12 @@ from structlog import get_logger from apps.wizard import utils -from apps.wizard.app_pages.indicator_upgrade.charts_update import get_affected_charts_and_preview, push_new_charts +from apps.wizard.app_pages.indicator_upgrade.charts_update import ( + get_affected_charts_and_preview, + push_new_charts, + save_variable_mapping, + undo_upgrade_dialog, +) from apps.wizard.app_pages.indicator_upgrade.dataset_selection import build_dataset_form from apps.wizard.app_pages.indicator_upgrade.indicator_mapping import render_indicator_mapping from apps.wizard.app_pages.indicator_upgrade.utils import get_datasets @@ -40,7 +45,7 @@ page_title="Wizard: Indicator Upgrader", layout="wide", page_icon="🪄", - initial_sidebar_state="collapsed", + # initial_sidebar_state="collapsed", menu_items={ "Report a bug": "https://github.com/owid/etl/issues/new?assignees=marigold%2Clucasrodes&labels=wizard&projects=&template=wizard-issue---.md&title=wizard%3A+meaningful+title+for+the+issue", "About": """ @@ -148,14 +153,40 @@ if st.session_state.submitted_charts: if isinstance(charts, list) and len(charts) > 0: st.toast("Updating charts...") + + # Push charts push_new_charts(charts) + # Save variable mapping + save_variable_mapping( + indicator_mapping=indicator_mapping, + dataset_id_new=int(search_form.dataset_new_id), + dataset_id_old=int(search_form.dataset_old_id), + comments="Done with indicator-upgrader", + ) + + # Undo upgrade + st.markdown("Do you want to undo the indicator upgrade?") + st.button( + "Undo upgrade", + on_click=undo_upgrade_dialog, + icon=":material/undo:", + help="Undo all indicator upgrades", + key="btn_undo_upgrade_end", + ) + + ########################################################################################## -# 4 UPDATE CHARTS +# 5 UNDO UPGRADE # -# TODO: add description +# You may have accidentally upgraded the wrong indicators. Here you can undo the upgrade. ########################################################################################## -# if st.session_state.submitted_datasets and st.session_state.submitted_charts and st.session_state.submitted_indicators: -# if isinstance(charts, list) and len(charts) > 0: -# st.toast("Updating charts...") -# push_new_charts(charts, SCHEMA_CHART_CONFIG) +with st.sidebar: + st.markdown("### Advanced tools") + st.button( + "Undo upgrade", + on_click=undo_upgrade_dialog, + icon=":material/undo:", + help="Undo all indicator upgrades", + key="btn_undo_upgrade_sidebar", + ) diff --git a/apps/wizard/app_pages/indicator_upgrade/charts_update.py b/apps/wizard/app_pages/indicator_upgrade/charts_update.py index 5fd210f578d..b88800f4649 100644 --- a/apps/wizard/app_pages/indicator_upgrade/charts_update.py +++ b/apps/wizard/app_pages/indicator_upgrade/charts_update.py @@ -10,6 +10,7 @@ import etl.grapher_model as gm from apps.chart_sync.admin_api import AdminAPI from apps.wizard.utils import set_states, st_page_link, st_toast_error +from apps.wizard.utils.db import WizardDB from etl.config import OWID_ENV from etl.helpers import get_schema_from_url from etl.indicator_upgrade.indicator_update import find_charts_from_variable_ids, update_chart_config @@ -130,3 +131,59 @@ def push_new_charts(charts: List[gm.Chart]) -> None: "The charts were successfully updated! If indicators from other datasets also need to be upgraded, simply refresh this page, otherwise move on to `chart diff` to review all changes." ) st_page_link("chart-diff") + + +def save_variable_mapping( + indicator_mapping: Dict[int, int], dataset_id_new: int, dataset_id_old: int, comments: str = "" +) -> None: + WizardDB.add_variable_mapping( + mapping=indicator_mapping, + dataset_id_new=dataset_id_new, + dataset_id_old=dataset_id_old, + comments=comments, + ) + + +def undo_indicator_upgrade(indicator_mapping): + mapping_inverted = {v: k for k, v in indicator_mapping.items()} + with st.spinner("Undoing upgrade..."): + # Get affected charts + charts = get_affected_charts_and_preview( + mapping_inverted, + ) + + # TODO: instead of pushing new charts, we should revert the changes! + # To do this, we should have kept a copy or reference to the original revision. + # Idea: when 'push_new_charts' is called, store in a table the original revision of the chart. + push_new_charts(charts) + + # Reset variable mapping + WizardDB.delete_variable_mapping() + + +@st.dialog("Undo upgrade", width="large") +def undo_upgrade_dialog(): + mapping = WizardDB.get_variable_mapping() + + if mapping != {}: + st.markdown( + "The following table shows the indicator mapping that has been applied to the charts. Undoing means inverting this mapping." + ) + data = { + "id_old": list(mapping.keys()), + "id_new": list(mapping.values()), + } + st.dataframe(data) + st.button( + "Undo upgrade", + on_click=lambda m=mapping: undo_indicator_upgrade(m), + icon=":material/undo:", + help="Undo all indicator upgrades", + type="primary", + key="btn_undo_upgrade_2", + ) + st.warning( + "Charts will still appear in chart-diff. This is because the chart configs have actually changed (their version has beem bumped). In the future, we do not want to show these charts in chart-diff. For the time being, you should reject these chart diffs." + ) + else: + st.markdown("No indicator mapping found. Nothing to undo.") diff --git a/apps/wizard/app_pages/indicator_upgrade/indicator_mapping.py b/apps/wizard/app_pages/indicator_upgrade/indicator_mapping.py index 695738adb5b..e5e6c03b370 100644 --- a/apps/wizard/app_pages/indicator_upgrade/indicator_mapping.py +++ b/apps/wizard/app_pages/indicator_upgrade/indicator_mapping.py @@ -362,7 +362,7 @@ def __init__(self, indicator_upgrade: "IndicatorUpgrade"): @st.fragment def render(self, indicator_id_to_display, df_data=None): with st.container(border=True): - cols = [100, 10, 10] + cols = [100, 5, 10] cols = st.columns(cols, vertical_alignment="bottom") # Indicators (old, new) @@ -451,7 +451,7 @@ def _set_states_checkbox(): st.session_state[k][self.iu.key] = not st.session_state[k][self.iu.key] st.checkbox( - label="Ignore", + label="Skip", key=self.iu.key_ignore, # label_visibility="collapsed", value=self.iu.skip, diff --git a/apps/wizard/app_pages/indicator_upgrade/utils.py b/apps/wizard/app_pages/indicator_upgrade/utils.py index e737bf534b8..f432ca6cffc 100644 --- a/apps/wizard/app_pages/indicator_upgrade/utils.py +++ b/apps/wizard/app_pages/indicator_upgrade/utils.py @@ -1,5 +1,5 @@ """Utils for chart revision tool.""" -from typing import Any, Dict, List, Tuple, cast +from typing import Dict, Tuple, cast import pandas as pd import streamlit as st @@ -7,49 +7,19 @@ from rapidfuzz import fuzz from structlog import get_logger -from apps.utils.map_datasets import get_grapher_changes -from etl import config +from apps.wizard.utils.io import get_steps_df from etl.db import get_connection -from etl.git_helpers import get_changed_files from etl.grapher_io import get_all_datasets, get_dataset_charts, get_variables_in_dataset from etl.match_variables import find_mapping_suggestions, preliminary_mapping -from etl.version_tracker import VersionTracker # Logger log = get_logger() @st.spinner("Retrieving datasets...") -def get_datasets(archived) -> pd.DataFrame: - steps_df_grapher, grapher_changes = get_datasets_from_version_tracker() - - # Combine with datasets from database that are not present in ETL - # Get datasets from Database - try: - datasets_db = get_all_datasets(archived=archived) - except OperationalError as e: - raise OperationalError( - f"Could not retrieve datasets. Try reloading the page. If the error persists, please report an issue. Error: {e}" - ) - - # TODO: replace concat with merge - # steps_df_grapher = pd.concat([steps_df_grapher, datasets_db], ignore_index=True) - # steps_df_grapher = steps_df_grapher.drop_duplicates(subset="id").drop(columns="updatedAt").astype({"id": int}) - - # Get table with all datasets (ETL + DB) - steps_df_grapher = ( - steps_df_grapher.merge(datasets_db, on="id", how="outer", suffixes=("_etl", "_db")) - .sort_values(by="id", ascending=False) - .drop(columns="updatedAt") - .astype({"id": int}) - ) - columns = ["namespace", "name"] - for col in columns: - steps_df_grapher[col] = steps_df_grapher[f"{col}_etl"].fillna(steps_df_grapher[f"{col}_db"]) - steps_df_grapher = steps_df_grapher.drop(columns=[f"{col}_etl", f"{col}_db"]) - - assert steps_df_grapher["name"].notna().all(), "NaNs found in `name`" - assert steps_df_grapher["namespace"].notna().all(), "NaNs found in `namespace`" +def get_datasets(archived: bool) -> pd.DataFrame: + # Get steps_df and grapher_changes + steps_df_grapher, grapher_changes = get_steps_df(archived=archived) # Add column marking migrations steps_df_grapher["migration_new"] = False @@ -104,37 +74,6 @@ def get_datasets(archived) -> pd.DataFrame: return steps_df_grapher -@st.cache_data(show_spinner=False) -def get_datasets_from_version_tracker() -> Tuple[pd.DataFrame, List[Dict[str, Any]]]: - # Get steps_df - vt = VersionTracker() - assert vt.connect_to_db, "Can't connnect to database! You need to be connected to run indicator upgrader" - steps_df = vt.steps_df - - # Get file changes -> Infer dataset migrations - files_changed = get_changed_files() - grapher_changes = get_grapher_changes(files_changed, steps_df) - - # Only keep grapher steps - steps_df_grapher = steps_df.loc[ - steps_df["channel"] == "grapher", ["namespace", "identifier", "step", "db_dataset_name", "db_dataset_id"] - ] - # Remove unneded text from 'step' (e.g. '*/grapher/'), no need for fuzzymatch! - steps_df_grapher["step_reduced"] = steps_df_grapher["step"].str.split("grapher/").str[-1] - - ## Keep only those that are in DB (we need them to be in DB, otherwise indicator upgrade won't work since charts wouldn't be able to reference to non-db-existing indicator IDs) - steps_df_grapher = steps_df_grapher.dropna(subset="db_dataset_id") - assert steps_df_grapher.isna().sum().sum() == 0 - ## Column rename - steps_df_grapher = steps_df_grapher.rename( - columns={ - "db_dataset_name": "name", - "db_dataset_id": "id", - } - ) - return steps_df_grapher, grapher_changes - - def get_datasets_from_db() -> pd.DataFrame: """Load datasets.""" try: @@ -165,56 +104,6 @@ def get_indicators_from_datasets( return old_indictors, new_indictors -def _check_env() -> bool: - """Check if environment indicators are set correctly.""" - ok = True - for env_name in ("GRAPHER_USER_ID", "DB_USER", "DB_NAME", "DB_HOST"): - if getattr(config, env_name) is None: - ok = False - st.warning(st.markdown(f"Environment variable `{env_name}` not found, do you have it in your `.env` file?")) - - if ok: - st.success("`.env` configured correctly") - return ok - - -def _show_environment() -> None: - """Show environment indicators (streamlit).""" - # show indicators (from .env) - st.info( - f""" - * **GRAPHER_USER_ID**: `{config.GRAPHER_USER_ID}` - * **DB_USER**: `{config.DB_USER}` - * **DB_NAME**: `{config.DB_NAME}` - * **DB_HOST**: `{config.DB_HOST}` - """ - ) - - -@st.cache_resource -def _check_env_and_environment() -> None: - """Check if environment indicators are set correctly.""" - ok = _check_env() - if ok: - # check that you can connect to DB - try: - with st.spinner(): - _ = get_connection() - except OperationalError as e: - st.error( - "We could not connect to the database. If connecting to a remote database, remember to" - f" ssh-tunel into it using the appropriate ports and then try again.\n\nError:\n{e}" - ) - ok = False - except Exception as e: - raise e - else: - msg = "Connection to the Grapher database was successfull!" - st.success(msg) - st.subheader("Environment") - _show_environment() - - @st.cache_data(show_spinner=False) def preliminary_mapping_cached( old_indicators, new_indicators, match_identical diff --git a/apps/wizard/cli.py b/apps/wizard/cli.py index c492d707ad0..aedf01a6f21 100644 --- a/apps/wizard/cli.py +++ b/apps/wizard/cli.py @@ -15,7 +15,9 @@ from apps.utils.style import set_rich_click_style from apps.wizard.config import WIZARD_PHASES from apps.wizard.utils import CURRENT_DIR +from apps.wizard.utils.paths import WIZARD_ANOMALIES from etl.config import WIZARD_PORT +from etl.files import create_folder # Disable streamlit cache data API logging # ref: @kajarenc from https://github.com/streamlit/streamlit/issues/6620#issuecomment-1564735996 @@ -73,6 +75,10 @@ def cli( """ script_path = CURRENT_DIR / "app.py" + # Create folder for anomalies + # TODO: this should be created elsewhere + create_folder(WIZARD_ANOMALIES) + # Define command with arguments args = [ "streamlit", diff --git a/apps/wizard/config/config.yml b/apps/wizard/config/config.yml index 5c4cf554c43..0ec6f7414fd 100644 --- a/apps/wizard/config/config.yml +++ b/apps/wizard/config/config.yml @@ -99,6 +99,15 @@ sections: image_url: "https://superheroetc.wordpress.com/wp-content/uploads/2017/05/bulbasaur-line.jpg" disable: production: True + - title: "Anomalist" + alias: anomalist + entrypoint: app_pages/anomalist/app.py + description: List anomalies in data + maintainer: "@lucas" + icon: ":material/planner_review:" + image_url: "https://i0.pickpik.com/photos/87/645/315/halloween-ghosts-happy-halloween-ghost-preview.jpg" + disable: + production: True - title: "Chart Diff" alias: chart-diff entrypoint: app_pages/chart_diff/app.py @@ -108,15 +117,6 @@ sections: image_url: "https://static.wikia.nocookie.net/dragonball/images/6/60/FusionDanceFinaleGotenTrunksBuuSaga.png" disable: production: True - # - title: "Anomalist" - # alias: anomalist - # entrypoint: app_pages/anomalist_2.py - # description: List anomalies in data - # maintainer: "@lucas" - # icon: ":material/planner_review:" - # image_url: "https://superheroetc.wordpress.com/wp-content/uploads/2017/05/bulbasaur-line.jpg" - # disable: - # production: True - title: "Harmonizer" alias: harmonizer description: "Harmonize a column of a table" @@ -124,20 +124,6 @@ sections: entrypoint: app_pages/harmonizer.py icon: ":material/music_note:" image_url: "https://upload.wikimedia.org/wikipedia/commons/thumb/c/c1/C_triad.svg/2560px-C_triad.svg.png" - - title: "Map Bracketer" - alias: map_brackets - entrypoint: app_pages/map_brackets.py - description: Create optimal map brackets - maintainer: "@pablo" - icon: ":material/map:" - image_url: "https://upload.wikimedia.org/wikipedia/en/8/8c/Human_Language_Families_Map_%28Wikipedia_Colors_.PNG" - - title: "Explorer editor" - alias: explorer_editor - entrypoint: app_pages/explorer_edit.py - description: Edit explorer config - maintainer: "@lucas" - icon: ":material/explore:" - image_url: "https://upload.wikimedia.org/wikipedia/en/1/18/Dora_the_Explorer_2004_album_cover.jpg" - title: "Monitoring" description: |- @@ -158,6 +144,25 @@ sections: icon: ":material/search:" image_url: "https://upload.wikimedia.org/wikipedia/commons/c/c3/NGC_4414_%28NASA-med%29.jpg" + - title: "Explorers" + description: |- + Explorer tools. + apps: + - title: "Map Bracketer" + alias: map_brackets + entrypoint: app_pages/map_brackets.py + description: Create optimal map brackets + maintainer: "@pablo" + icon: ":material/map:" + image_url: "https://upload.wikimedia.org/wikipedia/en/8/8c/Human_Language_Families_Map_%28Wikipedia_Colors_.PNG" + - title: "ID to Path" + alias: explorer_editor + entrypoint: app_pages/explorer_edit.py + description: Migrate id-based explorers + maintainer: "@lucas" + icon: ":material/explore:" + image_url: "https://upload.wikimedia.org/wikipedia/en/1/18/Dora_the_Explorer_2004_album_cover.jpg" + - title: "Research" description: |- Research tools. diff --git a/apps/wizard/utils/__init__.py b/apps/wizard/utils/__init__.py index fdc7ec6f3d1..b6726fa6770 100644 --- a/apps/wizard/utils/__init__.py +++ b/apps/wizard/utils/__init__.py @@ -24,7 +24,6 @@ import numpy as np import streamlit as st from owid.catalog import Dataset -from pymysql import OperationalError from sqlalchemy.orm import Session from structlog import get_logger from typing_extensions import Self @@ -33,7 +32,7 @@ from apps.wizard.utils.defaults import load_wizard_defaults, update_wizard_defaults_from_form from apps.wizard.utils.step_form import StepForm from etl.config import OWID_ENV, enable_bugsnag -from etl.db import get_connection, read_sql +from etl.db import read_sql from etl.files import ruamel_dump, ruamel_load from etl.metadata_export import main as metadata_export from etl.paths import ( @@ -405,8 +404,6 @@ def st_selectbox_responsive( key=f"{key}_custom", default_last=default_value, ) - # else: - # st.session_state[f"{self.step}.{key}_custom"] = "nana" @classproperty def args(cls: "AppState") -> argparse.Namespace: @@ -477,22 +474,6 @@ def _check_env() -> bool: return ok -def _check_db() -> bool: - try: - with st.spinner(): - _ = get_connection() - except OperationalError as e: - st.error( - "We could not connect to the database. If connecting to a remote database, remember to" - f" ssh-tunel into it using the appropriate ports and then try again.\n\nError:\n{e}" - ) - return False - except Exception as e: - raise e - st.success("Connection to the Grapher database was successfull!") - return True - - def _show_environment(): """Show environment variables.""" st.info( @@ -635,9 +616,6 @@ def st_page_link(alias: str, border: bool = False, **kwargs) -> None: st.page_link(**kwargs) -st.cache_data - - def metadata_export_basic(dataset_path: str | None = None, dataset: Dataset | None = None, output: str = "") -> str: """Export metadata of a dataset. @@ -746,3 +724,59 @@ def as_list(s): except (ValueError, SyntaxError): return s return s + + +def update_query_params(key): + def _update_query_params(): + value = st.session_state[key] + if value: + st.query_params.update({key: value}) + else: + st.query_params.pop(key, None) + + return _update_query_params + + +def url_persist(component: Any) -> Any: + """Wrapper around streamlit components that persist values in the URL query string. + + Usage: + url_persist(st.multiselect)( + key="abc", + ... + ) + """ + # Component uses list of values + if component == st.multiselect: + repeated = True + else: + repeated = False + + def _persist(*args, **kwargs): + assert "key" in kwargs, "key should be passed to persist" + # TODO: we could wrap on_change too to make it work + assert "on_change" not in kwargs, "on_change should not be passed to persist" + + key = kwargs["key"] + + if not st.session_state.get(key): + if repeated: + params = st.query_params.get_all(key) + # convert to int if digit + params = [int(q) if q.isdigit() else q for q in params] + else: + params = st.query_params.get(key) + if params and params.isdigit(): + params = int(params) + + # Use `value` from the component as a default value if available + if not params and "value" in kwargs: + params = kwargs.pop("value") + + st.session_state[key] = params + + kwargs["on_change"] = update_query_params(key) + + return component(*args, **kwargs) + + return _persist diff --git a/apps/wizard/utils/cached.py b/apps/wizard/utils/cached.py index 9ae59fa6817..88f4f649572 100644 --- a/apps/wizard/utils/cached.py +++ b/apps/wizard/utils/cached.py @@ -1,25 +1,78 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple import pandas as pd import streamlit as st +from sqlalchemy.orm import Session -from etl import grapher_io as io +from apps.utils.map_datasets import get_grapher_changes +from etl import grapher_io as gio from etl.config import OWID_ENV, OWIDEnv -from etl.grapher_model import Variable +from etl.git_helpers import get_changed_files +from etl.grapher_model import Anomaly, Variable +from etl.version_tracker import VersionTracker + + +@st.cache_data +def load_entity_ids(entity_ids: Optional[List[int]] = None): + return gio.load_entity_mapping(entity_ids) + + +@st.cache_data +def load_variables_display_in_dataset( + dataset_uri: Optional[List[str]] = None, + dataset_id: Optional[List[int]] = None, + only_slug: Optional[bool] = False, + _owid_env: OWIDEnv = OWID_ENV, +) -> Dict[int, str]: + """Load Variable objects that belong to a dataset with URI `dataset_uri`.""" + indicators = gio.load_variables_in_dataset( + dataset_uri=dataset_uri, + dataset_id=dataset_id, + owid_env=_owid_env, + ) + + def _display_slug(o) -> str: + p = o.catalogPath + if only_slug: + return p.rsplit("/", 1)[-1] if isinstance(p, str) else "" + return p + + indicators_display = {i.id: _display_slug(i) for i in indicators} + + return indicators_display + + +@st.cache_data +def get_variable_uris(indicators: List[Variable], only_slug: Optional[bool] = False) -> List[str]: + options = [o.catalogPath for o in indicators] + if only_slug: + options = [o.rsplit("/", 1)[-1] if isinstance(o, str) else "" for o in options] + return options # type: ignore + + +@st.cache_data +def load_dataset_uris_new_in_server() -> List[str]: + """Load URIs of datasets that are new in staging server.""" + return gio.load_dataset_uris() @st.cache_data def load_dataset_uris() -> List[str]: - return load_dataset_uris() + return gio.load_dataset_uris() @st.cache_data def load_variables_in_dataset( - dataset_uri: List[str], + dataset_uri: Optional[List[str]] = None, + dataset_id: Optional[List[int]] = None, _owid_env: OWIDEnv = OWID_ENV, ) -> List[Variable]: """Load Variable objects that belong to a dataset with URI `dataset_uri`.""" - return load_variables_in_dataset(dataset_uri, _owid_env) + return gio.load_variables_in_dataset( + dataset_uri=dataset_uri, + dataset_id=dataset_id, + owid_env=_owid_env, + ) @st.cache_data @@ -29,7 +82,7 @@ def load_variable_metadata( variable: Optional[Variable] = None, _owid_env: OWIDEnv = OWID_ENV, ) -> Dict[str, Any]: - return io.load_variable_metadata( + return gio.load_variable_metadata( catalog_path=catalog_path, variable_id=variable_id, variable=variable, @@ -44,9 +97,51 @@ def load_variable_data( variable: Optional[Variable] = None, _owid_env: OWIDEnv = OWID_ENV, ) -> pd.DataFrame: - return io.load_variable_data( + return gio.load_variable_data( catalog_path=catalog_path, variable_id=variable_id, variable=variable, owid_env=_owid_env, ) + + +@st.cache_data +def load_anomalies_in_dataset( + dataset_ids: List[int], + _owid_env: OWIDEnv = OWID_ENV, +) -> List[Anomaly]: + """Load Anomaly objects that belong to a dataset with URI `dataset_uri`.""" + with Session(_owid_env.engine) as session: + return Anomaly.load_anomalies(session, dataset_ids) + + +@st.cache_data(show_spinner=False) +def get_datasets_from_version_tracker() -> Tuple[pd.DataFrame, List[Dict[str, Any]]]: + """Get dataset info from version tracker (ETL).""" + # Get steps_df + vt = VersionTracker() + assert vt.connect_to_db, "Can't connect to database! You need to be connected to run this tool." + steps_df = vt.steps_df + + # Get file changes -> Infer dataset migrations + files_changed = get_changed_files() + grapher_changes = get_grapher_changes(files_changed, steps_df) + + # Only keep grapher steps + steps_df_grapher = steps_df.loc[ + steps_df["channel"] == "grapher", ["namespace", "identifier", "step", "db_dataset_name", "db_dataset_id"] + ] + # Remove unneeded text from 'step' (e.g. '*/grapher/'), no need for fuzzymatch! + steps_df_grapher["step_reduced"] = steps_df_grapher["step"].str.split("grapher/").str[-1] + + # Keep only those that are in DB (we need them to be in DB, otherwise indicator upgrade won't work since charts wouldn't be able to reference to non-db-existing indicator IDs) + steps_df_grapher = steps_df_grapher.dropna(subset="db_dataset_id") + assert steps_df_grapher.isna().sum().sum() == 0 + # Column rename + steps_df_grapher = steps_df_grapher.rename( + columns={ + "db_dataset_name": "name", + "db_dataset_id": "id", + } + ) + return steps_df_grapher, grapher_changes diff --git a/apps/wizard/utils/chart_config.py b/apps/wizard/utils/chart_config.py new file mode 100644 index 00000000000..e6120255bb5 --- /dev/null +++ b/apps/wizard/utils/chart_config.py @@ -0,0 +1,123 @@ +"""Tools to generate chart configs.""" +from copy import deepcopy +from typing import Any, Dict, List, Optional + +import numpy as np + +from etl.config import OWID_ENV, OWIDEnv +from etl.grapher_io import ensure_load_variable +from etl.grapher_model import Variable + +CONFIG_BASE = { + # "title": "Placeholder", + # "subtitle": "Placeholder.", + # "originUrl": "placeholder", + # "slug": "placeholder", + # "selectedEntityNames": ["placeholder"], + "entityType": "entity", + "entityTypePlural": "entities", + "facettingLabelByYVariables": "metric", + "invertColorScheme": False, + "yAxis": { + "canChangeScaleType": False, + "min": 0, + "max": "auto", + "facetDomain": "shared", + "removePointsOutsideDomain": False, + "scaleType": "linear", + }, + "hideTotalValueLabel": False, + "hideTimeline": False, + "hideLegend": False, + "tab": "chart", + "logo": "owid", + "$schema": "https://files.ourworldindata.org/schemas/grapher-schema.005.json", + "showYearLabels": False, + "id": 807, + "selectedFacetStrategy": "none", + "stackMode": "absolute", + "minTime": "earliest", + "compareEndPointsOnly": False, + "version": 14, + "sortOrder": "desc", + "maxTime": "latest", + "type": "LineChart", + "hideRelativeToggle": True, + "addCountryMode": "add-country", + "hideAnnotationFieldsInTitle": {"entity": False, "changeInPrefix": False, "time": False}, + "matchingEntitiesOnly": False, + "showNoDataArea": True, + "scatterPointLabelStrategy": "year", + "hideLogo": False, + "xAxis": { + "canChangeScaleType": False, + "min": "auto", + "max": "auto", + "facetDomain": "shared", + "removePointsOutsideDomain": False, + "scaleType": "linear", + }, + "hideConnectedScatterLines": False, + "zoomToSelection": False, + "hideFacetControl": True, + "hasMapTab": True, + "hideScatterLabels": False, + "missingDataStrategy": "auto", + "isPublished": False, + "timelineMinTime": "earliest", + "hasChartTab": True, + "timelineMaxTime": "latest", + "sortBy": "total", +} + + +def bake_chart_config( + catalog_path: Optional[str] = None, + variable_id: Optional[int | List[int]] = None, + variable: Optional[Variable | List[Variable]] = None, + selected_entities: Optional[list] = None, + included_entities: Optional[list] = None, + display: Optional[List[Any]] = None, + owid_env: OWIDEnv = OWID_ENV, +) -> Dict[str, Any]: + """Bake a Grapher chart configuration. + + Bakes a very basic config, which will be enough most of the times. If you want a more complex config, use this as a baseline to adjust to your needs. + + Note: You can find more details on our Grapher API at https://files.ourworldindata.org/schemas/grapher-schema.005.json. + + """ + # Define chart config + chart_config = deepcopy(CONFIG_BASE) + + # Tweak config + if isinstance(variable_id, (int, np.integer)): + chart_config["dimensions"] = [{"property": "y", "variableId": variable_id}] + elif isinstance(variable_id, list): + chart_config["dimensions"] = [{"property": "y", "variableId": v} for v in variable_id] + elif isinstance(catalog_path, str): + variable = ensure_load_variable(catalog_path=catalog_path, owid_env=owid_env) + chart_config["dimensions"] = [{"property": "y", "variableId": variable.id}] + elif isinstance(variable, Variable): + chart_config["dimensions"] = [{"property": "y", "variableId": variable.id}] + elif isinstance(variable, list): + chart_config["dimensions"] = [{"property": "y", "variableId": v.id} for v in variable] + else: + variable = ensure_load_variable(catalog_path, variable_id, variable, owid_env) + chart_config["dimensions"] = [{"property": "y", "variableId": variable.id}] + + if display is not None: + assert len(display) == len(chart_config["dimensions"]) + for i, d in enumerate(display): + chart_config["dimensions"][i]["display"] = d + + ## Selected entities? + if selected_entities is not None: + chart_config["selectedEntityNames"] = selected_entities + + # Included entities + if included_entities is not None: + included_entities = [str(entity) for entity in included_entities] + chart_config["includedEntities"] = included_entities + + return chart_config diff --git a/apps/wizard/utils/components.py b/apps/wizard/utils/components.py index 93ed4d6d818..258bdb1e3d5 100644 --- a/apps/wizard/utils/components.py +++ b/apps/wizard/utils/components.py @@ -1,15 +1,14 @@ import json from contextlib import contextmanager from copy import deepcopy -from random import sample -from typing import Any, Callable, Dict, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional import numpy as np import streamlit as st import streamlit.components.v1 as components +from apps.wizard.utils.chart_config import bake_chart_config from etl.config import OWID_ENV, OWIDEnv -from etl.grapher_io import load_variable_data from etl.grapher_model import Variable HORIZONTAL_STYLE = """