From 13c7118cfda6f2328066a79fa71298e51bd34509 Mon Sep 17 00:00:00 2001 From: Camille Bellot <80476446+cbellot000@users.noreply.github.com> Date: Fri, 21 Jun 2024 10:27:36 +0200 Subject: [PATCH] collection set/get support (#1624) * collection set support * retro * Apply suggestions from code review --- src/ansys/dpf/core/collection_base.py | 23 ++++++++++++ src/ansys/dpf/gate/collection_grpcapi.py | 48 ++++++++++++++++-------- src/ansys/dpf/gate/support_grpcapi.py | 11 ++++-- tests/test_collection.py | 30 ++++++++++++++- tests/test_fieldscontainer.py | 34 ++++++++++++++--- 5 files changed, 119 insertions(+), 27 deletions(-) diff --git a/src/ansys/dpf/core/collection_base.py b/src/ansys/dpf/core/collection_base.py index ee66520252..e6acfc638e 100644 --- a/src/ansys/dpf/core/collection_base.py +++ b/src/ansys/dpf/core/collection_base.py @@ -445,6 +445,29 @@ def _set_time_freq_support(self, time_freq_support): """Set the time frequency support of the collection.""" self._api.collection_set_support(self, "time", time_freq_support) + @version_requires("5.0") + def set_support(self, label: str, support: Support) -> None: + """Set the support of the collection for a given label. + + Notes + ----- + Available starting with DPF 2023 R1. + + """ + self._api.collection_set_support(self, label, support) + + @version_requires("5.0") + def get_support(self, label: str) -> Support: + """Get the support of the collection for a given label. + + Notes + ----- + Available starting with DPF 2023 R1. + + """ + from ansys.dpf.core.support import Support + return Support(support=self._api.collection_get_support(self, label), server=self._server) + def __str__(self): """Describe the entity. diff --git a/src/ansys/dpf/gate/collection_grpcapi.py b/src/ansys/dpf/gate/collection_grpcapi.py index 56d90f3f53..31bd55377c 100644 --- a/src/ansys/dpf/gate/collection_grpcapi.py +++ b/src/ansys/dpf/gate/collection_grpcapi.py @@ -5,9 +5,10 @@ from ansys.dpf.gate.generated import collection_abstract_api from ansys.dpf.gate import object_handler, data_processing_grpcapi, grpc_stream_helpers, errors -#------------------------------------------------------------------------------- + +# ------------------------------------------------------------------------------- # Collection -#------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- def _get_stub(server): return server.get_stub(CollectionGRPCAPI.STUBNAME) @@ -29,7 +30,8 @@ def init_collection_environment(object): server.create_stub_if_necessary( CollectionGRPCAPI.STUBNAME, collection_pb2_grpc.CollectionServiceStub) - object._deleter_func = (_get_stub(server).Delete, lambda obj: obj._internal_obj if isinstance(obj,collection_pb2.Collection) else None) + object._deleter_func = ( + _get_stub(server).Delete, lambda obj: obj._internal_obj if isinstance(obj, collection_pb2.Collection) else None) @staticmethod def collection_of_scoping_new_on_client(client): @@ -135,7 +137,7 @@ def collection_get_obj_by_index_for_label_space(collection, space, index): @staticmethod def collection_get_obj_by_index(collection, index): - return data_processing_grpcapi.DataProcessingGRPCAPI.data_processing_duplicate_object_reference( + return data_processing_grpcapi.DataProcessingGRPCAPI.data_processing_duplicate_object_reference( CollectionGRPCAPI._collection_get_entries(collection, index)[0].entry ) @@ -145,7 +147,8 @@ def collection_get_obj_label_space_by_index(collection, index): @staticmethod def _collection_get_entries(collection, label_space_or_index): - from ansys.grpc.dpf import collection_pb2, scoping_pb2, field_pb2, meshed_region_pb2, base_pb2, dpf_any_message_pb2 + from ansys.grpc.dpf import collection_pb2, scoping_pb2, field_pb2, meshed_region_pb2, base_pb2, \ + dpf_any_message_pb2 request = collection_pb2.EntryRequest() request.collection.CopyFrom(collection._internal_obj) @@ -154,7 +157,7 @@ def _collection_get_entries(collection, label_space_or_index): else: request.label_space.CopyFrom(label_space_or_index._internal_obj) - out = _get_stub(collection._server).GetEntries(request) + out = _get_stub(collection._server).GetEntries(request) list_out = [] for obj in out.entries: label_space = {} @@ -163,15 +166,19 @@ def _collection_get_entries(collection, label_space_or_index): label_space[key] = obj.label_space.label_space[key] if obj.HasField("dpf_type"): if collection._internal_obj.type == base_pb2.Type.Value("SCOPING"): - entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, scoping_pb2.Scoping()) + entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, + scoping_pb2.Scoping()) elif collection._internal_obj.type == base_pb2.Type.Value("FIELD"): entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, field_pb2.Field()) elif collection._internal_obj.type == base_pb2.Type.Value("MESHED_REGION"): - entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, meshed_region_pb2.MeshedRegion()) + entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, + meshed_region_pb2.MeshedRegion()) elif collection._internal_obj.type == base_pb2.Type.Value("ANY"): - entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, dpf_any_message_pb2.DpfAny()) + entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, + dpf_any_message_pb2.DpfAny()) else: - raise NotImplementedError(f"collection {base_pb2.Type.Name(collection._internal_obj.type)} type is not implemented") + raise NotImplementedError( + f"collection {base_pb2.Type.Name(collection._internal_obj.type)} type is not implemented") obj.dpf_type.Unpack(entry._internal_obj) entry._server = collection._server list_out.append(_CollectionEntry(label_space, entry)) @@ -193,7 +200,7 @@ def collection_add_entry(collection, labelspace, obj): request = collection_pb2.UpdateRequest() request.collection.CopyFrom(collection._internal_obj) if hasattr(obj, "_message"): - #TO DO: remove + # TO DO: remove request.entry.dpf_type.Pack(obj._message) else: request.entry.dpf_type.Pack(obj._internal_obj) @@ -206,7 +213,8 @@ def _collection_set_data_as_integral_type(collection, data, size): metadata = [(u"size_bytes", f"{size * data.itemsize}")] request = collection_pb2.UpdateAllDataRequest() request.collection.CopyFrom(collection._internal_obj) - _get_stub(collection._server).UpdateAllData(grpc_stream_helpers._data_chunk_yielder(request, data), metadata=metadata) + _get_stub(collection._server).UpdateAllData(grpc_stream_helpers._data_chunk_yielder(request, data), + metadata=metadata) @staticmethod def collection_set_data_as_int(collection, data, size): @@ -219,9 +227,15 @@ def collection_set_data_as_double(collection, data, size): @staticmethod def collection_set_support(collection, label, support): from ansys.grpc.dpf import collection_pb2 + from ansys.grpc.dpf import time_freq_support_pb2 + from ansys.grpc.dpf import support_pb2 request = collection_pb2.UpdateSupportRequest() request.collection.CopyFrom(collection._internal_obj) - request.time_freq_support.CopyFrom(support._internal_obj) + if isinstance(support._internal_obj, time_freq_support_pb2.TimeFreqSupport): + request.time_freq_support.CopyFrom(support._internal_obj) + else: + supp = support_pb2.Support(id=support._internal_obj.id) + request.support.CopyFrom(supp) request.label = label _get_stub(collection._server).UpdateSupport(request) @@ -230,7 +244,11 @@ def collection_get_support(collection, label): from ansys.grpc.dpf import collection_pb2, base_pb2 request = collection_pb2.SupportRequest() request.collection.CopyFrom(collection._internal_obj) - request.type = base_pb2.Type.Value("TIME_FREQ_SUPPORT") + if collection._server.meet_version("5.0"): + request.label = label + request.type = base_pb2.Type.Value("SUPPORT") + else: + request.type = base_pb2.Type.Value("TIME_FREQ_SUPPORT") message = _get_stub(collection._server).GetSupport(request) return message @@ -284,5 +302,3 @@ def collection_add_string_entry(collection, obj): class _CollectionEntry(NamedTuple): label_space: dict entry: object - - diff --git a/src/ansys/dpf/gate/support_grpcapi.py b/src/ansys/dpf/gate/support_grpcapi.py index 14f2cc24d7..e56c02754d 100644 --- a/src/ansys/dpf/gate/support_grpcapi.py +++ b/src/ansys/dpf/gate/support_grpcapi.py @@ -36,11 +36,14 @@ def support_get_as_time_freq_support(support): if isinstance(internal_obj, time_freq_support_pb2.TimeFreqSupport): message = support elif isinstance(internal_obj, support_pb2.Support): - message = time_freq_support_pb2.TimeFreqSupport() - if isinstance(message.id, int): - message.id = internal_obj.id + if hasattr(_get_stub(support._server), "GetSupport"): + message = _get_stub(support._server).GetSupport(internal_obj).time_freq_support else: - message.id.CopyFrom(internal_obj.id) + message = time_freq_support_pb2.TimeFreqSupport() + if isinstance(message.id, int): + message.id = internal_obj.id + else: + message.id.CopyFrom(internal_obj.id) else: raise NotImplementedError(f"Tried to get {support} as TimeFreqSupport.") return message diff --git a/tests/test_collection.py b/tests/test_collection.py index 63aad1d500..a1196d291c 100644 --- a/tests/test_collection.py +++ b/tests/test_collection.py @@ -5,8 +5,10 @@ import pytest import numpy as np from ansys.dpf.core import CustomTypeField, CustomTypeFieldsCollection, GenericDataContainersCollection, \ - StringFieldsCollection, StringField, GenericDataContainer, operators, types, Workflow + StringFieldsCollection, StringField, GenericDataContainer, operators, Workflow, fields_factory from ansys.dpf.core.collection import Collection +from ansys.dpf.core.time_freq_support import TimeFreqSupport +from ansys.dpf.core.generic_support import GenericSupport import random from dataclasses import dataclass, field @@ -103,6 +105,32 @@ def test_fill_gdc_collection(server_type): # assert "collection" in str(coll) +@pytest.mark.parametrize("subtype_creator", + [collection_helper, cust_type_field_collection_helper, string_field_collection_helper], + ids=[collection_helper.name, cust_type_field_collection_helper.name, + string_field_collection_helper.name]) +@conftest.raises_for_servers_version_under("8.1") +def test_set_support_collection(server_type, subtype_creator): + coll = subtype_creator.type(server=server_type, **subtype_creator.kwargs) + coll.labels = ["body", "time"] + tfq = TimeFreqSupport(server=server_type) + frequencies = fields_factory.create_scalar_field(3, server=server_type) + frequencies.append([1.0], 1) + tfq.time_frequencies = frequencies + + gen_support = GenericSupport(name="body", server=server_type) + str_f = StringField(server=server_type) + str_f.append(["inlet"], 1) + gen_support.set_support_of_property("name", str_f) + + coll.set_support("time", tfq) + coll.set_support("body", gen_support) + + assert coll.get_support("time").available_field_supported_properties() == ["time_freqs"] + assert coll.get_support("body").available_string_field_supported_properties() == ["name"] + assert coll.get_support("body").string_field_support_by_property("name").data == ["inlet"] + + @pytest.mark.parametrize("subtype_creator", [collection_helper, cust_type_field_collection_helper, string_field_collection_helper, gdc_collection_helper], diff --git a/tests/test_fieldscontainer.py b/tests/test_fieldscontainer.py index a907c60ac3..92ee7964c2 100644 --- a/tests/test_fieldscontainer.py +++ b/tests/test_fieldscontainer.py @@ -361,17 +361,17 @@ def test_el_shape_fc(allkindofcomplexity): mesh = model.metadata.meshed_region f = fc.beam_field() - ids = f.scoping.ids[0 : int(len(f.scoping) / 4)] + ids = f.scoping.ids[0: int(len(f.scoping) / 4)] for id in ids: assert mesh.elements.element_by_id(id).shape == "beam" f = fc.shell_field() - ids = f.scoping.ids[0 : int(len(f.scoping) / 10)] + ids = f.scoping.ids[0: int(len(f.scoping) / 10)] for id in ids: assert mesh.elements.element_by_id(id).shape == "shell" f = fc.solid_field() - ids = f.scoping.ids[0 : int(len(f.scoping) / 10)] + ids = f.scoping.ids[0: int(len(f.scoping) / 10)] for id in ids: assert mesh.elements.element_by_id(id).shape == "solid" @@ -389,15 +389,15 @@ def test_el_shape_time_fc(): mesh = model.metadata.meshed_region f = fc.beam_field(3) - for id in f.scoping.ids[0 : int(len(f.scoping.ids) / 3)]: + for id in f.scoping.ids[0: int(len(f.scoping.ids) / 3)]: assert mesh.elements.element_by_id(id).shape == "beam" f = fc.shell_field(4) - for id in f.scoping.ids[0 : int(len(f.scoping.ids) / 10)]: + for id in f.scoping.ids[0: int(len(f.scoping.ids) / 10)]: assert mesh.elements.element_by_id(id).shape == "shell" f = fc.solid_field(5) - for id in f.scoping.ids[0 : int(len(f.scoping.ids) / 10)]: + for id in f.scoping.ids[0: int(len(f.scoping.ids) / 10)]: assert mesh.elements.element_by_id(id).shape == "solid" @@ -531,6 +531,28 @@ def test_fields_container_get_time_scoping(server_type, disp_fc): assert freq_scoping.size == 1 +@conftest.raises_for_servers_version_under("5.0") +def test_fields_container_set_tfsupport(server_type): + coll = dpf.FieldsContainer(server=server_type) + coll.labels = ["body", "time"] + tfq = TimeFreqSupport(server=server_type) + frequencies = fields_factory.create_scalar_field(3, server=server_type) + frequencies.append([1.0], 1) + tfq.time_frequencies = frequencies + + gen_support = dpf.GenericSupport(name="body", server=server_type) + str_f = dpf.StringField(server=server_type) + str_f.append(["inlet"], 1) + gen_support.set_support_of_property("name", str_f) + + coll.set_support("time", tfq) + coll.set_support("body", gen_support) + + assert coll.get_support("time").available_field_supported_properties() == ["time_freqs"] + assert coll.get_support("body").available_string_field_supported_properties() == ["name"] + assert coll.get_support("body").string_field_support_by_property("name").data == ["inlet"] + + @pytest.mark.skipif( not conftest.SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_7_0, reason="Available for servers >=7.0" )