From d0afe396f42bade9e9bc93eaf37cdaba09fdcb4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucas=20Rod=C3=A9s-Guirao?= Date: Wed, 9 Oct 2024 13:05:37 +0200 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20wizard:=20anomalist=20(first=20draf?= =?UTF-8?q?t)=20(#3363)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ wizard: anomalies * wip * bump streamlit * wip * wip: chart * wip * todo * plot indicator * re-structure * wip: loading indicators * fix API grapher_chart * deprecate chart_html * chart_html -> grapher_chart * clean * ci/cd * wip * wip * changed module name * custom components module * add methods to get uris * new alias * get dataset uris * update import * update gpt pricing * update import * wip * provide entity-context for anomaly * wip: anomalist v2 * wip * wip * lock * ✨ anomalist: improve utils (#3385) * wip * db -> db_utils * io -> db * move things db_utils -> db * db -> grapher_io * db -> grapher_io, db_utils -> db * docstring * db_utils -> db * wip * remove indicator * add overloads * ci/cd * wip * cicd * wip * deprecation warnings * missing import * hide anomalist in wizard --- api/v1/__init__.py | 2 +- apps/backport/backport.py | 2 +- apps/backport/datasync/data_metadata.py | 134 +-- apps/explorer_update/cli.py | 2 +- apps/metadata_migrate/cli.py | 2 +- apps/utils/gpt.py | 18 +- apps/wizard/app_pages/anomalist.py | 314 +++++++ apps/wizard/app_pages/anomalist_2.py | 110 +++ .../app_pages/chart_diff/chart_diff_show.py | 14 +- .../indicator_upgrade/indicator_mapping.py | 2 +- .../app_pages/indicator_upgrade/utils.py | 4 +- apps/wizard/app_pages/map_brackets.py | 6 +- apps/wizard/app_pages/metaplay.py | 2 +- apps/wizard/config/config.yml | 9 + apps/wizard/utils/__init__.py | 56 +- apps/wizard/utils/cached.py | 52 ++ apps/wizard/utils/components.py | 213 +++++ etl/compare.py | 2 +- etl/config.py | 6 +- etl/db.py | 438 +--------- etl/explorer.py | 2 +- etl/explorer_helpers.py | 2 +- etl/grapher_helpers.py | 4 +- etl/grapher_import.py | 2 +- etl/grapher_io.py | 798 ++++++++++++++++++ etl/grapher_model.py | 215 ++++- etl/match_variables.py | 17 +- .../archive/migrate_to_new_metadata.py | 6 +- etl/scripts/faostat/create_chart_revisions.py | 13 +- etl/version_tracker.py | 3 +- pyproject.toml | 3 +- snapshots/wb/2023-07-10/education.py | 2 +- tests/api/v1.py | 8 +- tests/backport/datasync/test_data_metadata.py | 4 +- tests/test_grapher_helpers.py | 2 +- uv.lock | 10 +- 36 files changed, 1794 insertions(+), 685 deletions(-) create mode 100644 apps/wizard/app_pages/anomalist.py create mode 100644 apps/wizard/app_pages/anomalist_2.py create mode 100644 apps/wizard/utils/cached.py create mode 100644 apps/wizard/utils/components.py create mode 100644 etl/grapher_io.py diff --git a/api/v1/__init__.py b/api/v1/__init__.py index a18eb995cee..e5affca1454 100644 --- a/api/v1/__init__.py +++ b/api/v1/__init__.py @@ -102,7 +102,7 @@ def _load_and_validate_indicator(catalog_path: str) -> gm.Variable: # update YAML file with Session(engine) as session: try: - db_indicator = gm.Variable.load_from_catalog_path(session, catalog_path) + db_indicator = gm.Variable.from_id_or_path(session, catalog_path) except NoResultFound: raise HTTPException( 404, diff --git a/apps/backport/backport.py b/apps/backport/backport.py index ae5ed9a7061..a91614e9457 100644 --- a/apps/backport/backport.py +++ b/apps/backport/backport.py @@ -14,7 +14,6 @@ from apps.backport.datasync.data_metadata import ( _variable_metadata, variable_data, - variable_data_df_from_s3, ) from apps.backport.datasync.datasync import upload_gzip_dict from etl import config, paths @@ -22,6 +21,7 @@ from etl.backport_helpers import GrapherConfig from etl.db import get_engine, read_sql from etl.files import checksum_str +from etl.grapher_io import variable_data_df_from_s3 from etl.snapshot import Snapshot, SnapshotMeta from . import utils diff --git a/apps/backport/datasync/data_metadata.py b/apps/backport/datasync/data_metadata.py index 1ba45bcbf8b..6b744c18219 100644 --- a/apps/backport/datasync/data_metadata.py +++ b/apps/backport/datasync/data_metadata.py @@ -1,148 +1,16 @@ -import concurrent.futures import json from copy import deepcopy -from http.client import RemoteDisconnected -from typing import Any, Dict, List, Union, cast -from urllib.error import HTTPError, URLError +from typing import Any, Dict, List, Union import numpy as np import pandas as pd -import requests from sqlalchemy import text -from sqlalchemy.engine import Engine from sqlalchemy.orm import Session from structlog import get_logger -from tenacity import Retrying -from tenacity.retry import retry_if_exception_type -from tenacity.stop import stop_after_attempt -from tenacity.wait import wait_fixed - -from etl import config -from etl.config import OWIDEnv -from etl.db import read_sql log = get_logger() -def _fetch_data_df_from_s3(variable_id: int): - try: - # Cloudflare limits us to 600 requests per minute, retry in case we hit the limit - # NOTE: increase wait time or attempts if we hit the limit too often - for attempt in Retrying( - wait=wait_fixed(2), - stop=stop_after_attempt(3), - retry=retry_if_exception_type((URLError, RemoteDisconnected)), - ): - with attempt: - return ( - pd.read_json(config.variable_data_url(variable_id)) - .rename( - columns={ - "entities": "entityId", - "values": "value", - "years": "year", - } - ) - .assign(variableId=variable_id) - ) - # no data on S3 - except HTTPError: - return pd.DataFrame(columns=["variableId", "entityId", "year", "value"]) - - -def variable_data_df_from_s3( - engine: Engine, - variable_ids: List[int] = [], - workers: int = 1, - value_as_str: bool = True, -) -> pd.DataFrame: - """Fetch data from S3 and add entity code and name from DB.""" - with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: - results = list(executor.map(_fetch_data_df_from_s3, variable_ids)) - - if isinstance(results, list) and all(isinstance(df, pd.DataFrame) for df in results): - df = pd.concat(cast(List[pd.DataFrame], results)) - else: - raise TypeError(f"results must be a list of pd.DataFrame, got {type(results)}") - - # we work with strings and convert to specific types later - if value_as_str: - df["value"] = df["value"].astype("string") - - with Session(engine) as session: - res = add_entity_code_and_name(session, df) - return res - - -def _fetch_metadata_from_s3(variable_id: int, env: OWIDEnv | None = None) -> Dict[str, Any] | None: - try: - # Cloudflare limits us to 600 requests per minute, retry in case we hit the limit - # NOTE: increase wait time or attempts if we hit the limit too often - for attempt in Retrying( - wait=wait_fixed(2), - stop=stop_after_attempt(3), - retry=retry_if_exception_type((URLError, RemoteDisconnected)), - ): - with attempt: - if env is not None: - url = env.indicator_metadata_url(variable_id) - else: - url = config.variable_metadata_url(variable_id) - return requests.get(url).json() - # no data on S3 - except HTTPError: - return {} - - -def variable_metadata_df_from_s3( - variable_ids: List[int] = [], - workers: int = 1, - env: OWIDEnv | None = None, -) -> List[Dict[str, Any]]: - """Fetch data from S3 and add entity code and name from DB.""" - args = [variable_ids] - if env: - args += [[env for _ in range(len(variable_ids))]] - - with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: - results = list(executor.map(_fetch_metadata_from_s3, *args)) - - if not (isinstance(results, list) and all(isinstance(res, dict) for res in results)): - raise TypeError(f"results must be a list of dictionaries, got {type(results)}") - - return results # type: ignore - - -def _fetch_entities(session: Session, entity_ids: List[int]) -> pd.DataFrame: - # Query entities from the database - q = """ - SELECT - id AS entityId, - name AS entityName, - code AS entityCode - FROM entities - WHERE id in %(entity_ids)s - """ - return read_sql(q, session, params={"entity_ids": entity_ids}) - - -def add_entity_code_and_name(session: Session, df: pd.DataFrame) -> pd.DataFrame: - if df.empty: - df["entityName"] = [] - df["entityCode"] = [] - return df - - unique_entities = df["entityId"].unique() - - entities = _fetch_entities(session, list(unique_entities)) - - if set(unique_entities) - set(entities.entityId): - missing_entities = set(unique_entities) - set(entities.entityId) - raise ValueError(f"Missing entities in the database: {missing_entities}") - - return pd.merge(df, entities.astype({"entityName": "category", "entityCode": "category"}), on="entityId") - - def variable_data(data_df: pd.DataFrame) -> Dict[str, Any]: data_df = data_df.rename( columns={ diff --git a/apps/explorer_update/cli.py b/apps/explorer_update/cli.py index fff4c341623..616cd86e3ad 100644 --- a/apps/explorer_update/cli.py +++ b/apps/explorer_update/cli.py @@ -6,7 +6,7 @@ from structlog import get_logger from tqdm.auto import tqdm -from etl.db import get_variables_data +from etl.grapher_io import get_variables_data from etl.paths import EXPLORERS_DIR from etl.version_tracker import VersionTracker diff --git a/apps/metadata_migrate/cli.py b/apps/metadata_migrate/cli.py index 1cd20cb0256..b096289a8c2 100644 --- a/apps/metadata_migrate/cli.py +++ b/apps/metadata_migrate/cli.py @@ -120,7 +120,7 @@ def cli( var_id = grapher_config["dimensions"][0]["variableId"] with Session(engine) as session: - variable = gm.Variable.load_variable(session, var_id) + variable = gm.Variable.from_id_or_path(session, var_id) assert variable.catalogPath, f"Variable {var_id} does not come from ETL. Migrate it there first." diff --git a/apps/utils/gpt.py b/apps/utils/gpt.py index c3e2c63725c..965b19e4472 100644 --- a/apps/utils/gpt.py +++ b/apps/utils/gpt.py @@ -24,7 +24,9 @@ "gpt-3.5-turbo": "gpt-3.5-turbo-0125", "gpt-4-turbo-preview": "gpt-4-0125-preview", "gpt-4-turbo": "gpt-4-turbo-2024-04-09", - "gpt-4o": "gpt-4o-2024-05-13", + "gpt-4o": "gpt-4o-2024-08-06", + "o1-preview": "o1-preview-2024-09-12", + "gpt-4o-mini": "gpt-4o-mini-2024-07-18", } MODEL_RATES_1000_TOKEN = { # GPT 3.5 @@ -59,6 +61,20 @@ "in": 5 / 1000, "out": 15 / 1000, }, + "gpt-4o-2024-08-06": { + "in": 2.5 / 1000, + "out": 10 / 1000, + }, + # GPTO 4o mini + "gpt-4o-mini-2024-07-18": { + "in": 0.150 / 1000, + "out": 0.600 / 1000, + }, + # GPT o1 + "o1-preview-2024-09-12": { + "in": 15 / 1000, + "out": 60 / 1000, + }, } MODEL_RATES_1000_TOKEN = { **MODEL_RATES_1000_TOKEN, diff --git a/apps/wizard/app_pages/anomalist.py b/apps/wizard/app_pages/anomalist.py new file mode 100644 index 00000000000..cb76e03c806 --- /dev/null +++ b/apps/wizard/app_pages/anomalist.py @@ -0,0 +1,314 @@ +from typing import cast + +import streamlit as st +from pydantic import BaseModel, Field, ValidationError + +from apps.utils.gpt import OpenAIWrapper, get_cost_and_tokens +from apps.wizard.utils import cached +from apps.wizard.utils.components import grapher_chart, st_horizontal +from etl.config import OWID_ENV +from etl.grapher_io import load_variables_in_dataset + +# PAGE CONFIG +st.set_page_config( + page_title="Wizard: Anomalist", + page_icon="🪄", +) + +# SESSION STATE +st.session_state.register = st.session_state.get("register", {"by_dataset": {}}) +st.session_state.datasets_selected = st.session_state.get("datasets_selected", []) +st.session_state.anomaly_revision = st.session_state.get("anomaly_revision", {}) + + +# GPT +MODEL = "gpt-4o" +api = OpenAIWrapper() + +# PAGE TITLE +st.title(":material/planner_review: Anomalist") +# st.markdown("Detect anomalies in your data!") + + +# SELECT DATASETS +st.markdown( + """ + """, + unsafe_allow_html=True, +) +st.session_state.datasets_selected = st.multiselect( + "Select datasets", + options=cached.load_dataset_uris(), + max_selections=3, +) + +for i in st.session_state: + if i.startswith("check_anomaly_resolved_"): + st.write(i, st.session_state[i]) + + +# GET INDICATORS +if len(st.session_state.datasets_selected) > 0: + # Get indicator uris for all selected datasets + indicators = load_variables_in_dataset(st.session_state.datasets_selected) + + for indicator in indicators: + catalog_path = cast(str, indicator.catalogPath) + dataset_uri, indicator_slug = catalog_path.rsplit("/", 1) + if dataset_uri not in st.session_state.register["by_dataset"]: + st.session_state.register["by_dataset"][dataset_uri] = {} + if indicator_slug in st.session_state.register["by_dataset"][dataset_uri]: + continue + st.session_state.register["by_dataset"][dataset_uri][indicator_slug] = { + "anomalies": [], + "id": indicator.id, + } + +################################################ +# FUNCTIONS / CLASSES +################################################ + +NUM_MAX_ENTITIES = 10 + + +@st.dialog("Vizualize the indicator", width="large") +def show_indicator(indicator_uri, indicator_id, selected_entities=None): + """Plot the indicator in a modal window.""" + # Modal title + st.markdown(f"[{indicator_slug}]({OWID_ENV.indicator_admin_site(indicator_id)})") + + # Plot indicator + if (selected_entities is not None) and (len(selected_entities) > NUM_MAX_ENTITIES): + st.warning(f"Too many entities. Showing only the first {NUM_MAX_ENTITIES}.") + grapher_chart(catalog_path=indicator_uri, selected_entities=selected_entities, owid_env=OWID_ENV) + + +def show_anomaly(title, description, entities, indicator_id, indicator_uri, key): + # check_value = st.session_state.register["by_dataset"][dataset_name][indicator_slug].get("resolved", False) + check_value = st.session_state.get(key, False) + + if check_value: + icon = "✅" + else: + icon = "⏳" + + with st.expander(title, icon=icon): + st.checkbox( + "Mark as resolved", + value=check_value, + key=f"check_anomaly_resolved_{key}", + ) + st.write(description) + # st.write(entities) + + if (entities is not None) & isinstance(entities, list) & (len(entities) > 0): + if st.button( + "Inspect anomaly", + icon=":material/show_chart:", + # use_container_width=True, + key=f"btn_plot_{key}", + ): + show_indicator(indicator_uri, indicator_id, entities) + + +def get_anomaly_gpt(indicator_id: str, indicator_uri: str, dataset_name: str, indicator_slug: str): + # Open AI (do first to catch possible errors in ENV) + # Prepare messages for Insighter + + data = cached.load_variable_data(variable_id=int(indicator_id)) + data_1 = data.pivot(index="years", columns="entity", values="values") # .head(20) + data_1 = data_1.dropna(axis=1, how="all") + data_1_str = cast(str, data_1.to_csv()).replace(".0,", ",") + + num_anomalies = 3 + + messages = [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": f"Provide {num_anomalies} anomalies in for the given time series.", + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "text", + "text": data_1_str, + }, + ], + }, + ] + kwargs = { + "api": api, + "model": MODEL, + "messages": messages, + "max_tokens": 3000, + "response_format": AnomaliesModel, + } + latest_json = "" + for anomaly_count, (anomaly, _, latest_json) in enumerate(openai_structured_outputs_stream(**kwargs)): + # Show anomaly + key = f"{indicator_id}_{anomaly_count}" + show_anomaly(anomaly.title, anomaly.description, anomaly.entities, indicator_id, indicator_uri, key) + + # Save anomaly + st.session_state.register["by_dataset"][dataset_name][indicator_slug]["anomalies"].append(anomaly.model_dump()) + + # Get cost and tokens + text_in = [mm["text"] for m in messages for mm in m["content"] if mm["type"] == "text"] + text_in = "\n".join(text_in) + cost, num_tokens = get_cost_and_tokens(text_in, latest_json, cast(str, MODEL)) + cost_msg = f"**Cost**: ≥{cost} USD.\n\n **Tokens**: ≥{num_tokens}." + st.info(cost_msg) + + +@st.fragment +def show_indicator_block(dataset_name, indicator_slug): + indicator_uri = f"{dataset_name}/{indicator_slug}" + + indicator_props = st.session_state.register["by_dataset"][dataset_name][indicator_slug] + indicator_id = indicator_props["id"] + indicator_anomalies = indicator_props["anomalies"] + + with st.container(border=True): + # Title + st.markdown(f"[{indicator_slug}]({OWID_ENV.indicator_admin_site(indicator_id)})") + + # Buttons + with st_horizontal(): + # Find anomalies button + btn_gpt = st.button( + "Find anomalies", + icon=":material/robot:", + # use_container_width=True, + type="primary", + help="Use GPT to find anomalies in the indicator.", + key=f"btn_gpt_{indicator_id}", + # on_click=lambda: st.rerun(scope="fragment"), + ) + # 'Plot indicator' button + if st.button( + "Plot indicator", + icon=":material/show_chart:", + # use_container_width=True, + key=f"btn_plot_{indicator_id}", + ): + show_indicator(indicator_uri, indicator_id) + + # Show anomalies + if btn_gpt: + with st.spinner("Querying GPT for anomalies..."): + get_anomaly_gpt(indicator_id, indicator_uri, dataset_name, indicator_slug) + else: + for anomaly_count, anomaly in enumerate(indicator_anomalies): + key = f"resolved_{indicator_id}_{anomaly_count}" + show_anomaly( + anomaly["title"], anomaly["description"], anomaly["entities"], indicator_id, indicator_uri, key + ) + + +class AnomalyModel(BaseModel): + title: str = Field(description="Title of the anomaly.") + description: str = Field(description="Short description of the anomaly.") + entities: list[str] = Field( + description="List of entities affected by the anomaly. Entities are given as columns in the input data (excluding column 'years' or 'date')." + ) + finished: bool = Field(description="True if the obtention of a particular anomaly has been finalized.") + + +class AnomaliesModel(BaseModel): + anomalies: list[AnomalyModel] = Field(description="List of anomalies detected in the data.") + + +def openai_structured_outputs_stream(api, **kwargs): + """Stream structured outputs from OpenAI API. + + References: + - https://community.openai.com/t/streaming-using-structured-outputs/925799/13 + """ + parsed_latest = None + with api.beta.chat.completions.stream(**kwargs, stream_options={"include_usage": True}) as stream: + # Check each chunk in stream (new chunk appears whenever a new character is added to the completion) + for chunk in stream: + # Only consider those of type "chunk" + if chunk.type == "chunk": + # Get latest snapshot + latest_snapshot = chunk.to_dict()["snapshot"] + + # Get latest choice + choice = latest_snapshot["choices"][0] + parsed_cumulative = choice["message"].get("parsed", {}) + + # Note that usage is not available until the final chunk + latest_usage = latest_snapshot.get("usage", {}) + latest_json = choice["message"]["content"] + + # Checks: + # 1. Check if "anomalies" is in the returned object + # 2. Check if "anomalies" is a list + # 3. Check if "anomalies" is not empty + # 4. Check if the latest parsed object is different from the previous one + if "anomalies" in parsed_cumulative: + anomalies = parsed_cumulative["anomalies"] + if isinstance(anomalies, list) & (len(anomalies) > 0): + parsed_latest_ = anomalies[-1] + + # Check if parsed_latest_ is a valid AnomalyModel (i.e. if it is a complete object!) + try: + anomaly = AnomalyModel(**parsed_latest_) + if (parsed_latest is None) | (parsed_latest != parsed_latest_): + # st.write(choice) + parsed_latest = parsed_latest_ + yield anomaly, latest_usage, latest_json + except ValidationError as _: + continue + # yield latest_parsed["anomalies"], latest_usage, latest_json + + +# SHOW INDICATORS +if len(st.session_state.datasets_selected) > 0: + num_tabs = len(st.session_state.datasets_selected) + tabs = st.tabs(st.session_state.datasets_selected) + + # Block per dataset + for dataset_name, tab in zip(st.session_state.datasets_selected, tabs): + with tab: + # Block per indicator in dataset + for indicator_slug in st.session_state.register["by_dataset"][dataset_name].keys(): + # Indicator block + show_indicator_block(dataset_name, indicator_slug) + + # my_fragment(indicator_uri) + # Anomalies detected + # anomalies = indicator["anomalies"] + # st.markdown(f"{len(anomalies)} anomalies detected.") + + # for anomaly_index, a in enumerate(anomalies): + # # Review icon + # if a["resolved"]: + # icon = "✅" + # else: + # icon = "⏳" + + # # Anomaly explained (expander) + # with st.expander(f'{anomaly_index+1}/ {a["title"]}', expanded=False, icon=icon): + # # Check if resolved + # key = f"resolved_{dataset_index}_{indicator_index}_{anomaly_index}" + + # # Checkbox (if resolved) + # st.checkbox( + # "Mark as resolved", + # value=a["resolved"], + # key=key, + # ) + + # # Anomaly description + # st.markdown(a["description"]) diff --git a/apps/wizard/app_pages/anomalist_2.py b/apps/wizard/app_pages/anomalist_2.py new file mode 100644 index 00000000000..61369b2880a --- /dev/null +++ b/apps/wizard/app_pages/anomalist_2.py @@ -0,0 +1,110 @@ +import pandas as pd +import streamlit as st + +from apps.wizard.utils import cached +from apps.wizard.utils.components import grapher_chart, st_horizontal + +# PAGE CONFIG +st.set_page_config( + page_title="Wizard: Anomalist", + page_icon="🪄", +) +# OTHER CONFIG +ANOMALY_TYPES = [ + "Upgrade", + "Abrupt change", + "Context change", +] + +# SESSION STATE +st.session_state.datasets_selected = st.session_state.get("datasets_selected", []) +st.session_state.filter_indicators = st.session_state.get("filter_indicators", []) +st.session_state.indicators = st.session_state.get("indicators", []) + +# PAGE TITLE +st.title(":material/planner_review: Anomalist") + + +# DATASET SEARCH +st.markdown( + """ + """, + unsafe_allow_html=True, +) +with st.form(key="dataset_search"): + st.session_state.datasets_selected = st.multiselect( + "Select datasets", + options=cached.load_dataset_uris(), + max_selections=1, + ) + + st.form_submit_button("Detect anomalies", type="primary") + + +# FILTER PARAMS +with st.container(border=True): + st.markdown("##### Filter Parameters") + options = [] + if len(st.session_state.datasets_selected) > 0: + st.session_state.indicators = cached.load_variables_in_dataset(st.session_state.datasets_selected) + options = [o.catalogPath for o in st.session_state.indicators] + + st.session_state.filter_indicators = st.multiselect( + label="Indicator", + options=options, + ) + + with st_horizontal(): + st.session_state.filter_indicators = st.multiselect( + label="Indicator type", + options=["New indicator", "Indicator upgrade"], + ) + st.session_state.filter_indicators = st.multiselect( + label="Anomaly type", + options=ANOMALY_TYPES, + ) + + # st.multiselect("Anomaly type", min_value=0.0, max_value=1.0, value=0.5, step=0.01) + st.number_input("Minimum score", min_value=0.0, max_value=1.0, value=0.5, step=0.01) + +# SHOW ANOMALIES +data = { + "anomaly": ["Anomaly 1", "Anomaly 2", "Anomaly 3"], + "description": ["Description 1", "Description 2", "Description 3"], +} + + +# SHOW ANOMALIES +def show_anomaly(df: pd.DataFrame): + if len(st.session_state.anomalies["selection"]["rows"]) > 0: + # Get selected row number + row_num = st.session_state.anomalies["selection"]["rows"][0] + # Get indicator id + indicator_id = df.index[row_num] + action(indicator_id) + + +@st.dialog("Show anomaly", width="large") +def action(indicator_id): + grapher_chart(variable_id=indicator_id) + + +if len(st.session_state.indicators) > 0: + df = pd.DataFrame( + { + "indicator_id": [i.id for i in st.session_state.indicators], + "reviewed": [False for i in st.session_state.indicators], + }, + ).set_index("indicator_id") + + st.dataframe( + df, + key="anomalies", + selection_mode="single-row", + on_select=lambda df=df: show_anomaly(df), + use_container_width=True, + ) diff --git a/apps/wizard/app_pages/chart_diff/chart_diff_show.py b/apps/wizard/app_pages/chart_diff/chart_diff_show.py index c7ba8b98c8d..3dfa4c1d0ef 100644 --- a/apps/wizard/app_pages/chart_diff/chart_diff_show.py +++ b/apps/wizard/app_pages/chart_diff/chart_diff_show.py @@ -14,14 +14,14 @@ import etl.grapher_model as gm from apps.backport.datasync.data_metadata import ( filter_out_fields_in_metadata_for_checksum, - variable_metadata_df_from_s3, ) from apps.utils.gpt import OpenAIWrapper, get_cost_and_tokens from apps.wizard.app_pages.chart_diff.chart_diff import ChartDiff, ChartDiffsLoader from apps.wizard.app_pages.chart_diff.conflict_resolver import ChartDiffConflictResolver from apps.wizard.app_pages.chart_diff.utils import SOURCE, TARGET, prettify_date -from apps.wizard.utils import chart_html +from apps.wizard.utils.components import grapher_chart from etl.config import OWID_ENV +from etl.grapher_io import variable_metadata_df_from_s3 # How to display the various chart review statuses DISPLAY_STATE_OPTIONS = { @@ -403,11 +403,11 @@ def _show_charts_comparison_v(): else: st.markdown(self._header_production_chart) assert self.diff.target_chart is not None - chart_html(self.diff.target_chart.config, owid_env=TARGET) + grapher_chart(chart_config=self.diff.target_chart.config, owid_env=TARGET) # Chart staging st.markdown(self._header_staging_chart) - chart_html(self.diff.source_chart.config, owid_env=SOURCE) + grapher_chart(chart_config=self.diff.source_chart.config, owid_env=SOURCE) def _show_charts_comparison_h(): """Show charts next to each other.""" @@ -421,15 +421,15 @@ def _show_charts_comparison_h(): else: st.markdown(self._header_production_chart) assert self.diff.target_chart is not None - chart_html(self.diff.target_chart.config, owid_env=TARGET) + grapher_chart(chart_config=self.diff.target_chart.config, owid_env=TARGET) with col2: st.markdown(self._header_staging_chart) - chart_html(self.diff.source_chart.config, owid_env=SOURCE) + grapher_chart(chart_config=self.diff.source_chart.config, owid_env=SOURCE) # Only one chart: new chart if self.diff.target_chart is None: st.markdown(f"New version ┃ _{prettify_date(self.diff.source_chart)}_") - chart_html(self.diff.source_chart.config, owid_env=SOURCE) + grapher_chart(chart_config=self.diff.source_chart.config, owid_env=SOURCE) # Two charts, actual diff else: # Detect arrangement type diff --git a/apps/wizard/app_pages/indicator_upgrade/indicator_mapping.py b/apps/wizard/app_pages/indicator_upgrade/indicator_mapping.py index 270c20666ed..b11324505e9 100644 --- a/apps/wizard/app_pages/indicator_upgrade/indicator_mapping.py +++ b/apps/wizard/app_pages/indicator_upgrade/indicator_mapping.py @@ -7,7 +7,6 @@ import streamlit as st from structlog import get_logger -from apps.backport.datasync.data_metadata import variable_data_df_from_s3 from apps.wizard.app_pages.indicator_upgrade.explore_mode import st_explore_indicator from apps.wizard.app_pages.indicator_upgrade.utils import ( find_mapping_suggestions_cached, @@ -17,6 +16,7 @@ from apps.wizard.utils import Pagination, set_states from etl.config import OWID_ENV from etl.db import get_engine, read_sql +from etl.grapher_io import variable_data_df_from_s3 # Logger log = get_logger() diff --git a/apps/wizard/app_pages/indicator_upgrade/utils.py b/apps/wizard/app_pages/indicator_upgrade/utils.py index a2ac426fba1..e737bf534b8 100644 --- a/apps/wizard/app_pages/indicator_upgrade/utils.py +++ b/apps/wizard/app_pages/indicator_upgrade/utils.py @@ -8,8 +8,10 @@ from structlog import get_logger from apps.utils.map_datasets import get_grapher_changes -from etl.db import config, get_all_datasets, get_connection, get_dataset_charts, get_variables_in_dataset +from etl import config +from etl.db import get_connection from etl.git_helpers import get_changed_files +from etl.grapher_io import get_all_datasets, get_dataset_charts, get_variables_in_dataset from etl.match_variables import find_mapping_suggestions, preliminary_mapping from etl.version_tracker import VersionTracker diff --git a/apps/wizard/app_pages/map_brackets.py b/apps/wizard/app_pages/map_brackets.py index 57db7425ec9..2a48432ca18 100644 --- a/apps/wizard/app_pages/map_brackets.py +++ b/apps/wizard/app_pages/map_brackets.py @@ -15,7 +15,7 @@ from sqlalchemy.orm import Session from structlog import get_logger -from apps.wizard.utils import chart_html +from apps.wizard.utils.components import grapher_chart from etl.config import OWID_ENV from etl.data_helpers.misc import round_to_nearest_power_of_ten, round_to_shifted_power_of_ten, round_to_sig_figs from etl.explorer_helpers import Explorer @@ -113,7 +113,7 @@ def load_variable_from_id(variable_id: int): @st.cache_data def load_variable_from_catalog_path(catalog_path: str): with Session(OWID_ENV.engine) as session: - variable = Variable.load_from_catalog_path(session=session, catalog_path=catalog_path) + variable = Variable.from_catalog_path(session=session, catalog_path=catalog_path) return variable @@ -1040,7 +1040,7 @@ def _create_maximum_instances_message(mb: MapBracketer) -> str: st.info(_create_maximum_instances_message(mb)) # Display the chart. - chart_html(chart_config=mb.chart_config, owid_env=OWID_ENV, height=540) + grapher_chart(chart_config=mb.chart_config, owid_env=OWID_ENV, height=540) with st.sidebar: if edit_brackets and st.button("Save brackets in explorer file", type="primary"): diff --git a/apps/wizard/app_pages/metaplay.py b/apps/wizard/app_pages/metaplay.py index 10be3435245..c045a0aa95b 100644 --- a/apps/wizard/app_pages/metaplay.py +++ b/apps/wizard/app_pages/metaplay.py @@ -71,7 +71,7 @@ def get_data_page_url() -> str: # The following port is defined in one of owid-grapher's config files. HOST = "localhost:3030" with get_session() as session: - VARIABLE_ID = gm.Variable.load_from_catalog_path(session, CATALOG_PATH).id + VARIABLE_ID = gm.Variable.from_catalog_path(session, CATALOG_PATH).id url = f"http://{HOST}/admin/datapage-preview/{VARIABLE_ID}" return url diff --git a/apps/wizard/config/config.yml b/apps/wizard/config/config.yml index 4d3d823bfd7..5c4cf554c43 100644 --- a/apps/wizard/config/config.yml +++ b/apps/wizard/config/config.yml @@ -108,6 +108,15 @@ sections: image_url: "https://static.wikia.nocookie.net/dragonball/images/6/60/FusionDanceFinaleGotenTrunksBuuSaga.png" disable: production: True + # - title: "Anomalist" + # alias: anomalist + # entrypoint: app_pages/anomalist_2.py + # description: List anomalies in data + # maintainer: "@lucas" + # icon: ":material/planner_review:" + # image_url: "https://superheroetc.wordpress.com/wp-content/uploads/2017/05/bulbasaur-line.jpg" + # disable: + # production: True - title: "Harmonizer" alias: harmonizer description: "Harmonize a column of a table" diff --git a/apps/wizard/utils/__init__.py b/apps/wizard/utils/__init__.py index c812059dd9c..2f7c2b079df 100644 --- a/apps/wizard/utils/__init__.py +++ b/apps/wizard/utils/__init__.py @@ -16,7 +16,6 @@ import os import re import sys -from copy import deepcopy from datetime import date from pathlib import Path from typing import Any, Callable, Dict, List, Literal, Optional, cast @@ -24,7 +23,6 @@ import bugsnag import numpy as np import streamlit as st -import streamlit.components.v1 as components from owid.catalog import Dataset from pymysql import OperationalError from sqlalchemy.orm import Session @@ -34,7 +32,7 @@ from apps.wizard.config import PAGES_BY_ALIAS from apps.wizard.utils.defaults import load_wizard_defaults, update_wizard_defaults_from_form from apps.wizard.utils.step_form import StepForm -from etl.config import OWID_ENV, OWIDEnv, enable_bugsnag +from etl.config import OWID_ENV, enable_bugsnag from etl.db import get_connection, read_sql from etl.files import ruamel_dump, ruamel_load from etl.metadata_export import main as metadata_export @@ -678,58 +676,6 @@ def bugsnag_handler(exception: Exception) -> None: error_util.handle_uncaught_app_exception = bugsnag_handler # type: ignore -def chart_html(chart_config: Dict[str, Any], owid_env: OWIDEnv, height=600, **kwargs): - chart_config_tmp = deepcopy(chart_config) - - chart_config_tmp["bakedGrapherURL"] = f"{owid_env.base_site}/grapher" - chart_config_tmp["adminBaseUrl"] = owid_env.base_site - chart_config_tmp["dataApiUrl"] = f"{owid_env.indicators_url}/" - - # HTML = f""" - # - # - # - # - # - # - # - # - #
- #
- #
- #
- # - # - # - # - # - # """ - - HTML = f""" - - -
-
-
-
- - - -
- """ - - components.html(HTML, height=height, **kwargs) - - class Pagination: def __init__(self, items: list[Any], items_per_page: int, pagination_key: str, on_click: Optional[Callable] = None): self.items = items diff --git a/apps/wizard/utils/cached.py b/apps/wizard/utils/cached.py new file mode 100644 index 00000000000..9ae59fa6817 --- /dev/null +++ b/apps/wizard/utils/cached.py @@ -0,0 +1,52 @@ +from typing import Any, Dict, List, Optional + +import pandas as pd +import streamlit as st + +from etl import grapher_io as io +from etl.config import OWID_ENV, OWIDEnv +from etl.grapher_model import Variable + + +@st.cache_data +def load_dataset_uris() -> List[str]: + return load_dataset_uris() + + +@st.cache_data +def load_variables_in_dataset( + dataset_uri: List[str], + _owid_env: OWIDEnv = OWID_ENV, +) -> List[Variable]: + """Load Variable objects that belong to a dataset with URI `dataset_uri`.""" + return load_variables_in_dataset(dataset_uri, _owid_env) + + +@st.cache_data +def load_variable_metadata( + catalog_path: Optional[str] = None, + variable_id: Optional[int] = None, + variable: Optional[Variable] = None, + _owid_env: OWIDEnv = OWID_ENV, +) -> Dict[str, Any]: + return io.load_variable_metadata( + catalog_path=catalog_path, + variable_id=variable_id, + variable=variable, + owid_env=_owid_env, + ) + + +@st.cache_data +def load_variable_data( + catalog_path: Optional[str] = None, + variable_id: Optional[int] = None, + variable: Optional[Variable] = None, + _owid_env: OWIDEnv = OWID_ENV, +) -> pd.DataFrame: + return io.load_variable_data( + catalog_path=catalog_path, + variable_id=variable_id, + variable=variable, + owid_env=_owid_env, + ) diff --git a/apps/wizard/utils/components.py b/apps/wizard/utils/components.py new file mode 100644 index 00000000000..52b5cc445c0 --- /dev/null +++ b/apps/wizard/utils/components.py @@ -0,0 +1,213 @@ +import json +from contextlib import contextmanager +from copy import deepcopy +from random import sample +from typing import Any, Dict, Optional + +import numpy as np +import streamlit as st +import streamlit.components.v1 as components + +from etl.config import OWID_ENV, OWIDEnv +from etl.grapher_io import load_variable_data +from etl.grapher_model import Variable + +HORIZONTAL_STYLE = """ +""" + + +@contextmanager +def st_horizontal(): + st.markdown(HORIZONTAL_STYLE, unsafe_allow_html=True) + with st.container(): + st.markdown('', unsafe_allow_html=True) + yield + + +CONFIG_BASE = { + # "title": "Placeholder", + # "subtitle": "Placeholder.", + # "originUrl": "placeholder", + # "slug": "placeholder", + # "selectedEntityNames": ["placeholder"], + "entityType": "entity", + "entityTypePlural": "entities", + "facettingLabelByYVariables": "metric", + "invertColorScheme": False, + "yAxis": { + "canChangeScaleType": False, + "min": 0, + "max": "auto", + "facetDomain": "shared", + "removePointsOutsideDomain": False, + "scaleType": "linear", + }, + "hideTotalValueLabel": False, + "hideTimeline": False, + "hideLegend": False, + "tab": "chart", + "logo": "owid", + "$schema": "https://files.ourworldindata.org/schemas/grapher-schema.005.json", + "showYearLabels": False, + "id": 807, + "selectedFacetStrategy": "none", + "stackMode": "absolute", + "minTime": "earliest", + "compareEndPointsOnly": False, + "version": 14, + "sortOrder": "desc", + "maxTime": "latest", + "type": "LineChart", + "hideRelativeToggle": True, + "addCountryMode": "add-country", + "hideAnnotationFieldsInTitle": {"entity": False, "changeInPrefix": False, "time": False}, + "matchingEntitiesOnly": False, + "showNoDataArea": True, + "scatterPointLabelStrategy": "year", + "hideLogo": False, + "xAxis": { + "canChangeScaleType": False, + "min": "auto", + "max": "auto", + "facetDomain": "shared", + "removePointsOutsideDomain": False, + "scaleType": "linear", + }, + "hideConnectedScatterLines": False, + "zoomToSelection": False, + "hideFacetControl": True, + "hasMapTab": True, + "hideScatterLabels": False, + "missingDataStrategy": "auto", + "isPublished": False, + "timelineMinTime": "earliest", + "hasChartTab": True, + "timelineMaxTime": "latest", + "sortBy": "total", +} + + +def default_converter(o): + if isinstance(o, np.integer): # ignore + return int(o) + else: + raise TypeError(f"Object of type {o.__class__.__name__} is not JSON serializable") + + +def grapher_chart( + catalog_path: Optional[str] = None, + variable_id: Optional[int] = None, + variable: Optional[Variable] = None, + chart_config: Optional[Dict[str, Any]] = None, + owid_env: OWIDEnv = OWID_ENV, + selected_entities: Optional[list] = None, + num_sample_selected_entities: int = 5, + height=600, + **kwargs, +): + """Plot a Grapher chart using the Grapher API. + + You can either plot a given chart config (using chart_config) or plot an indicator with its default metadata using either catalog_path, variable_id or variable. + + Parameters + ---------- + catalog_path : Optional[str], optional + Path to the catalog file, by default None + variable_id : Optional[int], optional + Variable ID, by default None + variable : Optional[Variable], optional + Variable object, by default None + chart_config : Optional[Dict[str, Any]], optional + Configuration of the chart, by default None + owid_env : OWIDEnv, optional + Environment configuration, by default OWID_ENV + selected_entities : Optional[list], optional + List of entities to plot, by default None. If None, a random sample of num_sample_selected_entities will be plotted. + num_sample_selected_entities : int, optional + Number of entities to sample if selected_entities is None, by default 5. If there are less entities than this number, all will be plotted. + height : int, optional + Height of the chart, by default 600 + """ + # Check we have all needed to plot the chart + if (catalog_path is None) and (variable_id is None) and (variable is None) and (chart_config is None): + raise ValueError("Either catalog_path, variable_id, variable or chart_config must be provided") + + # Get data / metadata if no chart config is provided + if chart_config is None: + # Get variable data + df = load_variable_data( + catalog_path=catalog_path, variable_id=variable_id, variable=variable, owid_env=owid_env + ) + + # Define chart config + chart_config = deepcopy(CONFIG_BASE) + chart_config["dimensions"] = [{"property": "y", "variableId": variable_id}] + + ## Selected entities? + if selected_entities is not None: + chart_config["selectedEntityNames"] = selected_entities + else: + entities = list(df["entity"].unique()) + chart_config["selectedEntityNames"] = sample(entities, min(len(entities), num_sample_selected_entities)) + + _chart_html(chart_config, owid_env, height=height, **kwargs) + + +def _chart_html(chart_config: Dict[str, Any], owid_env: OWIDEnv, height=600, **kwargs): + """Plot a Grapher chart using the Grapher API. + + Parameters + ---------- + chart_config : Dict[str, Any] + Configuration of the chart. + owid_env : OWIDEnv + Environment configuration. This is needed to access the correct API (changes between servers). + """ + chart_config_tmp = deepcopy(chart_config) + + chart_config_tmp["bakedGrapherURL"] = f"{owid_env.base_site}/grapher" + chart_config_tmp["adminBaseUrl"] = owid_env.base_site + chart_config_tmp["dataApiUrl"] = f"{owid_env.indicators_url}/" + + HTML = f""" + + +
+
+
+
+ + + +
+ """ + + components.html(HTML, height=height, **kwargs) diff --git a/etl/compare.py b/etl/compare.py index 2cb11f19de0..c0bd086c662 100644 --- a/etl/compare.py +++ b/etl/compare.py @@ -16,9 +16,9 @@ from rich_click.rich_command import RichCommand from rich_click.rich_group import RichGroup -from apps.backport.datasync.data_metadata import variable_data_df_from_s3 from etl import tempcompare from etl.db import get_engine, read_sql +from etl.grapher_io import variable_data_df_from_s3 @click.group(name="compare", cls=RichGroup) diff --git a/etl/config.py b/etl/config.py index 7f1ba9df971..dd7d1e9f66d 100644 --- a/etl/config.py +++ b/etl/config.py @@ -478,9 +478,13 @@ def dataset_admin_site(self, dataset_id: str | int) -> str: """Get dataset admin url.""" return f"{self.admin_site}/datasets/{dataset_id}/" + def indicator_admin_site(self, variable_id: str | int) -> str: + """Get indicator admin url.""" + return f"{self.admin_site}/variables/{variable_id}/" + def variable_admin_site(self, variable_id: str | int) -> str: """Get variable admin url.""" - return f"{self.admin_site}/variables/{variable_id}/" + return self.indicator_admin_site(variable_id) def chart_admin_site(self, chart_id: str | int) -> str: """Get chart admin url.""" diff --git a/etl/db.py b/etl/db.py index a25130ccf42..690e0ce7524 100644 --- a/etl/db.py +++ b/etl/db.py @@ -1,13 +1,13 @@ import functools import os import warnings -from typing import Any, Dict, List, Optional +from typing import Any, Dict, Optional from urllib.parse import quote import pandas as pd import pymysql import structlog -import validators +from deprecated import deprecated from sqlalchemy import create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import Session @@ -25,6 +25,7 @@ def can_connect(conf: Optional[Dict[str, Any]] = None) -> bool: return False +@deprecated("This function is deprecated. Instead, look at using etl.db.read_sql function.") def get_connection(conf: Optional[Dict[str, Any]] = None) -> pymysql.Connection: "Connect to the Grapher database." cf: Any = dict_to_object(conf) if conf else config @@ -60,388 +61,10 @@ def get_engine(conf: Optional[Dict[str, Any]] = None) -> Engine: return _get_engine_cached(cf, pid) -def get_dataset_id( - dataset_name: str, db_conn: Optional[pymysql.Connection] = None, version: Optional[str] = None -) -> Any: - """Get the dataset ID of a specific dataset name from database. - - If more than one dataset is found for the same name, or if no dataset is found, an error is raised. - - Parameters - ---------- - dataset_name : str - Dataset name. - db_conn : pymysql.Connection - Connection to database. Defaults to None, in which case a default connection is created (uses etl.config). - version : str - ETL version of the dataset. This is necessary when multiple datasets have the same title. In such a case, if - version is not given, the function will raise an error. - - Returns - ------- - dataset_id : int - Dataset ID. - - """ - if db_conn is None: - db_conn = get_connection() - - query = f""" - SELECT id - FROM datasets - WHERE name = '{dataset_name}' - """ - - if version: - query += f" AND version = '{version}'" - - with db_conn.cursor() as cursor: - cursor.execute(query) - result = cursor.fetchall() - - assert len(result) == 1, f"Ambiguous or unknown dataset name '{dataset_name}'" - dataset_id = result[0][0] - return dataset_id - - -def get_variables_in_dataset( - dataset_id: int, only_used_in_charts: bool = False, db_conn: Optional[pymysql.Connection] = None -) -> Any: - """Get all variables data for a specific dataset ID from database. - - Parameters - ---------- - dataset_id : int - Dataset ID. - only_used_in_charts : bool - True to select variables only if they have been used in at least one chart. False to select all variables. - db_conn : pymysql.Connection - Connection to database. Defaults to None, in which case a default connection is created (uses etl.config). - - Returns - ------- - variables_data : pd.DataFrame - Variables data for considered dataset. - - """ - if db_conn is None: - db_conn = get_connection() - - query = f""" - SELECT * - FROM variables - WHERE datasetId = {dataset_id} - """ - if only_used_in_charts: - query += """ - AND id IN ( - SELECT DISTINCT variableId - FROM chart_dimensions - ) - """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - variables_data = pd.read_sql(query, con=db_conn) - return variables_data - - -def _get_variables_data_with_filter( - field_name: Optional[str] = None, - field_values: Optional[List[Any]] = None, - db_conn: Optional[pymysql.Connection] = None, -) -> Any: - if db_conn is None: - db_conn = get_connection() - - if field_values is None: - field_values = [] - - # Construct the SQL query with a placeholder for each value in the list. - query = "SELECT * FROM variables" - - if (field_name is not None) and (len(field_values) > 0): - query += f"\nWHERE {field_name} IN ({', '.join(['%s'] * len(field_values))});" - - # Execute the query. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - variables_data = pd.read_sql(query, con=db_conn, params=field_values) - - assert set(variables_data[field_name]) <= set(field_values), f"Unexpected values for {field_name}." - - # Warn about values that were not found. - missing_values = set(field_values) - set(variables_data[field_name]) - if len(missing_values) > 0: - log.warning(f"Values of {field_name} not found in database: {missing_values}") - - return variables_data - - -def get_variables_data( - filter: Optional[Dict[str, Any]] = None, - condition: Optional[str] = "OR", - db_conn: Optional[pymysql.Connection] = None, -) -> pd.DataFrame: - """Get data from variables table, given a certain condition. - - Parameters - ---------- - filter : Optional[Dict[str, Any]], optional - Filter to apply to the data, which must contain a field name and a list of field values, - e.g. {"id": [123456, 234567, 345678]}. - In principle, multiple filters can be given. - condition : Optional[str], optional - In case multiple filters are given, this parameter specifies whether the output filters should be the union - ("OR") or the intersection ("AND"). - db_conn : pymysql.Connection - Connection to database. Defaults to None, in which case a default connection is created (uses etl.config). - - Returns - ------- - df : pd.DataFrame - Variables data. - - """ - # NOTE: This function should be optimized. Instead of fetching data for each filter, their conditions should be - # combined with OR or AND before executing the query. - - # Initialize an empty dataframe. - if filter is not None: - df = pd.DataFrame({"id": []}).astype({"id": int}) - for field_name, field_values in filter.items(): - _df = _get_variables_data_with_filter(field_name=field_name, field_values=field_values, db_conn=db_conn) - if condition == "OR": - df = pd.concat([df, _df], axis=0) - elif condition == "AND": - df = pd.merge(df, _df, on="id", how="inner") - else: - raise ValueError(f"Invalid condition: {condition}") - else: - # Fetch data for all variables. - df = _get_variables_data_with_filter(db_conn=db_conn) - - return df - - -def get_all_datasets(archived: bool = True, db_conn: Optional[pymysql.Connection] = None) -> pd.DataFrame: - """Get all datasets in database. - - Parameters - ---------- - db_conn : pymysql.connections.Connection - Connection to database. Defaults to None, in which case a default connection is created (uses etl.config). - - Returns - ------- - datasets : pd.DataFrame - All datasets in database. Table with three columns: dataset ID, dataset name, dataset namespace. - """ - if db_conn is None: - db_conn = get_connection() - - query = " SELECT namespace, name, id, updatedAt, isArchived FROM datasets" - if not archived: - query += " WHERE isArchived = 0" - datasets = pd.read_sql(query, con=db_conn) - return datasets.sort_values(["name", "namespace"]) - - def dict_to_object(d): return type("DynamicObject", (object,), d)() -def get_charts_slugs(db_conn: Optional[pymysql.Connection] = None) -> pd.DataFrame: - if db_conn is None: - db_conn = get_connection() - - # Get a dataframe chart_id,char_slug, for all charts that have variables with an ETL path. - query = """\ - SELECT - c.id AS chart_id, - cc.slug AS chart_slug - FROM charts c - JOIN chart_configs cc ON c.configId = cc.id - LEFT JOIN chart_dimensions cd ON c.id = cd.chartId - LEFT JOIN variables v ON cd.variableId = v.id - WHERE - v.catalogPath IS NOT NULL - ORDER BY - c.id ASC; - """ - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - df = pd.read_sql(query, con=db_conn) - - # Remove duplicated rows. - df = df.drop_duplicates().reset_index(drop=True) - - if len(df[df.duplicated(subset="chart_id")]) > 0: - log.warning("There are duplicated chart ids in the chart_ids and slugs table.") - - return df - - -def get_charts_views(db_conn: Optional[pymysql.Connection] = None) -> pd.DataFrame: - if db_conn is None: - db_conn = get_connection() - - # Assumed base url for all charts. - base_url = "https://ourworldindata.org/grapher/" - - # Note that for now we extract data for all dates. - # It seems that the table only has data for the last day. - query = f"""\ - SELECT - url, - views_7d, - views_14d, - views_365d - FROM - analytics_pageviews - WHERE - url LIKE '{base_url}%'; - """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - df = pd.read_sql(query, con=db_conn) - - # For some reason, there are spurious urls, clean some of them. - # Note that validators.url() returns a ValidationError object (instead of False) when the url has spaces. - is_url_invalid = [(validators.url(url) is False) or (" " in url) for url in df["url"]] - df = df.drop(df[is_url_invalid].index).reset_index(drop=True) - - # Note that some of the returned urls may still be invalid, for example "https://ourworldindata.org/grapher/132". - - # Add chart slug. - df["slug"] = [url.replace(base_url, "") for url in df["url"]] - - # Remove url. - df = df.drop(columns=["url"], errors="raise") - - if len(df[df.duplicated(subset="slug")]) > 0: - log.warning("There are duplicated slugs in the chart analytics table.") - - return df - - -def get_info_for_etl_datasets(db_conn: Optional[pymysql.Connection] = None) -> pd.DataFrame: - if db_conn is None: - db_conn = get_connection() - - # First, increase the GROUP_CONCAT limit, to avoid the list of chart ids to be truncated. - GROUP_CONCAT_MAX_LEN = 4096 - cursor = db_conn.cursor() - cursor.execute(f"SET SESSION group_concat_max_len = {GROUP_CONCAT_MAX_LEN};") - db_conn.commit() - - query = """\ - SELECT - q1.datasetId AS dataset_id, - d.name AS dataset_name, - q1.etlPath AS etl_path, - d.isArchived AS is_archived, - d.isPrivate AS is_private, - q2.chartIds AS chart_ids, - q2.updatePeriodDays AS update_period_days - FROM - (SELECT - datasetId, - MIN(catalogPath) AS etlPath - FROM - variables - WHERE - catalogPath IS NOT NULL - GROUP BY - datasetId) q1 - LEFT JOIN - (SELECT - d.id AS datasetId, - d.isArchived, - d.isPrivate, - d.updatePeriodDays, - GROUP_CONCAT(DISTINCT c.id) AS chartIds - FROM - datasets d - JOIN variables v ON v.datasetId = d.id - JOIN chart_dimensions cd ON cd.variableId = v.id - JOIN charts c ON c.id = cd.chartId - JOIN chart_configs cc ON c.configId = cc.id - WHERE - json_extract(cc.full, "$.isPublished") = TRUE - GROUP BY - d.id) q2 - ON q1.datasetId = q2.datasetId - JOIN - datasets d ON q1.datasetId = d.id - ORDER BY - q1.datasetId ASC; - - """ - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - df = pd.read_sql(query, con=db_conn) - - if max([len(row) for row in df["chart_ids"] if row is not None]) == GROUP_CONCAT_MAX_LEN: - log.error( - f"The value of group_concat_max_len (set to {GROUP_CONCAT_MAX_LEN}) has been exceeded." - "This means that the list of chart ids will be incomplete in some cases. Consider increasing it." - ) - - # Get mapping of chart ids to slugs. - chart_id_to_slug = get_charts_slugs(db_conn=db_conn).set_index("chart_id")["chart_slug"].to_dict() - - # Instead of having a string of chart ids, make chart_ids a column with lists of integers. - df["chart_ids"] = [ - [int(chart_id) for chart_id in chart_ids.split(",")] if chart_ids else [] for chart_ids in df["chart_ids"] - ] - # Add a column with lists of chart slugs. - # For each row, it will be a list of tuples (chart_id, chart_slug), - # e.g. [(123, "chart-slug"), (234, "another-chart-slug"), ...]. - df["chart_slugs"] = [ - [(chart_id, chart_id_to_slug[chart_id]) for chart_id in chart_ids] if chart_ids else [] - for chart_ids in df["chart_ids"] - ] - - # Add chart analytics. - views_df = get_charts_views(db_conn=db_conn).set_index("slug") - # Create a column for each of the views metrics. - # For each row, it will be a list of tuples (chart_id, views), - # e.g. [(123, 1000), (234, 2000), ...]. - for metric in views_df.columns: - df[metric] = [ - [ - (chart_id, views_df[metric][chart_id_to_slug[chart_id]]) - for chart_id in chart_ids - if chart_id_to_slug[chart_id] in views_df.index - ] - if chart_ids - else [] - for chart_ids in df["chart_ids"] - ] - - # Make is_archived and is_private boolean columns. - df["is_archived"] = df["is_archived"].astype(bool) - df["is_private"] = df["is_private"].astype(bool) - - # Sanity check. - unknown_channels = set([etl_path.split("/")[0] for etl_path in set(df["etl_path"])]) - {"grapher"} - if len(unknown_channels) > 0: - log.error( - "Variables in grapher DB are expected to come only from ETL grapher channel, " - f"but other channels were found: {unknown_channels}" - ) - - # Create a column with the step name. - # First assume all steps are public (hence starting with "data://"). - # Then edit private steps so they start with "data-private://". - df["step"] = ["data://" + "/".join(etl_path.split("#")[0].split("/")[:-1]) for etl_path in df["etl_path"]] - df.loc[df["is_private"], "step"] = df[df["is_private"]]["step"].str.replace("data://", "data-private://") - - return df - - def read_sql(sql: str, engine: Optional[Engine | Session] = None, *args, **kwargs) -> pd.DataFrame: """Wrapper around pd.read_sql that creates a connection and closes it after reading the data. This adds overhead, so if you need performance, reuse the same connection and cursor. @@ -456,58 +79,3 @@ def read_sql(sql: str, engine: Optional[Engine | Session] = None, *args, **kwarg return pd.read_sql(sql, engine.bind, *args, **kwargs) else: raise ValueError(f"Unsupported engine type {type(engine)}") - - -def get_dataset_charts(dataset_ids: List[str], db_conn: Optional[pymysql.Connection] = None) -> pd.DataFrame: - if db_conn is None: - db_conn = get_connection() - - dataset_ids_str = ", ".join(map(str, dataset_ids)) - - query = f""" - SELECT - d.id AS dataset_id, - d.name AS dataset_name, - q2.chartIds AS chart_ids - FROM - (SELECT - d.id, - d.name - FROM - datasets d - WHERE - d.id IN ({dataset_ids_str})) d - LEFT JOIN - (SELECT - v.datasetId, - GROUP_CONCAT(DISTINCT c.id) AS chartIds - FROM - variables v - JOIN chart_dimensions cd ON cd.variableId = v.id - JOIN charts c ON c.id = cd.chartId - WHERE - v.datasetId IN ({dataset_ids_str}) - GROUP BY - v.datasetId) q2 - ON d.id = q2.datasetId - ORDER BY - d.id ASC; - """ - - # First, increase the GROUP_CONCAT limit, to avoid the list of chart ids to be truncated. - with db_conn.cursor() as cursor: - cursor.execute("SET SESSION group_concat_max_len = 10000;") - - if len(dataset_ids) == 0: - return pd.DataFrame({"dataset_id": [], "dataset_name": [], "chart_ids": []}) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - df = pd.read_sql(query, con=db_conn) - - # Instead of having a string of chart ids, make chart_ids a column with lists of integers. - df["chart_ids"] = [ - [int(chart_id) for chart_id in chart_ids.split(",")] if chart_ids else [] for chart_ids in df["chart_ids"] - ] - - return df diff --git a/etl/explorer.py b/etl/explorer.py index 003d3f5f0af..b6ae5c20fce 100644 --- a/etl/explorer.py +++ b/etl/explorer.py @@ -15,8 +15,8 @@ from structlog import get_logger from etl import config -from etl.db import get_variables_data from etl.files import upload_file_to_server +from etl.grapher_io import get_variables_data from etl.paths import EXPLORERS_DIR # Initialize logger. diff --git a/etl/explorer_helpers.py b/etl/explorer_helpers.py index 0a9f684c8b7..f61b35ecb71 100644 --- a/etl/explorer_helpers.py +++ b/etl/explorer_helpers.py @@ -5,8 +5,8 @@ from structlog import get_logger from etl import config -from etl.db import get_variables_data from etl.files import upload_file_to_server +from etl.grapher_io import get_variables_data from etl.paths import EXPLORERS_DIR # Initialize logger. diff --git a/etl/grapher_helpers.py b/etl/grapher_helpers.py index f0ec944a1be..058598ee9f2 100644 --- a/etl/grapher_helpers.py +++ b/etl/grapher_helpers.py @@ -18,9 +18,9 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import Session -from apps.backport.datasync import data_metadata as dm from etl.db import get_engine, read_sql from etl.files import checksum_str +from etl.grapher_io import add_entity_code_and_name log = structlog.get_logger() @@ -549,7 +549,7 @@ def _adapt_table_for_grapher(table: catalog.Table, engine: Engine) -> catalog.Ta # Add entity code and name with Session(engine) as session: - table = dm.add_entity_code_and_name(session, table).copy_metadata(table) + table = add_entity_code_and_name(session, table).copy_metadata(table) table = table.set_index(["entityId", "entityCode", "entityName", "year"] + dim_names) diff --git a/etl/grapher_import.py b/etl/grapher_import.py index 3e83d78dd5e..e6a97d4df48 100644 --- a/etl/grapher_import.py +++ b/etl/grapher_import.py @@ -259,7 +259,7 @@ def upsert_table( with Session(engine) as session: # compare checksums try: - db_variable = gm.Variable.load_from_catalog_path(session, catalog_path) + db_variable = gm.Variable.from_catalog_path(session, catalog_path) except NoResultFound: db_variable = None diff --git a/etl/grapher_io.py b/etl/grapher_io.py new file mode 100644 index 00000000000..23189454ab1 --- /dev/null +++ b/etl/grapher_io.py @@ -0,0 +1,798 @@ +import concurrent.futures +import warnings +from http.client import RemoteDisconnected +from typing import Any, Dict, List, Optional, cast +from urllib.error import HTTPError, URLError + +import pandas as pd +import pymysql +import requests +import structlog +import validators +from deprecated import deprecated +from sqlalchemy.engine import Engine +from sqlalchemy.orm import Session +from tenacity import Retrying +from tenacity.retry import retry_if_exception_type +from tenacity.stop import stop_after_attempt +from tenacity.wait import wait_fixed + +from etl import config +from etl.config import OWID_ENV, OWIDEnv +from etl.db import get_connection, read_sql +from etl.grapher_model import Dataset, Variable + +log = structlog.get_logger() + + +############################################################################################## +# Load from DB +############################################################################################## + + +def load_dataset_uris( + owid_env: OWIDEnv = OWID_ENV, +) -> List[str]: + """Get list of dataset URIs from the database.""" + with Session(owid_env.engine) as session: + datasets = Dataset.load_datasets_uri(session) + + return list(datasets["dataset_uri"]) + + +def load_variables_in_dataset( + dataset_uri: List[str], + owid_env: OWIDEnv = OWID_ENV, +) -> List[Variable]: + """Load Variable objects that belong to a dataset with URI `dataset_uri`.""" + with Session(owid_env.engine) as session: + indicators = Variable.load_variables_in_datasets(session, dataset_uri) + + return indicators + + +# Load variable object +def load_variable( + id_or_path: str | int, + owid_env: OWIDEnv = OWID_ENV, +) -> Variable: + """Load variable""" + with Session(owid_env.engine) as session: + variable = Variable.from_id_or_path( + session=session, + id_or_path=id_or_path, + ) + + return variable + + +############################################################################################## +# Load data/metadata (API) +############################################################################################## + + +# SINGLE INDICATOR +# Load variable metadata +def load_variable_metadata( + catalog_path: Optional[str] = None, + variable_id: Optional[int] = None, + variable: Optional[Variable] = None, + owid_env: OWIDEnv = OWID_ENV, +) -> Dict[str, Any]: + """Get metadata for an indicator based on its catalog path or variable id. + + Parameters + ---------- + catalog_path : str, optional + The path to the indicator in the catalog. + variable_id : int, optional + The ID of the indicator. + variable : Variable, optional + The indicator object. + """ + # Get variable + variable = ensure_load_variable(catalog_path, variable_id, variable, owid_env) + + # Get metadata + metadata = variable.get_metadata() + + return metadata + + +def load_variable_data( + catalog_path: Optional[str] = None, + variable_id: Optional[int] = None, + variable: Optional[Variable] = None, + owid_env: OWIDEnv = OWID_ENV, +) -> pd.DataFrame: + """Get data for an indicator based on its catalog path or variable id. + + Parameters + ---------- + cataslog_path : str, optional + The path to the indicator in the catalog. + variable_id : int, optional + The ID of the indicator. + variable : Variable, optional + The indicator object. + + """ + + # Get variable + variable = ensure_load_variable(catalog_path, variable_id, variable, owid_env) + + # Get data + with Session(owid_env.engine) as session: + df = variable.get_data(session=session) + + return df + + +def ensure_load_variable( + catalog_path: Optional[str] = None, + variable_id: Optional[int] = None, + variable: Optional[Variable] = None, + owid_env: OWIDEnv = OWID_ENV, +) -> Variable: + if variable is None: + if catalog_path is not None: + variable = load_variable(id_or_path=catalog_path, owid_env=owid_env) + elif variable_id is not None: + variable = load_variable(id_or_path=variable_id, owid_env=owid_env) + else: + raise ValueError("Either catalog_path, variable_id or variable must be provided") + return variable + + +############################################################################################## +# More optimized API access +# Most useful for bulk operations +# from apps.backport.datasync.data_metadata +############################################################################################## + + +def load_variables_data( + catalog_paths: Optional[List[str]] = None, + variable_ids: Optional[List[int]] = None, + variables: Optional[List[Variable]] = None, + owid_env: OWIDEnv = OWID_ENV, + workers: int = 1, + value_as_str: bool = True, +) -> pd.DataFrame: + """Get data for a list of indicators based on their catalog path or variable id. + + Priority: catalog_paths > variable_ids > variables + + Parameters + ---------- + cataslog_path : str, optional + The path to the indicator in the catalog. + variable_id : int, optional + The ID of the indicator. + variable : Variable, optional + The indicator object. + + """ + # Get variable IDs + variable_ids = _ensure_variable_ids(owid_env.engine, catalog_paths, variable_ids, variables) + + # Get variable + df = variable_data_df_from_s3( + owid_env.engine, + variable_ids=variable_ids, + workers=workers, + value_as_str=value_as_str, + ) + + return df + + +def load_variables_metadata( + catalog_paths: Optional[List[str]] = None, + variable_ids: Optional[List[int]] = None, + variables: Optional[List[Variable]] = None, + owid_env: OWIDEnv = OWID_ENV, + workers: int = 1, +) -> List[Dict[str, Any]]: + """Get metadata for a list of indicators based on their catalog path or variable id. + + Priority: catalog_paths > variable_ids > variables + + Parameters + ---------- + catalog_path : str, optional + The path to the indicator in the catalog. + variable_id : int, optional + The ID of the indicator. + variable : Variable, optional + The indicator object. + """ + + # Get variable IDs + variable_ids = _ensure_variable_ids(owid_env.engine, catalog_paths, variable_ids, variables) + + metadata = variable_metadata_df_from_s3( + variable_ids=variable_ids, + workers=workers, + env=owid_env, + ) + + return metadata + + +def _ensure_variable_ids( + engine: Engine, + catalog_paths: Optional[List[str]] = None, + variable_ids: Optional[List[int]] = None, + variables: Optional[List[Variable]] = None, +) -> List[int]: + if catalog_paths is not None: + with Session(engine) as session: + mapping = Variable.catalog_paths_to_variable_ids(session, catalog_paths=catalog_paths) + variable_ids = [int(i) for i in mapping.values()] + elif (variable_ids is None) and (variables is not None): + variable_ids = [variable.id for variable in variables] + + if variable_ids is None: + raise ValueError("Either catalog_paths, variable_ids or variables must be provided") + + return variable_ids + + +def variable_data_df_from_s3( + engine: Engine, + variable_ids: List[int] = [], + workers: int = 1, + value_as_str: bool = True, +) -> pd.DataFrame: + """Fetch data from S3 and add entity code and name from DB.""" + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + results = list(executor.map(_fetch_data_df_from_s3, variable_ids)) + + if isinstance(results, list) and all(isinstance(df, pd.DataFrame) for df in results): + df = pd.concat(cast(List[pd.DataFrame], results)) + else: + raise TypeError(f"results must be a list of pd.DataFrame, got {type(results)}") + + # we work with strings and convert to specific types later + if value_as_str: + df["value"] = df["value"].astype("string") + + with Session(engine) as session: + res = add_entity_code_and_name(session, df) + return res + + +def _fetch_data_df_from_s3(variable_id: int): + try: + # Cloudflare limits us to 600 requests per minute, retry in case we hit the limit + # NOTE: increase wait time or attempts if we hit the limit too often + for attempt in Retrying( + wait=wait_fixed(2), + stop=stop_after_attempt(3), + retry=retry_if_exception_type((URLError, RemoteDisconnected)), + ): + with attempt: + return ( + pd.read_json(config.variable_data_url(variable_id)) + .rename( + columns={ + "entities": "entityId", + "values": "value", + "years": "year", + } + ) + .assign(variableId=variable_id) + ) + # no data on S3 + except HTTPError: + return pd.DataFrame(columns=["variableId", "entityId", "year", "value"]) + + +def add_entity_code_and_name(session: Session, df: pd.DataFrame) -> pd.DataFrame: + if df.empty: + df["entityName"] = [] + df["entityCode"] = [] + return df + + unique_entities = df["entityId"].unique() + + entities = _fetch_entities(session, list(unique_entities)) + + if set(unique_entities) - set(entities.entityId): + missing_entities = set(unique_entities) - set(entities.entityId) + raise ValueError(f"Missing entities in the database: {missing_entities}") + + return pd.merge(df, entities.astype({"entityName": "category", "entityCode": "category"}), on="entityId") + + +def _fetch_entities(session: Session, entity_ids: List[int]) -> pd.DataFrame: + # Query entities from the database + q = """ + SELECT + id AS entityId, + name AS entityName, + code AS entityCode + FROM entities + WHERE id in %(entity_ids)s + """ + return read_sql(q, session, params={"entity_ids": entity_ids}) + + +def variable_metadata_df_from_s3( + variable_ids: List[int] = [], + workers: int = 1, + env: OWIDEnv | None = None, +) -> List[Dict[str, Any]]: + """Fetch data from S3 and add entity code and name from DB.""" + args = [variable_ids] + if env: + args += [[env for _ in range(len(variable_ids))]] + + with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: + results = list(executor.map(_fetch_metadata_from_s3, *args)) + + if not (isinstance(results, list) and all(isinstance(res, dict) for res in results)): + raise TypeError(f"results must be a list of dictionaries, got {type(results)}") + + return results # type: ignore + + +def _fetch_metadata_from_s3(variable_id: int, env: OWIDEnv | None = None) -> Dict[str, Any] | None: + try: + # Cloudflare limits us to 600 requests per minute, retry in case we hit the limit + # NOTE: increase wait time or attempts if we hit the limit too often + for attempt in Retrying( + wait=wait_fixed(2), + stop=stop_after_attempt(3), + retry=retry_if_exception_type((URLError, RemoteDisconnected)), + ): + with attempt: + if env is not None: + url = env.indicator_metadata_url(variable_id) + else: + url = config.variable_metadata_url(variable_id) + return requests.get(url).json() + # no data on S3 + except HTTPError: + return {} + + +############################################################################################## +# TO BE REVIEWED: +# This is code that could be deprecated / removed? +############################################################################################## + + +def get_dataset_id( + dataset_name: str, db_conn: Optional[pymysql.Connection] = None, version: Optional[str] = None +) -> Any: + """Get the dataset ID of a specific dataset name from database. + + If more than one dataset is found for the same name, or if no dataset is found, an error is raised. + + Parameters + ---------- + dataset_name : str + Dataset name. + db_conn : pymysql.Connection + Connection to database. Defaults to None, in which case a default connection is created (uses etl.config). + version : str + ETL version of the dataset. This is necessary when multiple datasets have the same title. In such a case, if + version is not given, the function will raise an error. + + Returns + ------- + dataset_id : int + Dataset ID. + + """ + if db_conn is None: + db_conn = get_connection() + + query = f""" + SELECT id + FROM datasets + WHERE name = '{dataset_name}' + """ + + if version: + query += f" AND version = '{version}'" + + with db_conn.cursor() as cursor: + cursor.execute(query) + result = cursor.fetchall() + + assert len(result) == 1, f"Ambiguous or unknown dataset name '{dataset_name}'" + dataset_id = result[0][0] + return dataset_id + + +@deprecated("This function is deprecated. Its logic will be soon moved to etl.grapher_model.Dataset.") +def get_variables_in_dataset( + dataset_id: int, only_used_in_charts: bool = False, db_conn: Optional[pymysql.Connection] = None +) -> Any: + """Get all variables data for a specific dataset ID from database. + + Parameters + ---------- + dataset_id : int + Dataset ID. + only_used_in_charts : bool + True to select variables only if they have been used in at least one chart. False to select all variables. + db_conn : pymysql.Connection + Connection to database. Defaults to None, in which case a default connection is created (uses etl.config). + + Returns + ------- + variables_data : pd.DataFrame + Variables data for considered dataset. + + """ + if db_conn is None: + db_conn = get_connection() + + query = f""" + SELECT * + FROM variables + WHERE datasetId = {dataset_id} + """ + if only_used_in_charts: + query += """ + AND id IN ( + SELECT DISTINCT variableId + FROM chart_dimensions + ) + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + variables_data = pd.read_sql(query, con=db_conn) + return variables_data + + +def get_all_datasets(archived: bool = True, db_conn: Optional[pymysql.Connection] = None) -> pd.DataFrame: + """Get all datasets in database. + + Parameters + ---------- + db_conn : pymysql.connections.Connection + Connection to database. Defaults to None, in which case a default connection is created (uses etl.config). + + Returns + ------- + datasets : pd.DataFrame + All datasets in database. Table with three columns: dataset ID, dataset name, dataset namespace. + """ + if db_conn is None: + db_conn = get_connection() + + query = " SELECT namespace, name, id, updatedAt, isArchived FROM datasets" + if not archived: + query += " WHERE isArchived = 0" + datasets = pd.read_sql(query, con=db_conn) + return datasets.sort_values(["name", "namespace"]) + + +def get_info_for_etl_datasets(db_conn: Optional[pymysql.Connection] = None) -> pd.DataFrame: + if db_conn is None: + db_conn = get_connection() + + # First, increase the GROUP_CONCAT limit, to avoid the list of chart ids to be truncated. + GROUP_CONCAT_MAX_LEN = 4096 + cursor = db_conn.cursor() + cursor.execute(f"SET SESSION group_concat_max_len = {GROUP_CONCAT_MAX_LEN};") + db_conn.commit() + + query = """\ + SELECT + q1.datasetId AS dataset_id, + d.name AS dataset_name, + q1.etlPath AS etl_path, + d.isArchived AS is_archived, + d.isPrivate AS is_private, + q2.chartIds AS chart_ids, + q2.updatePeriodDays AS update_period_days + FROM + (SELECT + datasetId, + MIN(catalogPath) AS etlPath + FROM + variables + WHERE + catalogPath IS NOT NULL + GROUP BY + datasetId) q1 + LEFT JOIN + (SELECT + d.id AS datasetId, + d.isArchived, + d.isPrivate, + d.updatePeriodDays, + GROUP_CONCAT(DISTINCT c.id) AS chartIds + FROM + datasets d + JOIN variables v ON v.datasetId = d.id + JOIN chart_dimensions cd ON cd.variableId = v.id + JOIN charts c ON c.id = cd.chartId + JOIN chart_configs cc ON c.configId = cc.id + WHERE + json_extract(cc.full, "$.isPublished") = TRUE + GROUP BY + d.id) q2 + ON q1.datasetId = q2.datasetId + JOIN + datasets d ON q1.datasetId = d.id + ORDER BY + q1.datasetId ASC; + + """ + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + df = pd.read_sql(query, con=db_conn) + + if max([len(row) for row in df["chart_ids"] if row is not None]) == GROUP_CONCAT_MAX_LEN: + log.error( + f"The value of group_concat_max_len (set to {GROUP_CONCAT_MAX_LEN}) has been exceeded." + "This means that the list of chart ids will be incomplete in some cases. Consider increasing it." + ) + + # Get mapping of chart ids to slugs. + chart_id_to_slug = get_charts_slugs(db_conn=db_conn).set_index("chart_id")["chart_slug"].to_dict() + + # Instead of having a string of chart ids, make chart_ids a column with lists of integers. + df["chart_ids"] = [ + [int(chart_id) for chart_id in chart_ids.split(",")] if chart_ids else [] for chart_ids in df["chart_ids"] + ] + # Add a column with lists of chart slugs. + # For each row, it will be a list of tuples (chart_id, chart_slug), + # e.g. [(123, "chart-slug"), (234, "another-chart-slug"), ...]. + df["chart_slugs"] = [ + [(chart_id, chart_id_to_slug[chart_id]) for chart_id in chart_ids] if chart_ids else [] + for chart_ids in df["chart_ids"] + ] + + # Add chart analytics. + views_df = get_charts_views(db_conn=db_conn).set_index("slug") + # Create a column for each of the views metrics. + # For each row, it will be a list of tuples (chart_id, views), + # e.g. [(123, 1000), (234, 2000), ...]. + for metric in views_df.columns: + df[metric] = [ + [ + (chart_id, views_df[metric][chart_id_to_slug[chart_id]]) + for chart_id in chart_ids + if chart_id_to_slug[chart_id] in views_df.index + ] + if chart_ids + else [] + for chart_ids in df["chart_ids"] + ] + + # Make is_archived and is_private boolean columns. + df["is_archived"] = df["is_archived"].astype(bool) + df["is_private"] = df["is_private"].astype(bool) + + # Sanity check. + unknown_channels = set([etl_path.split("/")[0] for etl_path in set(df["etl_path"])]) - {"grapher"} + if len(unknown_channels) > 0: + log.error( + "Variables in grapher DB are expected to come only from ETL grapher channel, " + f"but other channels were found: {unknown_channels}" + ) + + # Create a column with the step name. + # First assume all steps are public (hence starting with "data://"). + # Then edit private steps so they start with "data-private://". + df["step"] = ["data://" + "/".join(etl_path.split("#")[0].split("/")[:-1]) for etl_path in df["etl_path"]] + df.loc[df["is_private"], "step"] = df[df["is_private"]]["step"].str.replace("data://", "data-private://") + + return df + + +def get_charts_slugs(db_conn: Optional[pymysql.Connection] = None) -> pd.DataFrame: + if db_conn is None: + db_conn = get_connection() + + # Get a dataframe chart_id,char_slug, for all charts that have variables with an ETL path. + query = """\ + SELECT + c.id AS chart_id, + cc.slug AS chart_slug + FROM charts c + JOIN chart_configs cc ON c.configId = cc.id + LEFT JOIN chart_dimensions cd ON c.id = cd.chartId + LEFT JOIN variables v ON cd.variableId = v.id + WHERE + v.catalogPath IS NOT NULL + ORDER BY + c.id ASC; + """ + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + df = pd.read_sql(query, con=db_conn) + + # Remove duplicated rows. + df = df.drop_duplicates().reset_index(drop=True) + + if len(df[df.duplicated(subset="chart_id")]) > 0: + log.warning("There are duplicated chart ids in the chart_ids and slugs table.") + + return df + + +def get_charts_views(db_conn: Optional[pymysql.Connection] = None) -> pd.DataFrame: + if db_conn is None: + db_conn = get_connection() + + # Assumed base url for all charts. + base_url = "https://ourworldindata.org/grapher/" + + # Note that for now we extract data for all dates. + # It seems that the table only has data for the last day. + query = f"""\ + SELECT + url, + views_7d, + views_14d, + views_365d + FROM + analytics_pageviews + WHERE + url LIKE '{base_url}%'; + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + df = pd.read_sql(query, con=db_conn) + + # For some reason, there are spurious urls, clean some of them. + # Note that validators.url() returns a ValidationError object (instead of False) when the url has spaces. + is_url_invalid = [(validators.url(url) is False) or (" " in url) for url in df["url"]] + df = df.drop(df[is_url_invalid].index).reset_index(drop=True) + + # Note that some of the returned urls may still be invalid, for example "https://ourworldindata.org/grapher/132". + + # Add chart slug. + df["slug"] = [url.replace(base_url, "") for url in df["url"]] + + # Remove url. + df = df.drop(columns=["url"], errors="raise") + + if len(df[df.duplicated(subset="slug")]) > 0: + log.warning("There are duplicated slugs in the chart analytics table.") + + return df + + +def get_dataset_charts(dataset_ids: List[str], db_conn: Optional[pymysql.Connection] = None) -> pd.DataFrame: + if db_conn is None: + db_conn = get_connection() + + dataset_ids_str = ", ".join(map(str, dataset_ids)) + + query = f""" + SELECT + d.id AS dataset_id, + d.name AS dataset_name, + q2.chartIds AS chart_ids + FROM + (SELECT + d.id, + d.name + FROM + datasets d + WHERE + d.id IN ({dataset_ids_str})) d + LEFT JOIN + (SELECT + v.datasetId, + GROUP_CONCAT(DISTINCT c.id) AS chartIds + FROM + variables v + JOIN chart_dimensions cd ON cd.variableId = v.id + JOIN charts c ON c.id = cd.chartId + WHERE + v.datasetId IN ({dataset_ids_str}) + GROUP BY + v.datasetId) q2 + ON d.id = q2.datasetId + ORDER BY + d.id ASC; + """ + + # First, increase the GROUP_CONCAT limit, to avoid the list of chart ids to be truncated. + with db_conn.cursor() as cursor: + cursor.execute("SET SESSION group_concat_max_len = 10000;") + + if len(dataset_ids) == 0: + return pd.DataFrame({"dataset_id": [], "dataset_name": [], "chart_ids": []}) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + df = pd.read_sql(query, con=db_conn) + + # Instead of having a string of chart ids, make chart_ids a column with lists of integers. + df["chart_ids"] = [ + [int(chart_id) for chart_id in chart_ids.split(",")] if chart_ids else [] for chart_ids in df["chart_ids"] + ] + + return df + + +def get_variables_data( + filter: Optional[Dict[str, Any]] = None, + condition: Optional[str] = "OR", + db_conn: Optional[pymysql.Connection] = None, +) -> pd.DataFrame: + """Get data from variables table, given a certain condition. + + Parameters + ---------- + filter : Optional[Dict[str, Any]], optional + Filter to apply to the data, which must contain a field name and a list of field values, + e.g. {"id": [123456, 234567, 345678]}. + In principle, multiple filters can be given. + condition : Optional[str], optional + In case multiple filters are given, this parameter specifies whether the output filters should be the union + ("OR") or the intersection ("AND"). + db_conn : pymysql.Connection + Connection to database. Defaults to None, in which case a default connection is created (uses etl.config). + + Returns + ------- + df : pd.DataFrame + Variables data. + + """ + # NOTE: This function should be optimized. Instead of fetching data for each filter, their conditions should be + # combined with OR or AND before executing the query. + + # Initialize an empty dataframe. + if filter is not None: + df = pd.DataFrame({"id": []}).astype({"id": int}) + for field_name, field_values in filter.items(): + _df = _get_variables_data_with_filter(field_name=field_name, field_values=field_values, db_conn=db_conn) + if condition == "OR": + df = pd.concat([df, _df], axis=0) + elif condition == "AND": + df = pd.merge(df, _df, on="id", how="inner") + else: + raise ValueError(f"Invalid condition: {condition}") + else: + # Fetch data for all variables. + df = _get_variables_data_with_filter(db_conn=db_conn) + + return df + + +def _get_variables_data_with_filter( + field_name: Optional[str] = None, + field_values: Optional[List[Any]] = None, + db_conn: Optional[pymysql.Connection] = None, +) -> Any: + if db_conn is None: + db_conn = get_connection() + + if field_values is None: + field_values = [] + + # Construct the SQL query with a placeholder for each value in the list. + query = "SELECT * FROM variables" + + if (field_name is not None) and (len(field_values) > 0): + query += f"\nWHERE {field_name} IN ({', '.join(['%s'] * len(field_values))});" + + # Execute the query. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", UserWarning) + variables_data = pd.read_sql(query, con=db_conn, params=field_values) + + assert set(variables_data[field_name]) <= set(field_values), f"Unexpected values for {field_name}." + + # Warn about values that were not found. + missing_values = set(field_values) - set(variables_data[field_name]) + if len(missing_values) > 0: + log.warning(f"Values of {field_name} not found in database: {missing_values}") + + return variables_data diff --git a/etl/grapher_model.py b/etl/grapher_model.py index 927a5177ff9..4d0e117f38c 100644 --- a/etl/grapher_model.py +++ b/etl/grapher_model.py @@ -20,11 +20,13 @@ from datetime import date, datetime from enum import Enum from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Union, get_args +from typing import Any, Dict, List, Literal, Optional, Union, get_args, overload import humps import pandas as pd +import requests import structlog +from deprecated import deprecated from owid import catalog from owid.catalog.meta import VARIABLE_TYPE from sqlalchemy import ( @@ -427,7 +429,7 @@ def migrate_config(self, source_session: Session, target_session: Session) -> Di for source_var_id, source_var in source_variables.items(): if source_var.catalogPath: try: - target_var = Variable.load_from_catalog_path(target_session, source_var.catalogPath) + target_var = Variable.from_catalog_path(target_session, source_var.catalogPath) except NoResultFound: raise ValueError(f"variables.catalogPath not found in target: {source_var.catalogPath}") # old style variable, match it on name and dataset id @@ -622,6 +624,24 @@ def load_variables_for_dataset(cls, session: Session, dataset_id: int) -> list[" assert vars, f"Dataset {dataset_id} has no variables" return list(vars) + @classmethod + def load_datasets_uri(cls, session: Session): + query = """SELECT dataset_uri, createdAt + FROM ( + SELECT + namespace, + version, + shortName, + createdAt, + CONCAT('grapher/', namespace, '/', version, '/', shortName) AS dataset_uri + FROM + datasets d + ) AS derived + WHERE dataset_uri IS NOT NULL + ORDER BY createdAt DESC; + """ + return read_sql(query, session) + class SourceDescription(TypedDict, total=False): link: Optional[str] @@ -1162,17 +1182,116 @@ def from_variable_metadata( ) @classmethod + def load_variables_in_datasets(cls, session: Session, dataset_uris: List[str]) -> List["Variable"]: + conditions = [cls.catalogPath.startswith(uri) for uri in dataset_uris] + query = select(cls).where(or_(*conditions)) + results = session.scalars(query).all() + return list(results) + + @classmethod + @deprecated("Use from_id_or_path instead") def load_variable(cls, session: Session, variable_id: int) -> "Variable": + """D""" return session.scalars(select(cls).where(cls.id == variable_id)).one() @classmethod + @deprecated("Use from_id_or_path instead") def load_variables(cls, session: Session, variables_id: List[int]) -> List["Variable"]: return session.scalars(select(cls).where(cls.id.in_(variables_id))).all() # type: ignore + @overload + @classmethod + def from_id_or_path( + cls, + session: Session, + id_or_path: str | int, + ) -> "Variable": + ... + + @overload + @classmethod + def from_id_or_path( + cls, + session: Session, + id_or_path: List[str | int], + ) -> List["Variable"]: + ... + @classmethod - def load_from_catalog_path(cls, session: Session, catalog_path: str) -> "Variable": + def from_id_or_path( + cls, + session: Session, + id_or_path: int | str | List[str | int], + ) -> "Variable" | List["Variable"]: + """Load a variable from the database by its catalog path or variable ID.""" + # Single id + if isinstance(id_or_path, int): + return cls.from_id(session=session, variable_id=id_or_path) + # Single path + elif isinstance(id_or_path, str): + return cls.from_catalog_path(session=session, catalog_path=id_or_path) + + # Multiple path or id + elif isinstance(id_or_path, list): + # Filter the list to ensure only integers are passed + int_ids = [i for i in id_or_path if isinstance(i, int)] + str_ids = [i for i in id_or_path if isinstance(i, str)] + # Multiple IDs + if len(int_ids) == len(id_or_path): + return cls.from_id(session=session, variable_id=int_ids) + # Multiple paths + elif len(str_ids) == len(id_or_path): + return cls.from_catalog_path(session=session, catalog_path=str_ids) + else: + raise TypeError("All elements in the list must be integers") + + # # Ensure mutual exclusivity of catalog_path and variable_id + # if (catalog_path is not None) and (variable_id is not None): + # raise ValueError("Only one of catalog_path or variable_id can be provided") + + # if (catalog_path is not None) & isinstance(catalog_path, (str, list)): + # return cls.from_catalog_path(session=session, catalog_path=catalog_path) + # elif isinstance(catalog_path, (int, list)): + # return cls.from_id(session=session, variable_id=variable_id) + # else: + # raise ValueError("Either catalog_path or variable_id must be provided") + + @overload + @classmethod + def from_catalog_path(cls, session: Session, catalog_path: str) -> "Variable": + ... + + @overload + @classmethod + def from_catalog_path(cls, session: Session, catalog_path: List[str]) -> List["Variable"]: + ... + + @classmethod + def from_catalog_path(cls, session: Session, catalog_path: str | List[str]) -> "Variable" | List["Variable"]: + """Load a variable from the DB by its catalog path.""" assert "#" in catalog_path, "catalog_path should end with #indicator_short_name" - return session.scalars(select(cls).where(cls.catalogPath == catalog_path)).one() + if isinstance(catalog_path, str): + return session.scalars(select(cls).where(cls.catalogPath == catalog_path)).one() + elif isinstance(catalog_path, list): + return session.scalars(select(cls).where(cls.catalogPath.in_(catalog_path))).all() # type: ignore + + @overload + @classmethod + def from_id(cls, session: Session, variable_id: int) -> "Variable": + ... + + @overload + @classmethod + def from_id(cls, session: Session, variable_id: List[int]) -> List["Variable"]: + ... + + @classmethod + def from_id(cls, session: Session, variable_id: int | List[int]) -> "Variable" | List["Variable"]: + """Load a variable (or list of variables) from the DB by its ID path.""" + if isinstance(variable_id, int): + return session.scalars(select(cls).where(cls.id == variable_id)).one() + elif isinstance(variable_id, list): + return session.scalars(select(cls).where(cls.id.in_(variable_id))).all() # type: ignore @classmethod def catalog_paths_to_variable_ids(cls, session: Session, catalog_paths: List[str]) -> Dict[str, int]: @@ -1251,6 +1370,24 @@ def override_yaml_path(self) -> Path: """Return path to indicator YAML file.""" return self.step_path.with_suffix(".meta.override.yml") + def get_data(self, session: Optional[Session] = None) -> pd.DataFrame: + """Get variable data from S3. + + If session is given, entity codes are replaced with entity names. + """ + data = requests.get(self.s3_data_path(typ="http")).json() + df = pd.DataFrame(data) + + if session is not None: + df = add_entity_name(session=session, df=df, col_id="entities", col_name="entity") + + return df + + def get_metadata(self) -> Dict[str, Any]: + metadata = requests.get(self.s3_metadata_path(typ="http")).json() + + return metadata + class ChartDimensions(Base): __tablename__ = "chart_dimensions" @@ -1629,3 +1766,73 @@ def _is_float(x): return False else: return True + + +def add_entity_name( + session: Session, + df: pd.DataFrame, + col_id: str, + col_name: str = "entity", + col_code: Optional[str] = None, + remove_id: bool = True, +) -> pd.DataFrame: + # Initialize + if df.empty: + df[col_name] = [] + if col_code is not None: + df[col_code] = [] + return df + + # Get entity names + unique_entities = df[col_id].unique() + entities = _fetch_entities(session, list(unique_entities), col_id, col_name, col_code) + + # Sanity check + if set(unique_entities) - set(entities[col_id]): + missing_entities = set(unique_entities) - set(entities[col_id]) + raise ValueError(f"Missing entities in the database: {missing_entities}") + + # Set dtypes + dtypes = {col_name: "category"} + if col_code is not None: + dtypes[col_code] = "category" + df = pd.merge(df, entities.astype(dtypes), on=col_id) + + # Remove entity IDs + if remove_id: + df = df.drop(columns=[col_id]) + + return df + + +def _fetch_entities( + session: Session, + entity_ids: List[int], + col_id: Optional[str] = None, + col_name: Optional[str] = None, + col_code: Optional[str] = None, +) -> pd.DataFrame: + # Query entities from the database + q = """ + SELECT + id AS entityId, + name AS entityName, + code AS entityCode + FROM entities + WHERE id in %(entity_ids)s + """ + df = read_sql(q, session, params={"entity_ids": entity_ids}) + + # Rename columns + column_renames = {} + if col_id is not None: + column_renames["entityId"] = col_id + if col_name is not None: + column_renames["entityName"] = col_name + if col_code is not None: + column_renames["entityCode"] = col_code + else: + df = df.drop(columns=["entityCode"]) + + df = df.rename(columns=column_renames) + return df diff --git a/etl/match_variables.py b/etl/match_variables.py index d791270e203..e124054a7d0 100644 --- a/etl/match_variables.py +++ b/etl/match_variables.py @@ -9,7 +9,8 @@ from rapidfuzz import fuzz from structlog import get_logger -from etl import db +from etl.db import get_connection +from etl.grapher_io import get_dataset_id, get_variables_in_dataset # If True, identical variables will be matched automatically (by string comparison). # If False, variables with identical names will appear in comparison. @@ -124,19 +125,15 @@ def main( if Path(output_file).suffix != ".json": raise ValueError(f"`output_file` ({output_file}) should point to a JSON file ('*.json')!") - with db.get_connection() as db_conn: + with get_connection() as db_conn: # Get old and new dataset ids. - old_dataset_id = db.get_dataset_id(db_conn=db_conn, dataset_name=old_dataset_name) - new_dataset_id = db.get_dataset_id(db_conn=db_conn, dataset_name=new_dataset_name) + old_dataset_id = get_dataset_id(db_conn=db_conn, dataset_name=old_dataset_name) + new_dataset_id = get_dataset_id(db_conn=db_conn, dataset_name=new_dataset_name) # Get variables from old dataset that have been used in at least one chart. - old_indicators = db.get_variables_in_dataset( - db_conn=db_conn, dataset_id=old_dataset_id, only_used_in_charts=True - ) + old_indicators = get_variables_in_dataset(db_conn=db_conn, dataset_id=old_dataset_id, only_used_in_charts=True) # Get all variables from new dataset. - new_indicators = db.get_variables_in_dataset( - db_conn=db_conn, dataset_id=new_dataset_id, only_used_in_charts=False - ) + new_indicators = get_variables_in_dataset(db_conn=db_conn, dataset_id=new_dataset_id, only_used_in_charts=False) # Manually map old variable names to new variable names. mapping = map_old_and_new_indicators( diff --git a/etl/scripts/faostat/archive/migrate_to_new_metadata.py b/etl/scripts/faostat/archive/migrate_to_new_metadata.py index d750d9a3850..9e5349c81fe 100644 --- a/etl/scripts/faostat/archive/migrate_to_new_metadata.py +++ b/etl/scripts/faostat/archive/migrate_to_new_metadata.py @@ -11,8 +11,8 @@ from owid import catalog from structlog import get_logger -from etl import db from etl.files import yaml_dump +from etl.grapher_io import get_dataset_id, get_variables_in_dataset from etl.paths import DATA_DIR, STEP_DIR # Initialize logger. @@ -44,14 +44,14 @@ def main(): ) continue try: - dataset_id = db.get_dataset_id(dataset_name=dataset_name, version=VERSION) # type: ignore[reportArgumentType] + dataset_id = get_dataset_id(dataset_name=dataset_name, version=VERSION) # type: ignore[reportArgumentType] except AssertionError: log.error( f"Grapher dataset for {domain} could not be found in the database. " f"Run `etl run {domain} --grapher` and try again." ) continue - variables = db.get_variables_in_dataset(dataset_id=dataset_id, only_used_in_charts=True) + variables = get_variables_in_dataset(dataset_id=dataset_id, only_used_in_charts=True) if len(variables) > 0: # variables_in_charts[domain] = variables.set_index("shortName").T.to_dict() variables_dict = {} diff --git a/etl/scripts/faostat/create_chart_revisions.py b/etl/scripts/faostat/create_chart_revisions.py index 80c9b779846..3e6ef3c6e6a 100644 --- a/etl/scripts/faostat/create_chart_revisions.py +++ b/etl/scripts/faostat/create_chart_revisions.py @@ -20,7 +20,8 @@ from owid.datautils.dataframes import map_series from structlog import get_logger -from etl import db +from etl.db import get_connection +from etl.grapher_io import get_dataset_id, get_variables_in_dataset # from etl.chart_revision.v2.core import create_and_submit_charts_revisions from etl.paths import DATA_DIR @@ -93,10 +94,10 @@ def map_old_to_new_variable_names(variables_old: List[str], variables_new: List[ def get_grapher_data_for_old_and_new_variables( dataset_old: Dataset, dataset_new: Dataset ) -> Tuple[Optional[pd.DataFrame], Optional[pd.DataFrame]]: - with db.get_connection() as db_conn: + with get_connection() as db_conn: try: # Get old and new dataset ids. - dataset_id_old = db.get_dataset_id( + dataset_id_old = get_dataset_id( db_conn=db_conn, dataset_name=dataset_old.metadata.title, # type: ignore version=dataset_old.metadata.version, # type: ignore[reportArgumentType] @@ -105,7 +106,7 @@ def get_grapher_data_for_old_and_new_variables( log.error(f"Dataset {dataset_old.metadata.title} not found in grapher DB.") return None, None try: - dataset_id_new = db.get_dataset_id( + dataset_id_new = get_dataset_id( db_conn=db_conn, dataset_name=dataset_new.metadata.title, # type: ignore version=dataset_new.metadata.version, # type: ignore[reportArgumentType] @@ -115,11 +116,11 @@ def get_grapher_data_for_old_and_new_variables( return None, None # Get variables from old dataset that have been used in at least one chart. - grapher_variables_old = db.get_variables_in_dataset( + grapher_variables_old = get_variables_in_dataset( db_conn=db_conn, dataset_id=dataset_id_old, only_used_in_charts=True ) # Get all variables from new dataset. - grapher_variables_new = db.get_variables_in_dataset( + grapher_variables_new = get_variables_in_dataset( db_conn=db_conn, dataset_id=dataset_id_new, only_used_in_charts=False ) diff --git a/etl/version_tracker.py b/etl/version_tracker.py index 8e37d6b8bf7..2617527369d 100644 --- a/etl/version_tracker.py +++ b/etl/version_tracker.py @@ -12,7 +12,8 @@ from etl import paths from etl.config import ADMIN_HOST -from etl.db import can_connect, get_info_for_etl_datasets +from etl.db import can_connect +from etl.grapher_io import get_info_for_etl_datasets from etl.steps import extract_step_attributes, load_dag, reverse_graph log = structlog.get_logger() diff --git a/pyproject.toml b/pyproject.toml index 0dd43b9ed69..68db3cc4a00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ dependencies = [ "owid-datautils", "owid-repack", "walden", + "deprecated>=1.2.14", "scikit-learn>=1.5.2", ] @@ -116,7 +117,7 @@ api = [ "joblib>=1.3.2", ] wizard = [ - "streamlit>=1.38.0", + "streamlit>=1.39.0", "streamlit-aggrid>=0.3.4.post3", "streamlit-ace>=0.1.1", "streamlit-extras>=0.3.6", diff --git a/snapshots/wb/2023-07-10/education.py b/snapshots/wb/2023-07-10/education.py index 5cbde7de205..d298574ccd8 100644 --- a/snapshots/wb/2023-07-10/education.py +++ b/snapshots/wb/2023-07-10/education.py @@ -9,7 +9,7 @@ from owid.datautils.io import df_to_file from tqdm import tqdm -from etl.db import get_engine +from etl.db_utils import get_engine from etl.snapshot import Snapshot # Version for current snapshot dataset. diff --git a/tests/api/v1.py b/tests/api/v1.py index 1e28d9c1472..8c726886ef6 100644 --- a/tests/api/v1.py +++ b/tests/api/v1.py @@ -15,9 +15,9 @@ def test_health(): assert response.json() == {"status": "ok"} -@patch("etl.grapher_model.Variable.load_from_catalog_path") -def test_update_indicator(mock_load_from_catalog_path): - mock_load_from_catalog_path.return_value = gm.Variable( +@patch("etl.grapher_model.Variable.from_catalog_path") +def test_update_indicator(mock_from_catalog_path): + mock_from_catalog_path.return_value = gm.Variable( datasetId=1, description="", timespan="", @@ -29,7 +29,7 @@ def test_update_indicator(mock_load_from_catalog_path): dimensions=None, sourceId=None, ) - mock_load_from_catalog_path.id = 1 + mock_from_catalog_path.id = 1 response = client.put( "/api/v1/indicators", json={ diff --git a/tests/backport/datasync/test_data_metadata.py b/tests/backport/datasync/test_data_metadata.py index 2218a027f90..04ed691c15b 100644 --- a/tests/backport/datasync/test_data_metadata.py +++ b/tests/backport/datasync/test_data_metadata.py @@ -8,10 +8,10 @@ from apps.backport.datasync.data_metadata import ( _convert_strings_to_numeric, variable_data, - variable_data_df_from_s3, variable_metadata, ) from etl.db import get_engine +from etl.grapher_io import variable_data_df_from_s3 from etl.grapher_model import _infer_variable_type @@ -206,7 +206,7 @@ def test_variable_data_df_from_s3(): ) s3_data = pd.DataFrame({"entities": [1, 1], "values": ["a", 2], "years": [2000, 2001]}) - with mock.patch("apps.backport.datasync.data_metadata._fetch_entities", return_value=entities): + with mock.patch("etl.grapher_io._fetch_entities", return_value=entities): with mock.patch("pandas.read_json", return_value=s3_data): df = variable_data_df_from_s3(engine, [123]) diff --git a/tests/test_grapher_helpers.py b/tests/test_grapher_helpers.py index 5dcb9d4e423..b54ba3d2f6d 100644 --- a/tests/test_grapher_helpers.py +++ b/tests/test_grapher_helpers.py @@ -185,7 +185,7 @@ def _sample_table() -> Table: def test_adapt_table_for_grapher_multiindex(): with mock.patch("etl.grapher_helpers._get_entities_from_db") as mock_get_entities_from_db: - with mock.patch("apps.backport.datasync.data_metadata._fetch_entities") as mock_fetch_entities: + with mock.patch("etl.grapher_io._fetch_entities") as mock_fetch_entities: mock_get_entities_from_db.return_value = {"Poland": 1, "France": 2} mock_fetch_entities.return_value = pd.DataFrame( { diff --git a/uv.lock b/uv.lock index 1aa830b9a8b..d302b334863 100644 --- a/uv.lock +++ b/uv.lock @@ -777,6 +777,7 @@ dependencies = [ { name = "bugsnag" }, { name = "cdsapi" }, { name = "click" }, + { name = "deprecated" }, { name = "earthengine-api" }, { name = "fasteners" }, { name = "frictionless", extra = ["pandas"] }, @@ -887,6 +888,7 @@ requires-dist = [ { name = "bugsnag", specifier = ">=4.2.1" }, { name = "cdsapi", specifier = ">=0.7.0" }, { name = "click", specifier = ">=8.0.1" }, + { name = "deprecated", specifier = ">=1.2.14" }, { name = "earthengine-api", specifier = ">=0.1.411" }, { name = "fastapi", marker = "extra == 'api'", specifier = ">=0.109.0" }, { name = "fasteners", specifier = ">=0.19" }, @@ -935,7 +937,7 @@ requires-dist = [ { name = "slack-sdk", marker = "extra == 'api'", specifier = ">=3.26.2" }, { name = "sparqlwrapper", specifier = ">=1.8.5" }, { name = "sqlalchemy", specifier = ">=2.0.30" }, - { name = "streamlit", marker = "extra == 'wizard'", specifier = ">=1.38.0" }, + { name = "streamlit", marker = "extra == 'wizard'", specifier = ">=1.39.0" }, { name = "streamlit-ace", marker = "extra == 'wizard'", specifier = ">=0.1.1" }, { name = "streamlit-aggrid", marker = "extra == 'wizard'", specifier = ">=0.3.4.post3" }, { name = "streamlit-agraph", marker = "extra == 'wizard'", specifier = ">=0.0.45" }, @@ -4923,7 +4925,7 @@ wheels = [ [[package]] name = "streamlit" -version = "1.38.0" +version = "1.39.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "altair" }, @@ -4946,9 +4948,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "watchdog", marker = "platform_system != 'Darwin'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/de/dc/c4e77f5855538e11d0675b8f5e8922cb7d4928d8f877988ba1c7c63f02f6/streamlit-1.38.0.tar.gz", hash = "sha256:c4bf36b3ef871499ed4594574834583113f93f077dd3035d516d295786f2ad63", size = 8360969 } +sdist = { url = "https://files.pythonhosted.org/packages/d5/21/3740871ad79ee35f442f11bafec5010a3ec1916c7c9eb43ef866da641f31/streamlit-1.39.0.tar.gz", hash = "sha256:fef9de7983c4ee65c08e85607d7ffccb56b00482b1041fa62f90e4815d39df3a", size = 8360694 } wheels = [ - { url = "https://files.pythonhosted.org/packages/7c/2e/60e624bbe16f4baa45cb6e48a1ee05edd48a0a14cceec4d7eec9258755ac/streamlit-1.38.0-py2.py3-none-any.whl", hash = "sha256:0653ecfe86fef0f1608e3e082aef7eb335d8713f6f31e9c3b19486d1c67d7c41", size = 8741278 }, + { url = "https://files.pythonhosted.org/packages/ef/e1/f9c479f9dbe0bb702ea5ca6608f10e91a708b438f7fb4572a2642718c6e3/streamlit-1.39.0-py2.py3-none-any.whl", hash = "sha256:a359fc54ed568b35b055ff1d453c320735539ad12e264365a36458aef55a5fba", size = 8741335 }, ] [[package]]