Skip to content

Commit

Permalink
Merge pull request #63 from e2nIEE/fix/variant_creation
Browse files Browse the repository at this point in the history
[BREAKING / Fix] more robust variant creation, careful is a breaking change!
  • Loading branch information
vogt31337 authored Nov 7, 2024
2 parents b4a2b20 + 7e8addf commit c394abc
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 31 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Change Log

## [0.3.10]

- BREAKING: changed signatures of create_variant function and route

## [0.3.9]

- BREAKING: inserting multiple networks with the same name does not represent an error anymore, networks are only unique by their net_id (_id field of the collection)
Expand Down
20 changes: 15 additions & 5 deletions pandahub/api/routers/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,23 @@ def get_variants(data: GetVariantsModel, ph=Depends(pandahub)):

class CreateVariantModel(BaseModel):
project_id: str
variant_data: dict
net_id: int
name: str | None = None
default_name: str | None = None


class CreateVariantResponseModel(BaseModel):
net_id: int
index: int
name: str | None = None
date_created: int
date_changed: int


@router.post("/create_variant")
def create_variant(data: CreateVariantModel, ph=Depends(pandahub)):
project_id = data.project_id
ph.set_active_project_by_id(project_id)
return ph.create_variant(data.variant_data)
def create_variant(data: CreateVariantModel, ph=Depends(pandahub)) -> CreateVariantResponseModel:
ph.set_active_project_by_id(data.project_id)
return ph.create_variant(net_id=data.net_id, name=data.name, default_name=data.default_name)

class DeleteVariantModel(BaseModel):
project_id: str
Expand Down
70 changes: 45 additions & 25 deletions pandahub/lib/PandaHub.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import builtins
import json
import logging
import time
import warnings
from inspect import signature, _empty
from collections.abc import Callable
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(
user_id=None,
datatypes=DATATYPES,
mongodb_indexes=MONGODB_INDEXES,
elements_without_vars = None,
):
mongo_client_args = {
"host": connection_url,
Expand All @@ -115,6 +117,7 @@ def __init__(
{"var_type": np.nan},
]
}
self._elements_without_vars = ["variant"] if elements_without_vars is None else elements_without_vars
if check_server_available:
self.server_is_available()

Expand Down Expand Up @@ -1150,10 +1153,13 @@ def write_network_to_db(
if element_data.empty:
continue
element_data = element_data.copy(deep=True)
if "var_type" in element_data:
element_data["var_type"] = element_data["var_type"].fillna("base")
else:
element_data["var_type"] = "base"
if element not in self._elements_without_vars:
if "var_type" in element_data:
element_data["var_type"] = element_data["var_type"].fillna("base")
else:
element_data["var_type"] = "base"
element_data["not_in_var"] = np.empty((len(element_data.index), 0)).tolist()
element_data["variant"] = None
element_data = convert_element_to_dict(element_data, net_id, self._datatypes.get(element))
self._write_element_to_db(db, element, element_data)

Expand Down Expand Up @@ -1228,9 +1234,12 @@ def _get_net_id_from_name(self, name, db):
def _network_with_name_exists(self, name, db):
return self._get_net_id_from_name(name, db) is not None

def _get_net_collections(self, db, with_areas=True):
return self.get_net_collections(db, with_areas)

def get_net_collections(self, db=None, with_areas=True):
if db is None:
db = self._get_project_database()
db = self.get_project_database()
if with_areas:
collection_filter = {"name": {"$regex": "^net_"}}
else:
Expand Down Expand Up @@ -1402,7 +1411,7 @@ def delete_elements(
"""
if not isinstance(element_indexes, list):
raise TypeError("Parameter element_indexes must be a list of ints!")
validate_variant_type(variant)
self.validate_variant(variant, element_type)
if project_id:
self.set_active_project_by_id(project_id)

Expand Down Expand Up @@ -1446,7 +1455,7 @@ def set_net_value_in_db(
project_id=None,
**kwargs,
):
validate_variant_type(variant)
self.validate_variant(variant, element_type)
if project_id:
self.set_active_project_by_id(project_id)
self.check_permission("write")
Expand Down Expand Up @@ -1514,7 +1523,7 @@ def set_object_attribute(
variant=None,
project_id=None,
):
validate_variant_type(variant)
self.validate_variant(variant, element_type)
if project_id:
self.set_active_project_by_id(project_id)
self.check_permission("write")
Expand Down Expand Up @@ -1636,7 +1645,7 @@ def create_elements(
list
A list of the created elements (elements_data with added _id fields)
"""
validate_variant_type(variant)
self.validate_variant(variant, element_type)
if project_id:
self.set_active_project_by_id(project_id)
self.check_permission("write")
Expand Down Expand Up @@ -1771,24 +1780,30 @@ def _get_int_index(self, collection: str, index_field: str = "_id", query_filter
# Variants
# -------------------------

def create_variant(self, data, index: Optional[int] = None):
def create_variant(
self,
net_id: int,
name: str | None = None,
default_name: str = "Variant",
index: int | None = None,
) -> dict:
db = self._get_project_database()
net_id = int(data["net_id"])
if index is None:
index = self._get_int_index("variant", index_field="index", query_filter={"net_id": net_id})
data["index"] = index
if data.get("default_name") is not None and data.get("name") is None:
data["name"] = data.pop("default_name") + " " + str(index)
db["variant"].insert_one(data)
del data["_id"]

if index == 1:
for coll in self.get_net_collections(db):
db[coll].update_many(
{"$or": [{"var_type": None}, {"var_type": np.nan}]},
{"$set": {"var_type": "base", "not_in_var": [], "variant": None}},
)
return data

if name is None:
name = f"{default_name} {index + 1}"
now = int(time.time())
variant = {
"net_id": net_id,
"index": index,
"name": name,
"date_created": now,
"date_changed": now,
}
db["variant"].insert_one(variant)
del variant["_id"]
return variant

def delete_variant(self, net_id, index):
db = self._get_project_database()
Expand Down Expand Up @@ -1828,12 +1843,17 @@ def get_variant_filter(self, variant: int | None) -> dict:
dict
mongodb query filter for the given variant
"""
validate_variant_type(variant)
self.validate_variant(variant)
if variant is None:
return self.base_variant_filter
return {"$or": [{"var_type": "base", "not_in_var": {"$ne": variant}},
{"var_type": {"$in": ["change", "addition"]}, "variant": variant}, ]}

def validate_variant(self, variant: int | None, element_type: str | None = None):
"""Raise a ValueError if variant is not int | None or element_type does not support variants."""
validate_variant_type(variant)
if element_type is not None and variant is not None and element_type in self._elements_without_vars:
raise ValueError(f"{element_type} does not support variants")

# -------------------------
# Bulk operations
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "pandahub"
version = "0.3.9" # File format version '__format_version__' is tracked in __init__.py
version = "0.3.10" # File format version '__format_version__' is tracked in __init__.py
authors=[
{ name = "Jan Ulffers", email = "[email protected]" },
{ name = "Leon Thurner", email = "[email protected]" },
Expand Down

0 comments on commit c394abc

Please sign in to comment.