Skip to content

Commit

Permalink
collection set/get support (#1624)
Browse files Browse the repository at this point in the history
* collection set support

* retro

* Apply suggestions from code review
  • Loading branch information
cbellot000 authored Jun 21, 2024
1 parent dcfc0a9 commit 13c7118
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 27 deletions.
23 changes: 23 additions & 0 deletions src/ansys/dpf/core/collection_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,29 @@ def _set_time_freq_support(self, time_freq_support):
"""Set the time frequency support of the collection."""
self._api.collection_set_support(self, "time", time_freq_support)

@version_requires("5.0")
def set_support(self, label: str, support: Support) -> None:
"""Set the support of the collection for a given label.
Notes
-----
Available starting with DPF 2023 R1.
"""
self._api.collection_set_support(self, label, support)

@version_requires("5.0")
def get_support(self, label: str) -> Support:
"""Get the support of the collection for a given label.
Notes
-----
Available starting with DPF 2023 R1.
"""
from ansys.dpf.core.support import Support
return Support(support=self._api.collection_get_support(self, label), server=self._server)

def __str__(self):
"""Describe the entity.
Expand Down
48 changes: 32 additions & 16 deletions src/ansys/dpf/gate/collection_grpcapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from ansys.dpf.gate.generated import collection_abstract_api
from ansys.dpf.gate import object_handler, data_processing_grpcapi, grpc_stream_helpers, errors

#-------------------------------------------------------------------------------

# -------------------------------------------------------------------------------
# Collection
#-------------------------------------------------------------------------------
# -------------------------------------------------------------------------------

def _get_stub(server):
return server.get_stub(CollectionGRPCAPI.STUBNAME)
Expand All @@ -29,7 +30,8 @@ def init_collection_environment(object):
server.create_stub_if_necessary(
CollectionGRPCAPI.STUBNAME, collection_pb2_grpc.CollectionServiceStub)

object._deleter_func = (_get_stub(server).Delete, lambda obj: obj._internal_obj if isinstance(obj,collection_pb2.Collection) else None)
object._deleter_func = (
_get_stub(server).Delete, lambda obj: obj._internal_obj if isinstance(obj, collection_pb2.Collection) else None)

@staticmethod
def collection_of_scoping_new_on_client(client):
Expand Down Expand Up @@ -135,7 +137,7 @@ def collection_get_obj_by_index_for_label_space(collection, space, index):

@staticmethod
def collection_get_obj_by_index(collection, index):
return data_processing_grpcapi.DataProcessingGRPCAPI.data_processing_duplicate_object_reference(
return data_processing_grpcapi.DataProcessingGRPCAPI.data_processing_duplicate_object_reference(
CollectionGRPCAPI._collection_get_entries(collection, index)[0].entry
)

Expand All @@ -145,7 +147,8 @@ def collection_get_obj_label_space_by_index(collection, index):

@staticmethod
def _collection_get_entries(collection, label_space_or_index):
from ansys.grpc.dpf import collection_pb2, scoping_pb2, field_pb2, meshed_region_pb2, base_pb2, dpf_any_message_pb2
from ansys.grpc.dpf import collection_pb2, scoping_pb2, field_pb2, meshed_region_pb2, base_pb2, \
dpf_any_message_pb2
request = collection_pb2.EntryRequest()
request.collection.CopyFrom(collection._internal_obj)

Expand All @@ -154,7 +157,7 @@ def _collection_get_entries(collection, label_space_or_index):
else:
request.label_space.CopyFrom(label_space_or_index._internal_obj)

out = _get_stub(collection._server).GetEntries(request)
out = _get_stub(collection._server).GetEntries(request)
list_out = []
for obj in out.entries:
label_space = {}
Expand All @@ -163,15 +166,19 @@ def _collection_get_entries(collection, label_space_or_index):
label_space[key] = obj.label_space.label_space[key]
if obj.HasField("dpf_type"):
if collection._internal_obj.type == base_pb2.Type.Value("SCOPING"):
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, scoping_pb2.Scoping())
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI,
scoping_pb2.Scoping())
elif collection._internal_obj.type == base_pb2.Type.Value("FIELD"):
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, field_pb2.Field())
elif collection._internal_obj.type == base_pb2.Type.Value("MESHED_REGION"):
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, meshed_region_pb2.MeshedRegion())
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI,
meshed_region_pb2.MeshedRegion())
elif collection._internal_obj.type == base_pb2.Type.Value("ANY"):
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI, dpf_any_message_pb2.DpfAny())
entry = object_handler.ObjHandler(data_processing_grpcapi.DataProcessingGRPCAPI,
dpf_any_message_pb2.DpfAny())
else:
raise NotImplementedError(f"collection {base_pb2.Type.Name(collection._internal_obj.type)} type is not implemented")
raise NotImplementedError(
f"collection {base_pb2.Type.Name(collection._internal_obj.type)} type is not implemented")
obj.dpf_type.Unpack(entry._internal_obj)
entry._server = collection._server
list_out.append(_CollectionEntry(label_space, entry))
Expand All @@ -193,7 +200,7 @@ def collection_add_entry(collection, labelspace, obj):
request = collection_pb2.UpdateRequest()
request.collection.CopyFrom(collection._internal_obj)
if hasattr(obj, "_message"):
#TO DO: remove
# TO DO: remove
request.entry.dpf_type.Pack(obj._message)
else:
request.entry.dpf_type.Pack(obj._internal_obj)
Expand All @@ -206,7 +213,8 @@ def _collection_set_data_as_integral_type(collection, data, size):
metadata = [(u"size_bytes", f"{size * data.itemsize}")]
request = collection_pb2.UpdateAllDataRequest()
request.collection.CopyFrom(collection._internal_obj)
_get_stub(collection._server).UpdateAllData(grpc_stream_helpers._data_chunk_yielder(request, data), metadata=metadata)
_get_stub(collection._server).UpdateAllData(grpc_stream_helpers._data_chunk_yielder(request, data),
metadata=metadata)

@staticmethod
def collection_set_data_as_int(collection, data, size):
Expand All @@ -219,9 +227,15 @@ def collection_set_data_as_double(collection, data, size):
@staticmethod
def collection_set_support(collection, label, support):
from ansys.grpc.dpf import collection_pb2
from ansys.grpc.dpf import time_freq_support_pb2
from ansys.grpc.dpf import support_pb2
request = collection_pb2.UpdateSupportRequest()
request.collection.CopyFrom(collection._internal_obj)
request.time_freq_support.CopyFrom(support._internal_obj)
if isinstance(support._internal_obj, time_freq_support_pb2.TimeFreqSupport):
request.time_freq_support.CopyFrom(support._internal_obj)
else:
supp = support_pb2.Support(id=support._internal_obj.id)
request.support.CopyFrom(supp)
request.label = label
_get_stub(collection._server).UpdateSupport(request)

Expand All @@ -230,7 +244,11 @@ def collection_get_support(collection, label):
from ansys.grpc.dpf import collection_pb2, base_pb2
request = collection_pb2.SupportRequest()
request.collection.CopyFrom(collection._internal_obj)
request.type = base_pb2.Type.Value("TIME_FREQ_SUPPORT")
if collection._server.meet_version("5.0"):
request.label = label
request.type = base_pb2.Type.Value("SUPPORT")
else:
request.type = base_pb2.Type.Value("TIME_FREQ_SUPPORT")
message = _get_stub(collection._server).GetSupport(request)
return message

Expand Down Expand Up @@ -284,5 +302,3 @@ def collection_add_string_entry(collection, obj):
class _CollectionEntry(NamedTuple):
label_space: dict
entry: object


11 changes: 7 additions & 4 deletions src/ansys/dpf/gate/support_grpcapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ def support_get_as_time_freq_support(support):
if isinstance(internal_obj, time_freq_support_pb2.TimeFreqSupport):
message = support
elif isinstance(internal_obj, support_pb2.Support):
message = time_freq_support_pb2.TimeFreqSupport()
if isinstance(message.id, int):
message.id = internal_obj.id
if hasattr(_get_stub(support._server), "GetSupport"):
message = _get_stub(support._server).GetSupport(internal_obj).time_freq_support
else:
message.id.CopyFrom(internal_obj.id)
message = time_freq_support_pb2.TimeFreqSupport()
if isinstance(message.id, int):
message.id = internal_obj.id
else:
message.id.CopyFrom(internal_obj.id)
else:
raise NotImplementedError(f"Tried to get {support} as TimeFreqSupport.")
return message
Expand Down
30 changes: 29 additions & 1 deletion tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import pytest
import numpy as np
from ansys.dpf.core import CustomTypeField, CustomTypeFieldsCollection, GenericDataContainersCollection, \
StringFieldsCollection, StringField, GenericDataContainer, operators, types, Workflow
StringFieldsCollection, StringField, GenericDataContainer, operators, Workflow, fields_factory
from ansys.dpf.core.collection import Collection
from ansys.dpf.core.time_freq_support import TimeFreqSupport
from ansys.dpf.core.generic_support import GenericSupport
import random
from dataclasses import dataclass, field

Expand Down Expand Up @@ -103,6 +105,32 @@ def test_fill_gdc_collection(server_type):
# assert "collection" in str(coll)


@pytest.mark.parametrize("subtype_creator",
[collection_helper, cust_type_field_collection_helper, string_field_collection_helper],
ids=[collection_helper.name, cust_type_field_collection_helper.name,
string_field_collection_helper.name])
@conftest.raises_for_servers_version_under("8.1")
def test_set_support_collection(server_type, subtype_creator):
coll = subtype_creator.type(server=server_type, **subtype_creator.kwargs)
coll.labels = ["body", "time"]
tfq = TimeFreqSupport(server=server_type)
frequencies = fields_factory.create_scalar_field(3, server=server_type)
frequencies.append([1.0], 1)
tfq.time_frequencies = frequencies

gen_support = GenericSupport(name="body", server=server_type)
str_f = StringField(server=server_type)
str_f.append(["inlet"], 1)
gen_support.set_support_of_property("name", str_f)

coll.set_support("time", tfq)
coll.set_support("body", gen_support)

assert coll.get_support("time").available_field_supported_properties() == ["time_freqs"]
assert coll.get_support("body").available_string_field_supported_properties() == ["name"]
assert coll.get_support("body").string_field_support_by_property("name").data == ["inlet"]


@pytest.mark.parametrize("subtype_creator",
[collection_helper, cust_type_field_collection_helper, string_field_collection_helper,
gdc_collection_helper],
Expand Down
34 changes: 28 additions & 6 deletions tests/test_fieldscontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,17 +361,17 @@ def test_el_shape_fc(allkindofcomplexity):
mesh = model.metadata.meshed_region

f = fc.beam_field()
ids = f.scoping.ids[0 : int(len(f.scoping) / 4)]
ids = f.scoping.ids[0: int(len(f.scoping) / 4)]
for id in ids:
assert mesh.elements.element_by_id(id).shape == "beam"

f = fc.shell_field()
ids = f.scoping.ids[0 : int(len(f.scoping) / 10)]
ids = f.scoping.ids[0: int(len(f.scoping) / 10)]
for id in ids:
assert mesh.elements.element_by_id(id).shape == "shell"

f = fc.solid_field()
ids = f.scoping.ids[0 : int(len(f.scoping) / 10)]
ids = f.scoping.ids[0: int(len(f.scoping) / 10)]
for id in ids:
assert mesh.elements.element_by_id(id).shape == "solid"

Expand All @@ -389,15 +389,15 @@ def test_el_shape_time_fc():
mesh = model.metadata.meshed_region

f = fc.beam_field(3)
for id in f.scoping.ids[0 : int(len(f.scoping.ids) / 3)]:
for id in f.scoping.ids[0: int(len(f.scoping.ids) / 3)]:
assert mesh.elements.element_by_id(id).shape == "beam"

f = fc.shell_field(4)
for id in f.scoping.ids[0 : int(len(f.scoping.ids) / 10)]:
for id in f.scoping.ids[0: int(len(f.scoping.ids) / 10)]:
assert mesh.elements.element_by_id(id).shape == "shell"

f = fc.solid_field(5)
for id in f.scoping.ids[0 : int(len(f.scoping.ids) / 10)]:
for id in f.scoping.ids[0: int(len(f.scoping.ids) / 10)]:
assert mesh.elements.element_by_id(id).shape == "solid"


Expand Down Expand Up @@ -531,6 +531,28 @@ def test_fields_container_get_time_scoping(server_type, disp_fc):
assert freq_scoping.size == 1


@conftest.raises_for_servers_version_under("5.0")
def test_fields_container_set_tfsupport(server_type):
coll = dpf.FieldsContainer(server=server_type)
coll.labels = ["body", "time"]
tfq = TimeFreqSupport(server=server_type)
frequencies = fields_factory.create_scalar_field(3, server=server_type)
frequencies.append([1.0], 1)
tfq.time_frequencies = frequencies

gen_support = dpf.GenericSupport(name="body", server=server_type)
str_f = dpf.StringField(server=server_type)
str_f.append(["inlet"], 1)
gen_support.set_support_of_property("name", str_f)

coll.set_support("time", tfq)
coll.set_support("body", gen_support)

assert coll.get_support("time").available_field_supported_properties() == ["time_freqs"]
assert coll.get_support("body").available_string_field_supported_properties() == ["name"]
assert coll.get_support("body").string_field_support_by_property("name").data == ["inlet"]


@pytest.mark.skipif(
not conftest.SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_7_0, reason="Available for servers >=7.0"
)
Expand Down

0 comments on commit 13c7118

Please sign in to comment.