Skip to content

Commit

Permalink
Remove LocalComputationFactory interface, this abstraction is no lo…
Browse files Browse the repository at this point in the history
…nger used.

PiperOrigin-RevId: 682439444
  • Loading branch information
michaelreneer authored and copybara-github committed Oct 4, 2024
1 parent 4f0a985 commit 7ee3768
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 87 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@ py_library(
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/compiler:local_computation_factory_base",
"//tensorflow_federated/python/core/impl/types:array_shape",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:type_analysis",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,26 @@
import numpy as np
import tensorflow as tf

from tensorflow_federated.proto.v0 import computation_pb2 as pb
from tensorflow_federated.proto.v0 import computation_pb2
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.environments.tensorflow_backend import serialization_utils
from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_utils
from tensorflow_federated.python.core.environments.tensorflow_backend import type_conversions
from tensorflow_federated.python.core.impl.compiler import local_computation_factory_base
from tensorflow_federated.python.core.impl.types import array_shape
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import type_analysis
from tensorflow_federated.python.core.impl.types import type_serialization
from tensorflow_federated.python.core.impl.types import type_transformations


ComputationProtoAndType = local_computation_factory_base.ComputationProtoAndType
ComputationProtoAndType = tuple[
computation_pb2.Computation, computation_types.Type
]
T = TypeVar('T', bound=computation_types.Type)


class TensorFlowComputationFactory(
local_computation_factory_base.LocalComputationFactory
):
class TensorFlowComputationFactory:
"""An implementation of local computation factory for TF computations."""

def __init__(self):
Expand All @@ -49,21 +48,40 @@ def __init__(self):
def create_constant_from_scalar(
self, value, type_spec: computation_types.Type
) -> ComputationProtoAndType:
"""Creates a TFF computation returning a constant based on a scalar value.
The returned computation has the type signature `( -> T)`, where `T` may be
either a scalar, or a nested structure made up of scalars.
Args:
value: A numpy scalar representing the value to return from the
constructed computation (or to broadcast to all parts of a nested
structure if `type_spec` is a structured type).
type_spec: A `computation_types.Type` of the constructed constant. Must be
either a tensor, or a nested structure of tensors.
Returns:
A tuple `(pb.Computation, computation_types.Type)` with the first element
being a TFF computation with semantics as described above, and the second
element representing the formal type of that computation.
"""
return create_constant(value, type_spec)


def _tensorflow_comp(
tensorflow_proto: pb.TensorFlow,
tensorflow_proto: computation_pb2.TensorFlow,
type_signature: T,
) -> tuple[pb.Computation, T]:
) -> tuple[computation_pb2.Computation, T]:
serialized_type = type_serialization.serialize_type(type_signature)
comp = pb.Computation(type=serialized_type, tensorflow=tensorflow_proto)
comp = computation_pb2.Computation(
type=serialized_type, tensorflow=tensorflow_proto
)
return (comp, type_signature)


def create_constant(
value, type_spec: computation_types.Type
) -> tuple[pb.Computation, computation_types.FunctionType]:
) -> tuple[computation_pb2.Computation, computation_types.FunctionType]:
"""Returns a tensorflow computation returning a constant `value`.
The returned computation has the type signature `( -> T)`, where `T` is
Expand Down Expand Up @@ -154,7 +172,7 @@ def _create_result_tensor(type_spec, value):
)

type_signature = computation_types.FunctionType(None, result_type)
tensorflow = pb.TensorFlow(
tensorflow = computation_pb2.TensorFlow(
graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
parameter=None,
result=result_binding,
Expand Down Expand Up @@ -202,7 +220,7 @@ def create_unary_operator(

type_signature = computation_types.FunctionType(operand_type, result_type)
parameter_binding = operand_binding
tensorflow = pb.TensorFlow(
tensorflow = computation_pb2.TensorFlow(
graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
parameter=parameter_binding,
result=result_binding,
Expand Down Expand Up @@ -279,12 +297,12 @@ def create_binary_operator(
computation_types.StructType((operand_type, second_operand_type)),
result_type,
)
parameter_binding = pb.TensorFlow.Binding(
struct=pb.TensorFlow.StructBinding(
parameter_binding = computation_pb2.TensorFlow.Binding(
struct=computation_pb2.TensorFlow.StructBinding(
element=[operand_1_binding, operand_2_binding]
) # pytype: disable=wrong-arg-types
)
tensorflow = pb.TensorFlow(
tensorflow = computation_pb2.TensorFlow(
graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
parameter=parameter_binding,
result=result_binding,
Expand Down Expand Up @@ -407,12 +425,12 @@ def _pack_into_type(to_pack: tf.Tensor, type_spec: computation_types.Type):
)

type_signature = computation_types.FunctionType(type_signature, result_type)
parameter_binding = pb.TensorFlow.Binding(
struct=pb.TensorFlow.StructBinding(
parameter_binding = computation_pb2.TensorFlow.Binding(
struct=computation_pb2.TensorFlow.StructBinding(
element=[operand_1_binding, operand_2_binding]
) # pytype: disable=wrong-arg-types
)
tensorflow = pb.TensorFlow(
tensorflow = computation_pb2.TensorFlow(
graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
parameter=parameter_binding,
result=result_binding,
Expand Down Expand Up @@ -500,7 +518,7 @@ def create_computation_for_py_fn(
)

type_signature = computation_types.FunctionType(parameter_type, result_type)
tensorflow = pb.TensorFlow(
tensorflow = computation_pb2.TensorFlow(
graph_def=serialization_utils.pack_graph_def(graph.as_graph_def()),
parameter=parameter_binding,
result=result_binding,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ py_library(
srcs = ["compiler.py"],
deps = [
":xla_serialization",
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/common_libs:py_typecheck",
"//tensorflow_federated/python/common_libs:structure",
"//tensorflow_federated/python/core/impl/compiler:local_computation_factory_base",
"//tensorflow_federated/python/core/impl/types:computation_types",
"//tensorflow_federated/python/core/impl/types:type_analysis",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,40 @@
from jax.lib import xla_client
import numpy as np

from tensorflow_federated.proto.v0 import computation_pb2
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.environments.xla_backend import xla_serialization
from tensorflow_federated.python.core.impl.compiler import local_computation_factory_base
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import type_analysis


class XlaComputationFactory(
local_computation_factory_base.LocalComputationFactory
):
class XlaComputationFactory:
"""An implementation of local computation factory for XLA computations."""

def __init__(self):
pass

def create_constant_from_scalar(
self, value, type_spec: computation_types.Type
) -> local_computation_factory_base.ComputationProtoAndType:
) -> tuple[computation_pb2.Computation, computation_types.Type]:
"""Creates a TFF computation returning a constant based on a scalar value.
The returned computation has the type signature `( -> T)`, where `T` may be
either a scalar, or a nested structure made up of scalars.
Args:
value: A numpy scalar representing the value to return from the
constructed computation (or to broadcast to all parts of a nested
structure if `type_spec` is a structured type).
type_spec: A `computation_types.Type` of the constructed constant. Must be
either a tensor, or a nested structure of tensors.
Returns:
A tuple `(pb.Computation, computation_types.Type)` with the first element
being a TFF computation with semantics as described above, and the second
element representing the formal type of that computation.
"""
py_typecheck.check_type(type_spec, computation_types.Type)
if not type_analysis.is_structure_of_tensors(type_spec):
raise ValueError(
Expand Down
9 changes: 0 additions & 9 deletions tensorflow_federated/python/core/impl/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -218,15 +218,6 @@ py_test(
deps = [":intrinsic_defs"],
)

py_library(
name = "local_computation_factory_base",
srcs = ["local_computation_factory_base.py"],
deps = [
"//tensorflow_federated/proto/v0:computation_py_pb2",
"//tensorflow_federated/python/core/impl/types:computation_types",
],
)

py_library(
name = "transformations",
srcs = ["transformations.py"],
Expand Down

This file was deleted.

0 comments on commit 7ee3768

Please sign in to comment.