diff --git a/.gitignore b/.gitignore index 1b6675b..d126c3c 100644 --- a/.gitignore +++ b/.gitignore @@ -107,6 +107,7 @@ ENV/ # downloaded files /bblocks/.raw_data/*.csv /bblocks/.raw_data/*.feather +/bblocks/.raw_data/*.json # Sphinx documentation docs/_build/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ab9b1e..76b618f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,11 @@ Changelog ========= +[1.2.0] - 2023-07-20 +-------------------- +- Added new feature: `world_bank_projects` module in `import_tools` with an object + to extract data from the World Bank Projects database. + [1.1.1] - 2023-07-06 -------------------- - Updated requirements diff --git a/bblocks/import_tools/world_bank_projects.py b/bblocks/import_tools/world_bank_projects.py new file mode 100644 index 0000000..45a88d5 --- /dev/null +++ b/bblocks/import_tools/world_bank_projects.py @@ -0,0 +1,562 @@ +"""World Bank Projects Database Importer""" + +import pandas as pd +import requests +import json +from dataclasses import dataclass +import re + +from bblocks.logger import logger +from bblocks.import_tools.common import ImportData +from bblocks.config import BBPaths +from bblocks.cleaning_tools import clean + + +class EmptyDataException(Exception): + """Exception raised when the API response does not contain any data.""" + + pass + + +BASE_API_URL = "https://search.worldbank.org/api/v2/projects" + + +class QueryAPI: + """Helper class for querying the World Bank Projects API""" + + def __init__( + self, + max_rows_per_response: int = 500, + start_date: str | None = None, + end_date: str | None = None, + fields: list[str] | str = "*", + ): + """Initialize QueryAPI object + + Args: + max_rows_per_response: maximum number of rows to return per request. + Must be less than or equal to 1000. + start_date: start date of projects to return. Format: YYYY-MM-DD + end_date: end date of projects to return. Format: YYYY-MM-DD + fields: fields to return. Can be a list of strings or a single string. + By default, all fields are returned ('*'). + """ + + self.max_rows_per_response = max_rows_per_response + self.start_date = start_date + self.end_date = end_date + self.fields = list(set(fields)) if isinstance(fields, list) else [fields] + + self._params = { + "format": "json", + "rows": self.max_rows_per_response, + # 'os': 0, # offset + "strdate": self.start_date, + "enddate": self.end_date, + "fl": self.fields, + } + + self._check_params() + + self.response_data = {} # initialize response_data as empty dict + + def _check_params(self) -> None: + """Check parameters""" + + # if end_date is before start_date, raise error. + if self._params["strdate"] is not None and self._params["enddate"] is not None: + if self._params["enddate"] < self._params["strdate"]: + raise ValueError("end date must be after start date") + + # if max_rows is greater than 1000, raise error + if self._params["rows"] > 1000: + raise ValueError("max_rows must be less than or equal to 1000") + + # if dates are None, drop them from params + if self._params["strdate"] is None: + # drop start_date from params + self._params.pop("strdate") + + if self._params["enddate"] is None: + # drop end_date from params + self._params.pop("enddate") + + def _request(self) -> dict: + """Single request to API. Returns the response json.""" + + try: + response = requests.get(BASE_API_URL, params=self._params) + response.raise_for_status() + data = response.json()["projects"] # keep only the projects data + + return data + + except Exception as e: + raise Exception(f"Failed to get data: {e}") + + def request_data(self) -> "QueryAPI": + """Request data from API + + This method will request all the data from the API + and store it in the response_data attribute. + It will automatically determine the request to make + based on the offset and number of rows parameters. + + Returns: + 'QueryAPI' to allow chaining of methods + """ + + self._params["os"] = 0 # reset offset to 0 + + while True: + # request data + data = self._request() + + # if there are no more projects, break + if len(data) == 0: + break + + # add data to response_data + self.response_data.update(data) + + # update offset + self._params["os"] += self._params["rows"] + + # Log if no data was returned from API + if len(self.response_data) == 0: + raise EmptyDataException("No data was returned from API") + + logger.info(f"Retrieved {len(self.response_data)} projects from API") + return self + + def get_data(self) -> dict[dict]: + """Get the data, or request it if it hasn't been requested yet.""" + + if len(self.response_data) == 0: + self.request_data() + + return self.response_data + + +def _append_theme_to_list( + proj_id: str, theme_list: list[dict], theme_names: list[str], theme: dict +) -> None: + """Appends a theme to the theme_list. + + Args: + proj_id: The project ID. + theme_list: The list of theme dictionaries to append to. + theme_names): The names of the parent themes. + theme: The theme to append. + """ + new_theme = { + "project ID": proj_id, + **{f"theme{idx + 1}": name for idx, name in enumerate(theme_names)}, + "percent": clean.clean_number(theme["percent"]), + } + theme_list.append(new_theme) + + +def _parse_themes( + proj_id: str, + theme_list: list[dict], + theme_names: list[str], + theme: dict, + theme_level: int, +) -> None: + """Recursive function to handle nested themes. + + Args: + proj_id (str): The project ID. + theme_list (list[dict]): The list of theme dictionaries to append to. + theme_names (list[str]): The names of the parent themes. + theme (dict): The current theme. + theme_level (int): The current level of theme nesting. + """ + # Append the current theme to the list + _append_theme_to_list( + proj_id=proj_id, theme_list=theme_list, theme_names=theme_names, theme=theme + ) + + # Recursively call this function for each nested theme + nested_theme_key = f"theme{theme_level + 1}" + for nested_theme in theme.get(nested_theme_key, []): + _parse_themes( + proj_id=proj_id, + theme_list=theme_list, + theme_names=theme_names + [nested_theme["name"]], + theme=nested_theme, + theme_level=theme_level + 1, + ) + + +def clean_theme(data: dict) -> list[dict] | list: + """Clean theme data from a nested list to a list of dictionaries with theme names and + percentages. + If there are no themes, an empty list will be returned. + + Args: + data: data from API + + Returns: + list of dictionaries with theme names and percentages + """ + + # if there are no themes, return an empty list + if "theme_list" not in data: + # return [{'project ID': proj_id}] + return [] + + theme_list = [] + proj_id = data["id"] + for theme1 in data["theme_list"]: + _parse_themes(proj_id, theme_list, [theme1["name"]], theme1, 1) + + return theme_list + + +def _get_sector_data(d: dict) -> dict: + """Get sector percentages from a project dictionary + + the function first finds all available sectors + It then finds all fields from the json starting with 'sector' and ending with a number + and gets a dictionary of the sector name and percentage + If there are any sectors missing from the dict and the total percentage is less than 100 + the missing sector is added with the remaining percentage. + If the total is still not 100, it will raise an error to indicate a + problem with the data. + + args: + d: project dictionary + """ + + sectors_dict = {} # empty dict to store sector data as {sector_name: percent} + sector_names = [v["Name"] for v in d["sector"]] # get list of sector names + + # get sectors fields which should contain percentages + sectors = {key: value for key, value in d.items() if re.search(r"^sector\d+$", key)} + + # get available sector percentages + for _, v in sectors.items(): + if isinstance(v, dict): + sectors_dict[v["Name"]] = v["Percent"] + + # check if there are missing sectors from the dict + if (len(sectors_dict) == len(sectors) - 1) and (sum(sectors_dict.values()) < 100): + # loop through all the available sectors + for s in sector_names: + # if a sectors has not been picked up it must be the missing sector + if s not in sectors_dict: + sectors_dict[s] = 100 - sum(sectors_dict.values()) + + if sum(sectors_dict.values()) != 100: + raise ValueError("Sector percentages don't add up to 100%") + + return sectors_dict + + +GENERAL_FIELDS = { # general info + "id": "project ID", + "project_name": "project name", + "countryshortname": "country", + "regionname": "region name", + "url": "url", + "teamleadname": "team leader", + "status": "status", + "last_stage_reached_name": "last stage reached", + "pdo": "project development objective", + "cons_serv_reqd_ind": "consulting services required", + "envassesmentcategorycode": "environmental assesment category", + "esrc_ovrl_risk_rate": "environmental and social risk", + "transactiontype:": "transaction type", + "financier_loan": "financier loan", + "interestandcharges": "interest and charges", + # dates + "approvalfy": "fiscal year", + "boardapprovaldate": "board approval date", + "closingdate": "closing date", + "p2a_updated_date": "update date", + # lending + "lendinginstr": "lending instrument", + "projectfinancialtype": "financing type", + "loantype": "loan type", + "loantypedesc": "loan type description", + "borrower": "borrower", + "impagency": "implementing agency", + "lendprojectcost": "project cost", + "totalcommamt": "total commitment", + "grantamt": "grant amount", + "idacommamt": "IDA commitment amount", + "ibrdcommamt": "IBRD commitment amount", + "curr_project_cost": "current project cost", + "curr_total_commitment": "current total IBRD and IDA commitment", + "curr_ibrd_commitment": "current IBRD commitment", + "curr_ida_commitment": "current IDA commitment", + "repayment": "repayment", +} + +OTHER_FIELDS = { + "projectstatusdisplay": "project status display", + "sector1": "sector1", + "sector2": "sector2", + "sector3": "sector3", + "sector4": "sector4", + "sector5": "sector5", + "sector6": "sector6", + "sector7": "sector7", + "sector8": "sector8", + "sector": "sector", + "theme1": "theme1", + "theme2": "theme2", + "theme3": "theme3", + "theme4": "theme4", + "theme5": "theme5", + "fiscal_year": "fiscal year", + "fiscalyear": "fiscal year", + "fiscalyear_budget": "fiscal year budget", + "project_abstract": "project abstract", + "sectorlist": "sectorlist", + "theme_list": "theme_list", +} + + +@dataclass +class WorldBankProjects(ImportData): + """World Bank Projects Database Importer + + This object will import the World Bank Projects database from the World Bank API. + To use, create an instance of the class. Optionally, you can specify the start and end dates + of the data to import. If no dates are specified, all data will be imported. + To import the data, call the load_data method. If the data has already downloaded, it will + be loaded to the object from disk, otherwise it will be downloaded from the API. + To retrieve the data, call the get_data method. You can specify the type of data to retrieve, + either 'general' or 'theme'. If no type is specified, 'general' data will be returned. + To update the data, call the update_data method. This will download the data from the API. + If 'reload' is set to True, the data will be reloaded to the object. + + Parameters: + start_date: start date of data to import, in YYYY-MM-DD format + end_date: end date of data to import, in YYYY-MM-DD format. + """ + + start_date: str | None = None + end_date: str | None = None + + @property + def _path(self): + """Generate path based on version""" + + start_date = f"_{self.start_date}" if self.start_date is not None else "" + end_date = f"_{self.end_date}" if self.end_date is not None else "" + + return BBPaths.raw_data / f"world_bank_projects{start_date}{end_date}.json" + + def _format_general_data(self, additional_fields: list = None) -> None: + """Clean and format general data and store it in _data attribute with key 'general_data'""" + + numeric_cols = [ + "lendprojectcost", + "totalcommamt", + "grantamt", + "idacommamt", + "ibrdcommamt", + "curr_total_commitment", + "curr_ibrd_commitment", + "curr_ida_commitment", + "curr_project_cost", + ] + + self._data["general_data"] = ( + pd.DataFrame.from_dict(self._raw_data, orient="index") + .reset_index(drop=True) + .filter(list(GENERAL_FIELDS) + additional_fields, axis=1) + # change the fiscal year to int + .assign( + approvalfy=lambda d: clean.clean_numeric_series(d["approvalfy"], to=int) + ) + # change numeric columns to float + .pipe(clean.clean_numeric_series, series_columns=numeric_cols) + .assign( # format dates + boardapprovaldate=lambda d: clean.to_date_column( + d["boardapprovaldate"] + ), + closingdate=lambda d: clean.to_date_column(d["closingdate"]), + p2a_updated_date=lambda d: clean.to_date_column(d["p2a_updated_date"]), + ) + # rename columns + .rename(columns=GENERAL_FIELDS) + ) + + def _format_theme_data(self) -> None: + """Format theme data and store it as a dataframe in _data attribute with key 'theme_data'""" + + theme_data = [] + for _, proj_data in self._raw_data.items(): + theme_data.extend(clean_theme(proj_data)) + + self._data["theme_data"] = pd.DataFrame(theme_data).filter( + ["project ID", "theme1", "theme2", "theme3", "theme4", "theme5", "percent"], + axis=1, + ) + + def _format_sector_data(self) -> None: + """Format sector data and store it as a dataframe in _data attribute + with key 'sector_data'""" + + sector_data = [] + for _, proj_data in self._raw_data.items(): + if "sector" in proj_data: + proj_id = proj_data["id"] + + sectors = _get_sector_data(proj_data) + sector_data.extend( + [ + {"project ID": proj_id, "sector": s, "percent": p} + for s, p in sectors.items() + ] + ) + + self._data["sector_data"] = pd.DataFrame(sector_data) + + def _download(self, additional_fields: list | None = None) -> None: + """Download data from World Bank Projects API and save it as a json file.""" + + logger.info(f"Starting download of World Bank Projects") + + if additional_fields is None: + additional_fields = [] + if isinstance(additional_fields, str): + additional_fields = [additional_fields] + + with open(self._path, "w") as file: + data = ( + QueryAPI( + start_date=self.start_date, + end_date=self.end_date, + fields=list(GENERAL_FIELDS) + + list(OTHER_FIELDS) + + additional_fields, + ) + .request_data() + .get_data() + ) + json.dump(data, file) + + logger.info(f"Successfully downloaded World Bank Projects") + + def load_data(self, *, additional_fields: str | list = None) -> ImportData: + """Load data to the object + + This method will load the World Bank Project data to the object. + If the data has already downloaded, it will be loaded to the object from disk, + otherwise it will be downloaded from the API and saved as a json file and loaded + to the object. + + Args: + additional_fields: additional fields to download from the API. If the data has + already been downloaded, the additional fields may not be loaded if they do not + exist in the downloaded file. To force download of data with additional fields, + use the update_data method passing the additional fields as argument + + Returns: + object with loaded data + """ + + # if additional fields are set but the data is read from disk, log a warning + if self._path.exists() and additional_fields is not None: + logger.warning( + "Data already exists in disk. The additional fields might not be " + "loaded if they do not exist in the downloaded data. To force download " + "of data with additional fields, use the update_data method passing the " + "additional fields as argument" + ) + + # check if additional fields is a string or None and convert to list + if additional_fields is None: + additional_fields = [] + if isinstance(additional_fields, str): + additional_fields = [additional_fields] + + # if file does not exist, download it and save it as a json file + if not self._path.exists(): + self._download(additional_fields=additional_fields) + + # load data from json file + with open(self._path, "r") as file: + self._raw_data = json.load(file) + + if self._raw_data is None: + raise EmptyDataException("No data was retrieved") + + # set data + self._format_general_data(additional_fields=additional_fields) + self._format_theme_data() + self._format_sector_data() + + logger.info(f"Successfully loaded World Bank Projects") + return self + + def update_data( + self, reload: bool = True, *, additional_fields: str | list = None + ) -> ImportData: + """Force update of data + + This method will download the data from the API. + If 'reload' is set to True, the data will be reloaded to the object. + + Args: + reload: if True, reload data to object after downloading it. + additional_fields: additional fields to download + + Returns: + object with updated data + """ + + self._download(additional_fields=additional_fields) + if reload: + self.load_data(additional_fields=additional_fields) + + return self + + def get_data( + self, project_codes: str | list = "all", data_type: str = "general", **kwargs + ) -> pd.DataFrame: + """Get the data as a dataframe + + Get the general data, or the theme data for World Bank Projects as a dataframe. + Optionally, you can specify the project codes to retrieve data for. If no project codes + are specified, data for all projects will be returned. + + Args: + project_codes: project codes to retrieve data for. If 'all', data for all projects + will be returned + data_type: type of data to retrieve. Either 'general', 'sector' or 'theme' + + Returns: + dataframe with the requested data + """ + + # check if data has been loaded + if len(self._data) == 0: + raise EmptyDataException("Data has not been loaded. Run load_data() first.") + + if data_type == "general": + df = self._data["general_data"] + elif data_type == "theme": + df = self._data["theme_data"] + elif data_type == "sector": + df = self._data["sector_data"] + else: + raise ValueError("data_type must be either 'general', 'theme' or 'sector'") + + if project_codes != "all": + if isinstance(project_codes, str): + project_codes = [project_codes] + df = df[df["project ID"].isin(project_codes)] + + return df + + def get_json(self) -> dict: + """Return the raw data as a dictionary""" + + return self._raw_data diff --git a/tests/test_import_tools/test_world_bank_projects.py b/tests/test_import_tools/test_world_bank_projects.py new file mode 100644 index 0000000..08564e9 --- /dev/null +++ b/tests/test_import_tools/test_world_bank_projects.py @@ -0,0 +1,322 @@ +"""Tests for the world_bank_projects module.""" + +import pytest +import requests +from unittest.mock import Mock, patch, MagicMock + +from bblocks.import_tools import world_bank_projects + + +class TestQueryAPI: + """Test QueryAPI class.""" + + def test_init(self): + """Test initialization of QueryAPI object.""" + + # test that error is raised if end_date is before start_date + with pytest.raises(ValueError): + world_bank_projects.QueryAPI(start_date="2020-01-01", end_date="2019-01-01") + + # test that error is raised if max_rows_per_response is greater than 1000 + with pytest.raises(ValueError): + world_bank_projects.QueryAPI(max_rows_per_response=1001) + + # test that start_date is dropped if end_date is None + assert ( + "strdate" + not in world_bank_projects.QueryAPI( + end_date="2020-01-01", start_date=None + )._params + ) + + # test that end_date is dropped if start_date is None + assert ( + "enddate" + not in world_bank_projects.QueryAPI( + start_date="2020-01-01", end_date=None + )._params + ) + + def test_request(self): + """ """ + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "projects": {"P1234": {"name": "Test Project"}} + } + + with patch("requests.get", return_value=mock_response) as mock_get: + assert world_bank_projects.QueryAPI()._request() == { + "P1234": {"name": "Test Project"} + } + + def test_request_error(self): + """Test that error is raised if request fails.""" + + with patch("requests.get") as mock_get: + mock_get.return_value.raise_for_status.side_effect = ( + requests.exceptions.HTTPError + ) + mock_get.return_value.status_code = 404 + mock_get.json.return_value = { + "projects": {"P1234": {"name": "Test Project"}} + } + + with pytest.raises(Exception): + world_bank_projects.QueryAPI()._request() + + def test_request_data_no_data(self): + """Test request_data method.""" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "projects": {} + } # test that empty response is handled + + with pytest.raises(world_bank_projects.EmptyDataException): + with patch("requests.get", return_value=mock_response) as mock_get: + obj = world_bank_projects.QueryAPI() + obj.request_data() + + def test_request_data(self): + """Test request_data method.""" + + # Mocking the requests.get function + mocked_get = MagicMock( + side_effect=[ + Mock( + json=MagicMock( + return_value={ + "projects": { + "P1": {"name": "Test Project 1"}, + "P2": {"name": "Test Project 2"}, + } + } + ) + ), + Mock( + json=MagicMock( + return_value={"projects": {"P3": {"name": "Test Project 3"}}} + ) + ), + Mock(json=MagicMock(return_value={"projects": {}})), + ] + ) + + with patch("bblocks.import_tools.world_bank_projects.requests.get", mocked_get): + obj = world_bank_projects.QueryAPI() + obj.request_data() + + assert obj.response_data == { + "P1": {"name": "Test Project 1"}, + "P2": {"name": "Test Project 2"}, + "P3": {"name": "Test Project 3"}, + } + + +def test_clean_theme(): + """Test clean_theme function.""" + + test_data_dict = { + "id": "P1234", + "theme_list": [ + { + "name": "Environment and Natural Resource Management", + "code": "80", + "seqno": "14", + "percent": "34", + "theme2": [ + { + "name": "Energy", + "code": "86", + "seqno": "18", + "percent": "13", + "theme3": [ + { + "name": "Energy Efficiency", + "code": "861", + "seqno": "34", + "percent": "13", + }, + { + "name": "Energy Policies & Reform", + "code": "862", + "seqno": "35", + "percent": "13", + }, + ], + }, + { + "name": "Environmental policies and institutions", + "code": "84", + "seqno": "17", + "percent": "13", + }, + { + "name": "Environmental Health and Pollution Management", + "code": "82", + "seqno": "16", + "percent": "13", + "theme3": [ + { + "name": "Air quality management", + "code": "821", + "seqno": "33", + "percent": "13", + } + ], + }, + { + "name": "Climate change", + "code": "81", + "seqno": "15", + "percent": "34", + "theme3": [ + { + "name": "Adaptation", + "code": "812", + "seqno": "32", + "percent": "8", + }, + { + "name": "Mitigation", + "code": "811", + "seqno": "31", + "percent": "26", + }, + ], + }, + ], + } + ], + } + + formatted = [ + { + "project ID": "P1234", + "theme1": "Environment and Natural Resource Management", + "percent": 34, + }, + { + "project ID": "P1234", + "theme1": "Environment and Natural Resource Management", + "theme2": "Energy", + "percent": 13, + }, + { + "project ID": "P1234", + "theme1": "Environment and Natural Resource Management", + "theme2": "Energy", + "theme3": "Energy Efficiency", + "percent": 13, + }, + { + "project ID": "P1234", + "theme1": "Environment and Natural Resource Management", + "theme2": "Energy", + "theme3": "Energy Policies & Reform", + "percent": 13, + }, + { + "project ID": "P1234", + "theme1": "Environment and Natural Resource Management", + "theme2": "Environmental policies and institutions", + "percent": 13, + }, + { + "project ID": "P1234", + "theme1": "Environment and Natural Resource Management", + "theme2": "Environmental Health and Pollution Management", + "percent": 13, + }, + { + "project ID": "P1234", + "theme1": "Environment and Natural Resource Management", + "theme2": "Environmental Health and Pollution Management", + "theme3": "Air quality management", + "percent": 13, + }, + { + "project ID": "P1234", + "theme1": "Environment and Natural Resource Management", + "theme2": "Climate change", + "percent": 34, + }, + { + "project ID": "P1234", + "theme1": "Environment and Natural Resource Management", + "theme2": "Climate change", + "theme3": "Adaptation", + "percent": 8, + }, + { + "project ID": "P1234", + "theme1": "Environment and Natural Resource Management", + "theme2": "Climate change", + "theme3": "Mitigation", + "percent": 26, + }, + ] + + assert world_bank_projects.clean_theme(test_data_dict) == formatted + + +def test_clean_theme_no_theme(): + """Test clean_theme function with no theme.""" + + test_data_dict = {"id": "P1234"} + assert world_bank_projects.clean_theme(test_data_dict) == [] + + +def test_get_sector_data(): + """test the get_sector_data function.""" + + d = { + "id": "P1", + "sector": [ + {"Name": "Agriculture, fishing, and forestry", "code": "BX"}, + {"Name": "Agricultural extension and research", "code": "AX"}, + ], + "sector1": {"Name": "Agriculture, fishing, and forestry", "Percent": 50}, + "sector2": {"Name": "Agricultural extension and research", "Percent": 50}, + } + + expected = { + "Agriculture, fishing, and forestry": 50, + "Agricultural extension and research": 50, + } + + assert world_bank_projects._get_sector_data(d) == expected + + +def test_get_sector_data_missing_sector(): + """Test the get_sector_data function with missing sector.""" + + d = { + "id": "P2", + "sector": [ + {"Name": "Agriculture, fishing, and forestry", "code": "BX"}, + {"Name": "Agricultural extension and research", "code": "AX"}, + {"Name": "Missing sector", "code": "XX"}, + ], + "sector1": {"Name": "Agriculture, fishing, and forestry", "Percent": 40}, + "sector2": {"Name": "Agricultural extension and research", "Percent": 50}, + "sector3": "Missing sector", + } + + expected = { + "Agriculture, fishing, and forestry": 40, + "Agricultural extension and research": 50, + "Missing sector": 10, + } + + assert world_bank_projects._get_sector_data(d) == expected + + +def test_get_data_no_data_loaded(): + """Test the get_data function with no data loaded.""" + + with pytest.raises(world_bank_projects.EmptyDataException): + proj = world_bank_projects.WorldBankProjects() + proj.get_data()