Skip to content

Commit

Permalink
Merge pull request #1473 from dmach/nested-models
Browse files Browse the repository at this point in the history
Support nested models + related fixes
  • Loading branch information
dmach authored Jan 23, 2024
2 parents 0e6117a + 7903ade commit 0103634
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 9 deletions.
62 changes: 53 additions & 9 deletions osc/util/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import get_type_hints

# supported types
from enum import Enum
from typing import Any
from typing import Dict
from typing import List
Expand Down Expand Up @@ -40,6 +41,7 @@ def get_origin(typ):
"Field",
"NotSet",
"FromParent",
"Enum",
"Dict",
"List",
"NewType",
Expand Down Expand Up @@ -125,9 +127,26 @@ def origin_type(self):
origin_type = get_origin(self.type) or self.type
if self.is_optional:
types = [i for i in self.type.__args__ if i != type(None)]
return types[0]
return get_origin(types[0]) or types[0]
return origin_type

@property
def inner_type(self):
if self.is_optional:
types = [i for i in self.type.__args__ if i != type(None)]
type_ = types[0]
else:
type_ = self.type

if get_origin(type_) != list:
return None

if not hasattr(type_, "__args__"):
return None

inner_type = [i for i in type_.__args__ if i != type(None)][0]
return inner_type

@property
def is_optional(self):
origin_type = get_origin(self.type) or self.type
Expand All @@ -137,6 +156,10 @@ def is_optional(self):
def is_model(self):
return inspect.isclass(self.origin_type) and issubclass(self.origin_type, BaseModel)

@property
def is_model_list(self):
return inspect.isclass(self.inner_type) and issubclass(self.inner_type, BaseModel)

def validate_type(self, value, expected_types=None):
if not expected_types and self.is_optional and value is None:
return True
Expand Down Expand Up @@ -176,6 +199,15 @@ def validate_type(self, value, expected_types=None):
valid_type = True
continue

if (
inspect.isclass(expected_type)
and issubclass(expected_type, Enum)
):
# test if the value is part of the enum
expected_type(value)
valid_type = True
continue

if not isinstance(value, origin_type):
msg = f"Field '{self.name}' has type '{self.type}'. Cannot assign a value with type '{type(value).__name__}'."
raise TypeError(msg)
Expand Down Expand Up @@ -241,9 +273,17 @@ def get(self, obj):
def set(self, obj, value):
# if this is a model field, convert dict to a model instance
if self.is_model and isinstance(value, dict):
new_value = self.origin_type() # pylint: disable=not-callable
for k, v in value.items():
setattr(new_value, k, v)
# initialize a model instance from a dictionary
klass = self.origin_type
value = klass(**value) # pylint: disable=not-callable
elif self.is_model_list and isinstance(value, list):
new_value = []
for i in value:
if isinstance(i, dict):
klass = self.inner_type
new_value.append(klass(**i))
else:
new_value.append(i)
value = new_value

self.validate_type(value)
Expand Down Expand Up @@ -311,12 +351,12 @@ def __init__(self, **kwargs):

if kwargs:
unknown_fields_str = ", ".join([f"'{i}'" for i in kwargs])
raise TypeError(f"The following kwargs do not match any field: {unknown_fields_str}")
raise TypeError(f"The following kwargs of '{self.__class__.__name__}.__init__()' do not match any field: {unknown_fields_str}")

if uninitialized_fields:
uninitialized_fields_str = ", ".join([f"'{i}'" for i in uninitialized_fields])
raise TypeError(
f"The following fields are not initialized and have no default either: {uninitialized_fields_str}"
f"The following fields of '{self.__class__.__name__}' object are not initialized and have no default either: {uninitialized_fields_str}"
)

for name, field in self.__fields__.items():
Expand All @@ -329,8 +369,12 @@ def dict(self):
for name, field in self.__fields__.items():
if field.exclude:
continue
if field.is_model:
result[name] = getattr(self, name).dict()
value = getattr(self, name)
if value is not None and field.is_model:
result[name] = value.dict()
if value is not None and field.is_model_list:
result[name] = [i.dict() for i in value]
else:
result[name] = getattr(self, name)
result[name] = value

return result
77 changes: 77 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,14 +186,91 @@ class TestModel(BaseModel):
self.assertEqual(field.is_optional, True)
self.assertEqual(field.origin_type, TestSubmodel)
self.assertEqual(m.field, None)
m.dict()

m = TestModel(field=TestSubmodel())
self.assertIsInstance(m.field, TestSubmodel)
self.assertEqual(m.field.text, "default")
m.dict()

m = TestModel(field={"text": "text"})
self.assertNotEqual(m.field, None)
self.assertEqual(m.field.text, "text")
m.dict()

def test_list_submodels(self):
class TestSubmodel(BaseModel):
text: str = Field(default="default")

class TestModel(BaseModel):
field: List[TestSubmodel] = Field(default=[])

m = TestModel()

field = m.__fields__["field"]
self.assertEqual(field.is_model, False)
self.assertEqual(field.is_model_list, True)
self.assertEqual(field.is_optional, False)
self.assertEqual(field.origin_type, list)
m.dict()

m = TestModel(field=[TestSubmodel()])
self.assertEqual(m.field[0].text, "default")
m.dict()

m = TestModel(field=[{"text": "text"}])
self.assertEqual(m.field[0].text, "text")
m.dict()

self.assertRaises(TypeError, getattr(m, "field"))

def test_optional_list_submodels(self):
class TestSubmodel(BaseModel):
text: str = Field(default="default")

class TestModel(BaseModel):
field: Optional[List[TestSubmodel]] = Field(default=[])

m = TestModel()

field = m.__fields__["field"]
self.assertEqual(field.is_model, False)
self.assertEqual(field.is_model_list, True)
self.assertEqual(field.is_optional, True)
self.assertEqual(field.origin_type, list)
m.dict()

m = TestModel(field=[TestSubmodel()])
self.assertEqual(m.field[0].text, "default")
m.dict()

m = TestModel(field=[{"text": "text"}])
self.assertEqual(m.field[0].text, "text")
m.dict()

m.field = None
self.assertEqual(m.field, None)
m.dict()

def test_enum(self):
class Numbers(Enum):
one = "one"
two = "two"

class TestModel(BaseModel):
field: Optional[Numbers] = Field(default=None)

m = TestModel()
field = m.__fields__["field"]
self.assertEqual(field.is_model, False)
self.assertEqual(field.is_optional, True)
self.assertEqual(field.origin_type, Numbers)
self.assertEqual(m.field, None)

m.field = "one"
self.assertEqual(m.field, "one")

self.assertRaises(ValueError, setattr, m, "field", "does-not-exist")

def test_parent(self):
class ParentModel(BaseModel):
Expand Down

0 comments on commit 0103634

Please sign in to comment.