Skip to content

Commit

Permalink
Add type hints to methods. Add tests for invalid values.
Browse files Browse the repository at this point in the history
  • Loading branch information
garrettc committed Aug 5, 2024
1 parent 3ca03e3 commit 8ed20c0
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 13 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ In [8]: regions.countries_by_subregion('053')
Out[8]: ['AU', 'NZ', 'NF']
```


## Development

To contribute to this library, first checkout the code. Then create a new virtual environment:
Expand Down
42 changes: 30 additions & 12 deletions django_countries_regions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from django_countries_regions.regions import REGIONS, SUBREGIONS


def country_region(country_code, region=True):
def country_region(country_code: str, region: bool = True) -> str | None:
"""Return a UN M49 region code for a country.
Extends django-countries by adding a .region method to the Country field.
Expand All @@ -12,18 +12,24 @@ def country_region(country_code, region=True):
:return: String. UN M49 region code.
"""
code_to_use = "un_region_code" if region else "un_subregion_code"
return COUNTRY_REGIONS[country_code][code_to_use]
try:
return COUNTRY_REGIONS[country_code][code_to_use]
except KeyError:
return None


def country_subregion(country_code):
def country_subregion(country_code: str) -> str | None:
"""Return a UN M49 sub-region code for a country.
Extends django-countries by adding a .subregion method to the Country field.
:param country_code: Two-letter ISO country code.
:return: String. UN M49 sub-region code.
"""
return country_region(country_code, region=False)
try:
return country_region(country_code, region=False)
except KeyError:
return None


class Regions():
Expand All @@ -32,7 +38,7 @@ class Regions():
that region.
"""

def countries_by_region(self, region_code, region=True):
def countries_by_region(self, region_code: str, region: bool = True) -> str | None:
"""Return a list of country codes found within a region.
:param region_code: UN M49 region code.
Expand All @@ -41,39 +47,51 @@ def countries_by_region(self, region_code, region=True):
"""
if region_code:
if region:
return REGIONS[region_code]["countries"]
try:
return REGIONS[region_code]["countries"]
except KeyError:
return None
else:
return SUBREGIONS[region_code]["countries"]
return None

def countries_by_subregion(self, region_code):
def countries_by_subregion(self, region_code: str) -> str | None:
"""Return a list of country codes found within a sub-region
:param region_code: UN M49 sub-region code.
:return: List of two-letter ISO country codes.
"""
if region_code:
return self.countries_by_region(region_code, region=False)
try:
return self.countries_by_region(region_code, region=False)
except KeyError:
return None
return None

def region_name(self, region_code):
def region_name(self, region_code: str) -> str | None:
"""Return the region name
:param region_code: UN M49 region code.
:return: String
"""
if region_code:
return REGIONS[region_code]["name"]
try:
return REGIONS[region_code]["name"]
except KeyError:
return None
return None

def subregion_name(self, region_code):
def subregion_name(self, region_code: str) -> str | None:
"""Return the region name
:param region_code: UN M49 sub-region code.
:return: String
"""
if region_code:
return SUBREGIONS[region_code]["name"]
try:
return SUBREGIONS[region_code]["name"]
except KeyError:
return None
return None


Expand Down
24 changes: 24 additions & 0 deletions tests/test_django_countries_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ def test_country_subregion(self):
query = Country("AF").subregion
self.assertEqual(query, "034")

def test_invalid_country_region(self):
query = Country("ZZ").region
self.assertIsNone(query)

def test_invalid_country_subregion(self):
query = Country("ZZ").subregion
self.assertIsNone(query)


class TestRegions(TestCase):
def test_countries_by_region(self):
Expand All @@ -37,3 +45,19 @@ def test_region_name(self):
def test_subregion_name(self):
query = regions.subregion_name("053")
self.assertEqual(query, "Australia and New Zealand")

def test_invalid_countries_by_region(self):
query = regions.countries_by_region("900")
self.assertIsNone(query)

def test_invalid_countries_by_subregion(self):
query = regions.countries_by_subregion("999")
self.assertIsNone(query)

def test_invalid_region_name(self):
query = regions.region_name("900")
self.assertIsNone(query)

def test_invalid_subregion_name(self):
query = regions.subregion_name("999")
self.assertIsNone(query)

0 comments on commit 8ed20c0

Please sign in to comment.