From 8ed20c0780575b09cedc31e647614f5fe36d5041 Mon Sep 17 00:00:00 2001 From: Garrett Coakley Date: Mon, 5 Aug 2024 16:11:22 +0100 Subject: [PATCH] Add type hints to methods. Add tests for invalid values. --- README.md | 1 - django_countries_regions/__init__.py | 42 ++++++++++++++++++-------- tests/test_django_countries_regions.py | 24 +++++++++++++++ 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index d51e525..68fa235 100644 --- a/README.md +++ b/README.md @@ -43,7 +43,6 @@ In [8]: regions.countries_by_subregion('053') Out[8]: ['AU', 'NZ', 'NF'] ``` - ## Development To contribute to this library, first checkout the code. Then create a new virtual environment: diff --git a/django_countries_regions/__init__.py b/django_countries_regions/__init__.py index 8935048..7e038fb 100644 --- a/django_countries_regions/__init__.py +++ b/django_countries_regions/__init__.py @@ -2,7 +2,7 @@ from django_countries_regions.regions import REGIONS, SUBREGIONS -def country_region(country_code, region=True): +def country_region(country_code: str, region: bool = True) -> str | None: """Return a UN M49 region code for a country. Extends django-countries by adding a .region method to the Country field. @@ -12,10 +12,13 @@ def country_region(country_code, region=True): :return: String. UN M49 region code. """ code_to_use = "un_region_code" if region else "un_subregion_code" - return COUNTRY_REGIONS[country_code][code_to_use] + try: + return COUNTRY_REGIONS[country_code][code_to_use] + except KeyError: + return None -def country_subregion(country_code): +def country_subregion(country_code: str) -> str | None: """Return a UN M49 sub-region code for a country. Extends django-countries by adding a .subregion method to the Country field. @@ -23,7 +26,10 @@ def country_subregion(country_code): :param country_code: Two-letter ISO country code. :return: String. UN M49 sub-region code. """ - return country_region(country_code, region=False) + try: + return country_region(country_code, region=False) + except KeyError: + return None class Regions(): @@ -32,7 +38,7 @@ class Regions(): that region. """ - def countries_by_region(self, region_code, region=True): + def countries_by_region(self, region_code: str, region: bool = True) -> str | None: """Return a list of country codes found within a region. :param region_code: UN M49 region code. @@ -41,39 +47,51 @@ def countries_by_region(self, region_code, region=True): """ if region_code: if region: - return REGIONS[region_code]["countries"] + try: + return REGIONS[region_code]["countries"] + except KeyError: + return None else: return SUBREGIONS[region_code]["countries"] return None - def countries_by_subregion(self, region_code): + def countries_by_subregion(self, region_code: str) -> str | None: """Return a list of country codes found within a sub-region :param region_code: UN M49 sub-region code. :return: List of two-letter ISO country codes. """ if region_code: - return self.countries_by_region(region_code, region=False) + try: + return self.countries_by_region(region_code, region=False) + except KeyError: + return None return None - def region_name(self, region_code): + def region_name(self, region_code: str) -> str | None: """Return the region name :param region_code: UN M49 region code. :return: String """ if region_code: - return REGIONS[region_code]["name"] + try: + return REGIONS[region_code]["name"] + except KeyError: + return None return None - def subregion_name(self, region_code): + def subregion_name(self, region_code: str) -> str | None: """Return the region name :param region_code: UN M49 sub-region code. :return: String """ if region_code: - return SUBREGIONS[region_code]["name"] + try: + return SUBREGIONS[region_code]["name"] + except KeyError: + return None return None diff --git a/tests/test_django_countries_regions.py b/tests/test_django_countries_regions.py index 2df5f21..d2ccea4 100644 --- a/tests/test_django_countries_regions.py +++ b/tests/test_django_countries_regions.py @@ -20,6 +20,14 @@ def test_country_subregion(self): query = Country("AF").subregion self.assertEqual(query, "034") + def test_invalid_country_region(self): + query = Country("ZZ").region + self.assertIsNone(query) + + def test_invalid_country_subregion(self): + query = Country("ZZ").subregion + self.assertIsNone(query) + class TestRegions(TestCase): def test_countries_by_region(self): @@ -37,3 +45,19 @@ def test_region_name(self): def test_subregion_name(self): query = regions.subregion_name("053") self.assertEqual(query, "Australia and New Zealand") + + def test_invalid_countries_by_region(self): + query = regions.countries_by_region("900") + self.assertIsNone(query) + + def test_invalid_countries_by_subregion(self): + query = regions.countries_by_subregion("999") + self.assertIsNone(query) + + def test_invalid_region_name(self): + query = regions.region_name("900") + self.assertIsNone(query) + + def test_invalid_subregion_name(self): + query = regions.subregion_name("999") + self.assertIsNone(query)