diff --git a/CHANGELOG.md b/CHANGELOG.md index 69f482f..79e6b66 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,9 @@ # Change Log +## [0.3.10] + + - BREAKING: changed signatures of create_variant function and route + ## [0.3.9] - BREAKING: inserting multiple networks with the same name does not represent an error anymore, networks are only unique by their net_id (_id field of the collection) diff --git a/pandahub/api/routers/variants.py b/pandahub/api/routers/variants.py index 22e92ca..9526798 100644 --- a/pandahub/api/routers/variants.py +++ b/pandahub/api/routers/variants.py @@ -30,13 +30,23 @@ def get_variants(data: GetVariantsModel, ph=Depends(pandahub)): class CreateVariantModel(BaseModel): project_id: str - variant_data: dict + net_id: int + name: str | None = None + default_name: str | None = None + + +class CreateVariantResponseModel(BaseModel): + net_id: int + index: int + name: str | None = None + date_created: int + date_changed: int + @router.post("/create_variant") -def create_variant(data: CreateVariantModel, ph=Depends(pandahub)): - project_id = data.project_id - ph.set_active_project_by_id(project_id) - return ph.create_variant(data.variant_data) +def create_variant(data: CreateVariantModel, ph=Depends(pandahub)) -> CreateVariantResponseModel: + ph.set_active_project_by_id(data.project_id) + return ph.create_variant(net_id=data.net_id, name=data.name, default_name=data.default_name) class DeleteVariantModel(BaseModel): project_id: str diff --git a/pandahub/lib/PandaHub.py b/pandahub/lib/PandaHub.py index 6a0a505..15fd7cb 100644 --- a/pandahub/lib/PandaHub.py +++ b/pandahub/lib/PandaHub.py @@ -2,6 +2,7 @@ import builtins import json import logging +import time import warnings from inspect import signature, _empty from collections.abc import Callable @@ -91,6 +92,7 @@ def __init__( user_id=None, datatypes=DATATYPES, mongodb_indexes=MONGODB_INDEXES, + elements_without_vars = None, ): mongo_client_args = { "host": connection_url, @@ -115,6 +117,7 @@ def __init__( {"var_type": np.nan}, ] } + self._elements_without_vars = ["variant"] if elements_without_vars is None else elements_without_vars if check_server_available: self.server_is_available() @@ -1150,10 +1153,13 @@ def write_network_to_db( if element_data.empty: continue element_data = element_data.copy(deep=True) - if "var_type" in element_data: - element_data["var_type"] = element_data["var_type"].fillna("base") - else: - element_data["var_type"] = "base" + if element not in self._elements_without_vars: + if "var_type" in element_data: + element_data["var_type"] = element_data["var_type"].fillna("base") + else: + element_data["var_type"] = "base" + element_data["not_in_var"] = np.empty((len(element_data.index), 0)).tolist() + element_data["variant"] = None element_data = convert_element_to_dict(element_data, net_id, self._datatypes.get(element)) self._write_element_to_db(db, element, element_data) @@ -1228,9 +1234,12 @@ def _get_net_id_from_name(self, name, db): def _network_with_name_exists(self, name, db): return self._get_net_id_from_name(name, db) is not None + def _get_net_collections(self, db, with_areas=True): + return self.get_net_collections(db, with_areas) + def get_net_collections(self, db=None, with_areas=True): if db is None: - db = self._get_project_database() + db = self.get_project_database() if with_areas: collection_filter = {"name": {"$regex": "^net_"}} else: @@ -1402,7 +1411,7 @@ def delete_elements( """ if not isinstance(element_indexes, list): raise TypeError("Parameter element_indexes must be a list of ints!") - validate_variant_type(variant) + self.validate_variant(variant, element_type) if project_id: self.set_active_project_by_id(project_id) @@ -1446,7 +1455,7 @@ def set_net_value_in_db( project_id=None, **kwargs, ): - validate_variant_type(variant) + self.validate_variant(variant, element_type) if project_id: self.set_active_project_by_id(project_id) self.check_permission("write") @@ -1514,7 +1523,7 @@ def set_object_attribute( variant=None, project_id=None, ): - validate_variant_type(variant) + self.validate_variant(variant, element_type) if project_id: self.set_active_project_by_id(project_id) self.check_permission("write") @@ -1636,7 +1645,7 @@ def create_elements( list A list of the created elements (elements_data with added _id fields) """ - validate_variant_type(variant) + self.validate_variant(variant, element_type) if project_id: self.set_active_project_by_id(project_id) self.check_permission("write") @@ -1771,24 +1780,30 @@ def _get_int_index(self, collection: str, index_field: str = "_id", query_filter # Variants # ------------------------- - def create_variant(self, data, index: Optional[int] = None): + def create_variant( + self, + net_id: int, + name: str | None = None, + default_name: str = "Variant", + index: int | None = None, + ) -> dict: db = self._get_project_database() - net_id = int(data["net_id"]) if index is None: index = self._get_int_index("variant", index_field="index", query_filter={"net_id": net_id}) - data["index"] = index - if data.get("default_name") is not None and data.get("name") is None: - data["name"] = data.pop("default_name") + " " + str(index) - db["variant"].insert_one(data) - del data["_id"] - - if index == 1: - for coll in self.get_net_collections(db): - db[coll].update_many( - {"$or": [{"var_type": None}, {"var_type": np.nan}]}, - {"$set": {"var_type": "base", "not_in_var": [], "variant": None}}, - ) - return data + + if name is None: + name = f"{default_name} {index + 1}" + now = int(time.time()) + variant = { + "net_id": net_id, + "index": index, + "name": name, + "date_created": now, + "date_changed": now, + } + db["variant"].insert_one(variant) + del variant["_id"] + return variant def delete_variant(self, net_id, index): db = self._get_project_database() @@ -1828,12 +1843,17 @@ def get_variant_filter(self, variant: int | None) -> dict: dict mongodb query filter for the given variant """ - validate_variant_type(variant) + self.validate_variant(variant) if variant is None: return self.base_variant_filter return {"$or": [{"var_type": "base", "not_in_var": {"$ne": variant}}, {"var_type": {"$in": ["change", "addition"]}, "variant": variant}, ]} + def validate_variant(self, variant: int | None, element_type: str | None = None): + """Raise a ValueError if variant is not int | None or element_type does not support variants.""" + validate_variant_type(variant) + if element_type is not None and variant is not None and element_type in self._elements_without_vars: + raise ValueError(f"{element_type} does not support variants") # ------------------------- # Bulk operations diff --git a/pyproject.toml b/pyproject.toml index 0b81e1f..6ca7d12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pandahub" -version = "0.3.9" # File format version '__format_version__' is tracked in __init__.py +version = "0.3.10" # File format version '__format_version__' is tracked in __init__.py authors=[ { name = "Jan Ulffers", email = "jan.ulffers@iee.fraunhofer.de" }, { name = "Leon Thurner", email = "leon.thurner@retoflow.de" },