From 7ee3768c2f8107e3e28e070f558b983240df9da4 Mon Sep 17 00:00:00 2001 From: Michael Reneer Date: Fri, 4 Oct 2024 13:52:08 -0700 Subject: [PATCH] Remove `LocalComputationFactory` interface, this abstraction is no longer used. PiperOrigin-RevId: 682439444 --- .../environments/tensorflow_backend/BUILD | 1 - .../tensorflow_computation_factory.py | 56 ++++++++++++------- .../core/environments/xla_backend/BUILD | 2 +- .../core/environments/xla_backend/compiler.py | 25 +++++++-- .../python/core/impl/compiler/BUILD | 9 --- .../local_computation_factory_base.py | 52 ----------------- 6 files changed, 58 insertions(+), 87 deletions(-) delete mode 100644 tensorflow_federated/python/core/impl/compiler/local_computation_factory_base.py diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/BUILD b/tensorflow_federated/python/core/environments/tensorflow_backend/BUILD index 0ec1d82d1d..e8ab315079 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/BUILD +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/BUILD @@ -151,7 +151,6 @@ py_library( "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/compiler:local_computation_factory_base", "//tensorflow_federated/python/core/impl/types:array_shape", "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/core/impl/types:type_analysis", diff --git a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory.py b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory.py index 013a660c8a..603aa7d71b 100644 --- a/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory.py +++ b/tensorflow_federated/python/core/environments/tensorflow_backend/tensorflow_computation_factory.py @@ -20,13 +20,12 @@ import numpy as np import tensorflow as tf -from tensorflow_federated.proto.v0 import computation_pb2 as pb +from tensorflow_federated.proto.v0 import computation_pb2 from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.tensorflow_backend import serialization_utils from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_utils from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions -from tensorflow_federated.python.core.impl.compiler import local_computation_factory_base from tensorflow_federated.python.core.impl.types import array_shape from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.core.impl.types import type_analysis @@ -34,13 +33,13 @@ from tensorflow_federated.python.core.impl.types import type_transformations -ComputationProtoAndType = local_computation_factory_base.ComputationProtoAndType +ComputationProtoAndType = tuple[ + computation_pb2.Computation, computation_types.Type +] T = TypeVar('T', bound=computation_types.Type) -class TensorFlowComputationFactory( - local_computation_factory_base.LocalComputationFactory -): +class TensorFlowComputationFactory: """An implementation of local computation factory for TF computations.""" def __init__(self): @@ -49,21 +48,40 @@ def __init__(self): def create_constant_from_scalar( self, value, type_spec: computation_types.Type ) -> ComputationProtoAndType: + """Creates a TFF computation returning a constant based on a scalar value. + + The returned computation has the type signature `( -> T)`, where `T` may be + either a scalar, or a nested structure made up of scalars. + + Args: + value: A numpy scalar representing the value to return from the + constructed computation (or to broadcast to all parts of a nested + structure if `type_spec` is a structured type). + type_spec: A `computation_types.Type` of the constructed constant. Must be + either a tensor, or a nested structure of tensors. + + Returns: + A tuple `(pb.Computation, computation_types.Type)` with the first element + being a TFF computation with semantics as described above, and the second + element representing the formal type of that computation. + """ return create_constant(value, type_spec) def _tensorflow_comp( - tensorflow_proto: pb.TensorFlow, + tensorflow_proto: computation_pb2.TensorFlow, type_signature: T, -) -> tuple[pb.Computation, T]: +) -> tuple[computation_pb2.Computation, T]: serialized_type = type_serialization.serialize_type(type_signature) - comp = pb.Computation(type=serialized_type, tensorflow=tensorflow_proto) + comp = computation_pb2.Computation( + type=serialized_type, tensorflow=tensorflow_proto + ) return (comp, type_signature) def create_constant( value, type_spec: computation_types.Type -) -> tuple[pb.Computation, computation_types.FunctionType]: +) -> tuple[computation_pb2.Computation, computation_types.FunctionType]: """Returns a tensorflow computation returning a constant `value`. The returned computation has the type signature `( -> T)`, where `T` is @@ -154,7 +172,7 @@ def _create_result_tensor(type_spec, value): ) type_signature = computation_types.FunctionType(None, result_type) - tensorflow = pb.TensorFlow( + tensorflow = computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=None, result=result_binding, @@ -202,7 +220,7 @@ def create_unary_operator( type_signature = computation_types.FunctionType(operand_type, result_type) parameter_binding = operand_binding - tensorflow = pb.TensorFlow( + tensorflow = computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding, @@ -279,12 +297,12 @@ def create_binary_operator( computation_types.StructType((operand_type, second_operand_type)), result_type, ) - parameter_binding = pb.TensorFlow.Binding( - struct=pb.TensorFlow.StructBinding( + parameter_binding = computation_pb2.TensorFlow.Binding( + struct=computation_pb2.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding] ) # pytype: disable=wrong-arg-types ) - tensorflow = pb.TensorFlow( + tensorflow = computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding, @@ -407,12 +425,12 @@ def _pack_into_type(to_pack: tf.Tensor, type_spec: computation_types.Type): ) type_signature = computation_types.FunctionType(type_signature, result_type) - parameter_binding = pb.TensorFlow.Binding( - struct=pb.TensorFlow.StructBinding( + parameter_binding = computation_pb2.TensorFlow.Binding( + struct=computation_pb2.TensorFlow.StructBinding( element=[operand_1_binding, operand_2_binding] ) # pytype: disable=wrong-arg-types ) - tensorflow = pb.TensorFlow( + tensorflow = computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding, @@ -500,7 +518,7 @@ def create_computation_for_py_fn( ) type_signature = computation_types.FunctionType(parameter_type, result_type) - tensorflow = pb.TensorFlow( + tensorflow = computation_pb2.TensorFlow( graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()), parameter=parameter_binding, result=result_binding, diff --git a/tensorflow_federated/python/core/environments/xla_backend/BUILD b/tensorflow_federated/python/core/environments/xla_backend/BUILD index 7385743398..82c7660a6c 100644 --- a/tensorflow_federated/python/core/environments/xla_backend/BUILD +++ b/tensorflow_federated/python/core/environments/xla_backend/BUILD @@ -28,9 +28,9 @@ py_library( srcs = ["compiler.py"], deps = [ ":xla_serialization", + "//tensorflow_federated/proto/v0:computation_py_pb2", "//tensorflow_federated/python/common_libs:py_typecheck", "//tensorflow_federated/python/common_libs:structure", - "//tensorflow_federated/python/core/impl/compiler:local_computation_factory_base", "//tensorflow_federated/python/core/impl/types:computation_types", "//tensorflow_federated/python/core/impl/types:type_analysis", ], diff --git a/tensorflow_federated/python/core/environments/xla_backend/compiler.py b/tensorflow_federated/python/core/environments/xla_backend/compiler.py index 198c4b1e39..d60c9eac78 100644 --- a/tensorflow_federated/python/core/environments/xla_backend/compiler.py +++ b/tensorflow_federated/python/core/environments/xla_backend/compiler.py @@ -16,17 +16,15 @@ from jax.lib import xla_client import numpy as np +from tensorflow_federated.proto.v0 import computation_pb2 from tensorflow_federated.python.common_libs import py_typecheck from tensorflow_federated.python.common_libs import structure from tensorflow_federated.python.core.environments.xla_backend import xla_serialization -from tensorflow_federated.python.core.impl.compiler import local_computation_factory_base from tensorflow_federated.python.core.impl.types import computation_types from tensorflow_federated.python.core.impl.types import type_analysis -class XlaComputationFactory( - local_computation_factory_base.LocalComputationFactory -): +class XlaComputationFactory: """An implementation of local computation factory for XLA computations.""" def __init__(self): @@ -34,7 +32,24 @@ def __init__(self): def create_constant_from_scalar( self, value, type_spec: computation_types.Type - ) -> local_computation_factory_base.ComputationProtoAndType: + ) -> tuple[computation_pb2.Computation, computation_types.Type]: + """Creates a TFF computation returning a constant based on a scalar value. + + The returned computation has the type signature `( -> T)`, where `T` may be + either a scalar, or a nested structure made up of scalars. + + Args: + value: A numpy scalar representing the value to return from the + constructed computation (or to broadcast to all parts of a nested + structure if `type_spec` is a structured type). + type_spec: A `computation_types.Type` of the constructed constant. Must be + either a tensor, or a nested structure of tensors. + + Returns: + A tuple `(pb.Computation, computation_types.Type)` with the first element + being a TFF computation with semantics as described above, and the second + element representing the formal type of that computation. + """ py_typecheck.check_type(type_spec, computation_types.Type) if not type_analysis.is_structure_of_tensors(type_spec): raise ValueError( diff --git a/tensorflow_federated/python/core/impl/compiler/BUILD b/tensorflow_federated/python/core/impl/compiler/BUILD index 7c1196776b..552f440d08 100644 --- a/tensorflow_federated/python/core/impl/compiler/BUILD +++ b/tensorflow_federated/python/core/impl/compiler/BUILD @@ -218,15 +218,6 @@ py_test( deps = [":intrinsic_defs"], ) -py_library( - name = "local_computation_factory_base", - srcs = ["local_computation_factory_base.py"], - deps = [ - "//tensorflow_federated/proto/v0:computation_py_pb2", - "//tensorflow_federated/python/core/impl/types:computation_types", - ], -) - py_library( name = "transformations", srcs = ["transformations.py"], diff --git a/tensorflow_federated/python/core/impl/compiler/local_computation_factory_base.py b/tensorflow_federated/python/core/impl/compiler/local_computation_factory_base.py deleted file mode 100644 index fcbfa32bc3..0000000000 --- a/tensorflow_federated/python/core/impl/compiler/local_computation_factory_base.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright 2021, The TensorFlow Federated Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Defines the interface for factories of framework-specific computations.""" - -import abc - -from tensorflow_federated.proto.v0 import computation_pb2 as pb -from tensorflow_federated.python.core.impl.types import computation_types - -ComputationProtoAndType = tuple[pb.Computation, computation_types.Type] - - -class LocalComputationFactory(metaclass=abc.ABCMeta): - """Interface for factories of backend framework-specific local computations. - - Implementations of this interface encapsulate the logic for constructing local - computations that are executable on a particular type of backend. - """ - - @abc.abstractmethod - def create_constant_from_scalar( - self, value, type_spec: computation_types.Type - ) -> ComputationProtoAndType: - """Creates a TFF computation returning a constant based on a scalar value. - - The returned computation has the type signature `( -> T)`, where `T` may be - either a scalar, or a nested structure made up of scalars. - - Args: - value: A numpy scalar representing the value to return from the - constructed computation (or to broadcast to all parts of a nested - structure if `type_spec` is a structured type). - type_spec: A `computation_types.Type` of the constructed constant. Must be - either a tensor, or a nested structure of tensors. - - Returns: - A tuple `(pb.Computation, computation_types.Type)` with the first element - being a TFF computation with semantics as described above, and the second - element representing the formal type of that computation. - """ - raise NotImplementedError