Skip to content

Commit

Permalink
feat: add any_operator support (#1901)
Browse files Browse the repository at this point in the history
* feat: add any_operator support

* feat: add test for Operator copy construction
  • Loading branch information
Matteo-Baussart-ANSYS authored Nov 25, 2024
1 parent afa88c8 commit 7b33c37
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 6 deletions.
6 changes: 6 additions & 0 deletions src/ansys/dpf/core/any.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _type_to_new_from_get_as_method(self, obj):
custom_type_field,
collection,
workflow,
dpf_operator,
)

if issubclass(obj, int):
Expand Down Expand Up @@ -196,6 +197,11 @@ def _type_to_new_from_get_as_method(self, obj):
self._api.any_new_from_int_collection,
self._api.any_get_as_int_collection,
)
elif issubclass(obj, dpf_operator.Operator):
return (
self._api.any_new_from_operator,
self._api.any_get_as_operator,
)

@staticmethod
def new_from(obj, server=None):
Expand Down
27 changes: 21 additions & 6 deletions src/ansys/dpf/core/dpf_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class Operator:
"""

def __init__(self, name, config=None, server=None):
def __init__(self, name=None, config=None, server=None, operator=None):
"""Initialize the operator with its name by connecting to a stub."""
self.name = name
self._internal_obj = None
Expand All @@ -136,14 +136,29 @@ def __init__(self, name, config=None, server=None):
# step 2: get api
self._api_instance = None # see _api property

# step3: init environment
# step 3: init environment
self._api.init_operator_environment(self) # creates stub when gRPC

# step4: if object exists: take instance, else create it (server)
if self._server.has_client():
self._internal_obj = self._api.operator_new_on_client(self.name, self._server.client)
# step 4: if object exists, take the instance, else create it
if operator is not None:
if isinstance(operator, Operator):
core_api = self._server.get_api_for_type(
capi=data_processing_capi.DataProcessingCAPI,
grpcapi=data_processing_grpcapi.DataProcessingGRPCAPI,
)
core_api.init_data_processing_environment(self)
self._internal_obj = core_api.data_processing_duplicate_object_reference(operator)
self.name = operator.name
else:
self._internal_obj = operator
self.name = self._api.operator_name(self)
else:
self._internal_obj = self._api.operator_new(self.name)
if self._server.has_client():
self._internal_obj = self._api.operator_new_on_client(
self.name, self._server.client
)
else:
self._internal_obj = self._api.operator_new(self.name)

if self._internal_obj is None:
raise KeyError(
Expand Down
10 changes: 10 additions & 0 deletions src/ansys/dpf/gate/any_grpcapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _type_to_message_type():
custom_type_field,
collection_base,
workflow,
dpf_operator,
)

return [(int, base_pb2.Type.INT),
Expand All @@ -58,6 +59,7 @@ def _type_to_message_type():
(workflow.Workflow, base_pb2.Type.WORKFLOW),
(collection_base.CollectionBase, base_pb2.Type.COLLECTION, base_pb2.Type.ANY),
(dpf_vector.DPFVectorInt, base_pb2.Type.COLLECTION, base_pb2.Type.INT),
(dpf_operator.Operator, base_pb2.Type.OPERATOR),
]

@staticmethod
Expand Down Expand Up @@ -145,6 +147,10 @@ def any_get_as_int_collection(any):
def any_get_as_workflow(any):
return AnyGRPCAPI._get_as(any).workflow

@staticmethod
def any_get_as_operator(any):
return AnyGRPCAPI._get_as(any).operator

@staticmethod
def _new_from(any, client=None):
from ansys.grpc.dpf import dpf_any_pb2
Expand Down Expand Up @@ -230,3 +236,7 @@ def any_new_from_data_tree(any):
@staticmethod
def any_new_from_workflow(any):
return AnyGRPCAPI._new_from(any, any._server)

@staticmethod
def any_new_from_operator(any):
return AnyGRPCAPI._new_from(any, any._server)
4 changes: 4 additions & 0 deletions src/ansys/dpf/gate/operator_grpcapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def get_list(op):
def operator_get_config(op):
return OperatorGRPCAPI.get_list(op).config

@staticmethod
def operator_name(op):
return OperatorGRPCAPI.get_list(op).op_name

@staticmethod
def update_init(op, pin):
from ansys.grpc.dpf import operator_pb2
Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,9 @@ def return_ds(server=None):
return return_ds


SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_10_0 = meets_version(
get_server_version(core._global_server()), "10.0"
)
SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_9_1 = meets_version(
get_server_version(core._global_server()), "9.1"
)
Expand Down
12 changes: 12 additions & 0 deletions tests/test_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,15 @@ def test_cast_workflow_any(server_type):
new_entity = any_dpf.cast()

assert new_entity.input_names == []


@pytest.mark.skipif(
not conftest.SERVERS_VERSION_GREATER_THAN_OR_EQUAL_TO_10_0,
reason="any does not support operator below 8.0",
)
def test_cast_operator_any(server_type):
entity = dpf.Operator(server=server_type, name="U")
any_dpf = dpf.Any.new_from(entity)
new_entity = any_dpf.cast()

assert entity.name == new_entity.name
6 changes: 6 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@ def test_create_operator(server_type):
assert op._internal_obj


def test_create_operator_from_operator(server_type):
op = dpf.core.Operator("min_max", server=server_type)
op2 = dpf.core.Operator(operator=op, server=server_type)
assert op2._internal_obj


def test_invalid_operator_name(server_type):
# with pytest.raises(errors.DPFServerException):
with pytest.raises(Exception):
Expand Down

0 comments on commit 7b33c37

Please sign in to comment.