From 7b33c37fd6a511794e5ae7ca1101cf470e883e76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matt=C3=A9o=20Baussart?= Date: Mon, 25 Nov 2024 10:10:32 +0100 Subject: [PATCH] feat: add any_operator support (#1901) * feat: add any_operator support * feat: add test for Operator copy construction --- src/ansys/dpf/core/any.py | 6 ++++++ src/ansys/dpf/core/dpf_operator.py | 27 ++++++++++++++++++++------ src/ansys/dpf/gate/any_grpcapi.py | 10 ++++++++++ src/ansys/dpf/gate/operator_grpcapi.py | 4 ++++ tests/conftest.py | 3 +++ tests/test_any.py | 12 ++++++++++++ tests/test_operator.py | 6 ++++++ 7 files changed, 62 insertions(+), 6 deletions(-) diff --git a/src/ansys/dpf/core/any.py b/src/ansys/dpf/core/any.py index fc7d0f4738..0ae1f8b530 100644 --- a/src/ansys/dpf/core/any.py +++ b/src/ansys/dpf/core/any.py @@ -123,6 +123,7 @@ def _type_to_new_from_get_as_method(self, obj): custom_type_field, collection, workflow, + dpf_operator, ) if issubclass(obj, int): @@ -196,6 +197,11 @@ def _type_to_new_from_get_as_method(self, obj): self._api.any_new_from_int_collection, self._api.any_get_as_int_collection, ) + elif issubclass(obj, dpf_operator.Operator): + return ( + self._api.any_new_from_operator, + self._api.any_get_as_operator, + ) @staticmethod def new_from(obj, server=None): diff --git a/src/ansys/dpf/core/dpf_operator.py b/src/ansys/dpf/core/dpf_operator.py index ca83d61667..b0e2add66b 100644 --- a/src/ansys/dpf/core/dpf_operator.py +++ b/src/ansys/dpf/core/dpf_operator.py @@ -121,7 +121,7 @@ class Operator: """ - def __init__(self, name, config=None, server=None): + def __init__(self, name=None, config=None, server=None, operator=None): """Initialize the operator with its name by connecting to a stub.""" self.name = name self._internal_obj = None @@ -136,14 +136,29 @@ def __init__(self, name, config=None, server=None): # step 2: get api self._api_instance = None # see _api property - # step3: init environment + # step 3: init environment self._api.init_operator_environment(self) # creates stub when gRPC - # step4: if object exists: take instance, else create it (server) - if self._server.has_client(): - self._internal_obj = self._api.operator_new_on_client(self.name, self._server.client) + # step 4: if object exists, take the instance, else create it + if operator is not None: + if isinstance(operator, Operator): + core_api = self._server.get_api_for_type( + capi=data_processing_capi.DataProcessingCAPI, + grpcapi=data_processing_grpcapi.DataProcessingGRPCAPI, + ) + core_api.init_data_processing_environment(self) + self._internal_obj = core_api.data_processing_duplicate_object_reference(operator) + self.name = operator.name + else: + self._internal_obj = operator + self.name = self._api.operator_name(self) else: - self._internal_obj = self._api.operator_new(self.name) + if self._server.has_client(): + self._internal_obj = self._api.operator_new_on_client( + self.name, self._server.client + ) + else: + self._internal_obj = self._api.operator_new(self.name) if self._internal_obj is None: raise KeyError( diff --git a/src/ansys/dpf/gate/any_grpcapi.py b/src/ansys/dpf/gate/any_grpcapi.py index 69e178d98c..c263e904ff 100644 --- a/src/ansys/dpf/gate/any_grpcapi.py +++ b/src/ansys/dpf/gate/any_grpcapi.py @@ -42,6 +42,7 @@ def _type_to_message_type(): custom_type_field, collection_base, workflow, + dpf_operator, ) return [(int, base_pb2.Type.INT), @@ -58,6 +59,7 @@ def _type_to_message_type(): (workflow.Workflow, base_pb2.Type.WORKFLOW), (collection_base.CollectionBase, base_pb2.Type.COLLECTION, base_pb2.Type.ANY), (dpf_vector.DPFVectorInt, base_pb2.Type.COLLECTION, base_pb2.Type.INT), + (dpf_operator.Operator, base_pb2.Type.OPERATOR), ] @staticmethod @@ -145,6 +147,10 @@ def any_get_as_int_collection(any): def any_get_as_workflow(any): return AnyGRPCAPI._get_as(any).workflow + @staticmethod + def any_get_as_operator(any): + return AnyGRPCAPI._get_as(any).operator + @staticmethod def _new_from(any, client=None): from ansys.grpc.dpf import dpf_any_pb2 @@ -230,3 +236,7 @@ def any_new_from_data_tree(any): @staticmethod def any_new_from_workflow(any): return AnyGRPCAPI._new_from(any, any._server) + + @staticmethod + def any_new_from_operator(any): + return AnyGRPCAPI._new_from(any, any._server) diff --git a/src/ansys/dpf/gate/operator_grpcapi.py b/src/ansys/dpf/gate/operator_grpcapi.py index 2f42d451e9..7eaad43aed 100644 --- a/src/ansys/dpf/gate/operator_grpcapi.py +++ b/src/ansys/dpf/gate/operator_grpcapi.py @@ -51,6 +51,10 @@ def get_list(op): def operator_get_config(op): return OperatorGRPCAPI.get_list(op).config + @staticmethod + def operator_name(op): + return OperatorGRPCAPI.get_list(op).op_name + @staticmethod def update_init(op, pin): from ansys.grpc.dpf import operator_pb2 diff --git a/tests/conftest.py b/tests/conftest.py index 458a5db423..73c3f59ce9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -328,6 +328,9 @@ def return_ds(server=None): return return_ds +SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_10_0 = meets_version( + get_server_version(core._global_server()), "10.0" +) SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_9_1 = meets_version( get_server_version(core._global_server()), "9.1" ) diff --git a/tests/test_any.py b/tests/test_any.py index 5ec9486282..2ad9026466 100644 --- a/tests/test_any.py +++ b/tests/test_any.py @@ -132,3 +132,15 @@ def test_cast_workflow_any(server_type): new_entity = any_dpf.cast() assert new_entity.input_names == [] + + +@pytest.mark.skipif( + not conftest.SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_10_0, + reason="any does not support operator below 8.0", +) +def test_cast_operator_any(server_type): + entity = dpf.Operator(server=server_type, name="U") + any_dpf = dpf.Any.new_from(entity) + new_entity = any_dpf.cast() + + assert entity.name == new_entity.name diff --git a/tests/test_operator.py b/tests/test_operator.py index 7dc3d61715..8cfadd185b 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -54,6 +54,12 @@ def test_create_operator(server_type): assert op._internal_obj +def test_create_operator_from_operator(server_type): + op = dpf.core.Operator("min_max", server=server_type) + op2 = dpf.core.Operator(operator=op, server=server_type) + assert op2._internal_obj + + def test_invalid_operator_name(server_type): # with pytest.raises(errors.DPFServerException): with pytest.raises(Exception):