Skip to content

Commit

Permalink
Remove unused aggregators from tff.learning.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 687336045
  • Loading branch information
ZacharyGarrett authored and copybara-github committed Oct 18, 2024
1 parent 956b6c7 commit ad8c378
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 613 deletions.
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ and this project adheres to
### Removed

* `tff.types.tensorflow_to_type`, this function is no longer used.
* `tff.learning.dp_aggregator` removed. Prefer using the class methods on
`tff.aggregators.DifferentiallyPrivateFactory`.
* `tff.learning.ddp_secure_aggregator` and `tff.learning.secure_aggregator`
removed.

## Release 0.88.0

Expand Down
4 changes: 0 additions & 4 deletions tensorflow_federated/python/learning/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,6 @@ py_library(
name = "model_update_aggregator",
srcs = ["model_update_aggregator.py"],
deps = [
"//tensorflow_federated/python/aggregators:differential_privacy",
"//tensorflow_federated/python/aggregators:distributed_dp",
"//tensorflow_federated/python/aggregators:encoded",
"//tensorflow_federated/python/aggregators:factory",
"//tensorflow_federated/python/aggregators:mean",
"//tensorflow_federated/python/aggregators:quantile_estimation",
Expand All @@ -88,7 +85,6 @@ py_test(
"//tensorflow_federated/python/core/impl/types:type_analysis",
"//tensorflow_federated/python/core/templates:aggregation_process",
"//tensorflow_federated/python/core/templates:iterative_process",
"//tensorflow_federated/python/core/test:static_assert",
],
)

Expand Down
4 changes: 0 additions & 4 deletions tensorflow_federated/python/learning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,4 @@
from tensorflow_federated.python.learning.debug_measurements import add_debug_measurements
from tensorflow_federated.python.learning.debug_measurements import add_debug_measurements_with_mixed_dtype
from tensorflow_federated.python.learning.loop_builder import LoopImplementation
from tensorflow_federated.python.learning.model_update_aggregator import compression_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import ddp_secure_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import dp_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import robust_aggregator
from tensorflow_federated.python.learning.model_update_aggregator import secure_aggregator
9 changes: 0 additions & 9 deletions tensorflow_federated/python/learning/algorithms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,9 @@ py_cpu_gpu_test(
deps = [
":fed_avg",
"//tensorflow_federated/python/aggregators:factory_utils",
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:model_examples",
"//tensorflow_federated/python/learning/models:test_models",
"//tensorflow_federated/python/learning/optimizers:sgdm",
],
)
Expand Down Expand Up @@ -118,7 +115,6 @@ py_cpu_gpu_test(
shard_count = 10,
deps = [
":fed_avg_with_optimizer_schedule",
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning/metrics:aggregator",
Expand Down Expand Up @@ -164,10 +160,8 @@ py_cpu_gpu_test(
":fed_prox",
"//tensorflow_federated/python/aggregators:factory_utils",
"//tensorflow_federated/python/core/templates:iterative_process",
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/models:model_examples",
"//tensorflow_federated/python/learning/models:model_weights",
"//tensorflow_federated/python/learning/models:test_models",
Expand Down Expand Up @@ -330,7 +324,6 @@ py_cpu_gpu_test(
deps = [
":fed_sgd",
"//tensorflow_federated/python/core/environments/tensorflow_backend:tensorflow_test_utils",
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning/metrics:aggregator",
Expand Down Expand Up @@ -429,11 +422,9 @@ py_cpu_gpu_test(
"//tensorflow_federated/python/core/impl/types:placements",
"//tensorflow_federated/python/core/templates:iterative_process",
"//tensorflow_federated/python/core/templates:measured_process",
"//tensorflow_federated/python/core/test:static_assert",
"//tensorflow_federated/python/learning:client_weight_lib",
"//tensorflow_federated/python/learning:loop_builder",
"//tensorflow_federated/python/learning:model_update_aggregator",
"//tensorflow_federated/python/learning/metrics:aggregator",
"//tensorflow_federated/python/learning/metrics:counters",
"//tensorflow_federated/python/learning/models:functional",
"//tensorflow_federated/python/learning/models:keras_utils",
Expand Down
66 changes: 3 additions & 63 deletions tensorflow_federated/python/learning/algorithms/fed_avg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,25 @@
from absl.testing import parameterized

from tensorflow_federated.python.aggregators import factory_utils
from tensorflow_federated.python.core.test import static_assert
from tensorflow_federated.python.learning import loop_builder
from tensorflow_federated.python.learning import model_update_aggregator
from tensorflow_federated.python.learning.algorithms import fed_avg
from tensorflow_federated.python.learning.metrics import aggregator
from tensorflow_federated.python.learning.models import model_examples
from tensorflow_federated.python.learning.models import test_models
from tensorflow_federated.python.learning.optimizers import sgdm


class FedAvgTest(parameterized.TestCase):
"""Tests construction of the FedAvg training process."""

@parameterized.product(
optimizer_fn=[
sgdm.build_sgdm(learning_rate=0.1),
],
aggregation_factory=[
model_update_aggregator.robust_aggregator,
model_update_aggregator.compression_aggregator,
model_update_aggregator.secure_aggregator,
],
)
def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory):
def test_construction_calls_model_fn(self):
# Assert that the process building does not call `model_fn` too many times.
# `model_fn` can potentially be expensive (loading weights, processing, etc
# ).
mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression)
fed_avg.build_weighted_fed_avg(
model_fn=mock_model_fn,
client_optimizer_fn=optimizer_fn,
model_aggregator=aggregation_factory(),
client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1),
model_aggregator=model_update_aggregator.robust_aggregator(),
)
self.assertEqual(mock_model_fn.call_count, 3)

Expand Down Expand Up @@ -125,34 +112,6 @@ def test_unweighted_fed_avg_raises_on_weighted_aggregator(self):
model_aggregator=model_aggregator,
)

def test_weighted_fed_avg_with_only_secure_aggregation(self):
model_fn = model_examples.LinearRegression
learning_process = fed_avg.build_weighted_fed_avg(
model_fn,
client_optimizer_fn=sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=True
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)

def test_unweighted_fed_avg_with_only_secure_aggregation(self):
model_fn = model_examples.LinearRegression
learning_process = fed_avg.build_unweighted_fed_avg(
model_fn,
client_optimizer_fn=sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=False
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)


class FunctionalFedAvgTest(parameterized.TestCase):
"""Tests construction of the FedAvg training process."""
Expand All @@ -167,25 +126,6 @@ def test_raises_on_non_callable_or_functional_model(self, constructor):
model_fn=0, client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1)
)

@parameterized.named_parameters(
('weighted', fed_avg.build_weighted_fed_avg),
('unweighted', fed_avg.build_unweighted_fed_avg),
)
def test_weighted_fed_avg_with_only_secure_aggregation(self, constructor):
model = test_models.build_functional_linear_regression()
learning_process = constructor(
model_fn=model,
client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1),
server_optimizer_fn=sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=constructor is fed_avg.build_weighted_fed_avg
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)


if __name__ == '__main__':
absltest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from absl.testing import parameterized
import tensorflow as tf

from tensorflow_federated.python.core.test import static_assert
from tensorflow_federated.python.learning import loop_builder
from tensorflow_federated.python.learning import model_update_aggregator
from tensorflow_federated.python.learning.algorithms import fed_avg_with_optimizer_schedule
Expand All @@ -30,17 +29,7 @@

class ClientScheduledFedAvgTest(parameterized.TestCase):

@parameterized.product(
optimizer_fn=[
lambda x: sgdm.build_sgdm(learning_rate=x),
],
aggregation_factory=[
model_update_aggregator.robust_aggregator,
model_update_aggregator.compression_aggregator,
model_update_aggregator.secure_aggregator,
],
)
def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory):
def test_construction_calls_model_fn(self):
# Assert that the process building does not call `model_fn` too many times.
# `model_fn` can potentially be expensive (loading weights, processing, etc
# ).
Expand All @@ -49,8 +38,8 @@ def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory):
fed_avg_with_optimizer_schedule.build_weighted_fed_avg_with_optimizer_schedule(
model_fn=mock_model_fn,
client_learning_rate_fn=learning_rate_fn,
client_optimizer_fn=optimizer_fn,
model_aggregator=aggregation_factory(),
client_optimizer_fn=lambda lr: sgdm.build_sgdm(learning_rate=lr),
model_aggregator=model_update_aggregator.robust_aggregator(),
)
self.assertEqual(mock_model_fn.call_count, 3)

Expand Down Expand Up @@ -143,21 +132,6 @@ def test_raises_on_non_callable_model_fn(self):
client_optimizer_fn=lambda _: sgdm.build_sgdm(),
)

def test_construction_with_only_secure_aggregation(self):
model_fn = model_examples.LinearRegression
learning_process = fed_avg_with_optimizer_schedule.build_weighted_fed_avg_with_optimizer_schedule(
model_fn,
client_learning_rate_fn=lambda x: 0.5,
client_optimizer_fn=lambda x: sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=True
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)

def test_measurements_include_client_learning_rate(self):
client_work = fed_avg_with_optimizer_schedule.build_scheduled_client_work(
model_fn=model_examples.LinearRegression,
Expand Down
48 changes: 3 additions & 45 deletions tensorflow_federated/python/learning/algorithms/fed_prox_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,9 @@

from tensorflow_federated.python.aggregators import factory_utils
from tensorflow_federated.python.core.templates import iterative_process
from tensorflow_federated.python.core.test import static_assert
from tensorflow_federated.python.learning import loop_builder
from tensorflow_federated.python.learning import model_update_aggregator
from tensorflow_federated.python.learning.algorithms import fed_prox
from tensorflow_federated.python.learning.metrics import aggregator
from tensorflow_federated.python.learning.models import model_examples
from tensorflow_federated.python.learning.models import model_weights
from tensorflow_federated.python.learning.models import test_models
Expand All @@ -34,26 +32,16 @@
class FedProxConstructionTest(parameterized.TestCase):
"""Tests construction of the FedProx training process."""

@parameterized.product(
optimizer_fn=[
sgdm.build_sgdm(learning_rate=0.1),
],
aggregation_factory=[
model_update_aggregator.robust_aggregator,
model_update_aggregator.compression_aggregator,
model_update_aggregator.secure_aggregator,
],
)
def test_construction_calls_model_fn(self, optimizer_fn, aggregation_factory):
def test_construction_calls_model_fn(self):
# Assert that the process building does not call `model_fn` too many times.
# `model_fn` can potentially be expensive (loading weights, processing, etc
# ).
mock_model_fn = mock.Mock(side_effect=model_examples.LinearRegression)
fed_prox.build_weighted_fed_prox(
model_fn=mock_model_fn,
proximal_strength=1.0,
client_optimizer_fn=optimizer_fn,
model_aggregator=aggregation_factory(),
client_optimizer_fn=sgdm.build_sgdm(learning_rate=0.1),
model_aggregator=model_update_aggregator.robust_aggregator(),
)
self.assertEqual(mock_model_fn.call_count, 3)

Expand Down Expand Up @@ -160,36 +148,6 @@ def test_unweighted_fed_avg_raises_on_weighted_aggregator(self):
model_aggregator=model_aggregator,
)

def test_weighted_fed_prox_with_only_secure_aggregation(self):
model_fn = model_examples.LinearRegression
learning_process = fed_prox.build_weighted_fed_prox(
model_fn,
proximal_strength=1.0,
client_optimizer_fn=sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=True
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)

def test_unweighted_fed_prox_with_only_secure_aggregation(self):
model_fn = model_examples.LinearRegression
learning_process = fed_prox.build_unweighted_fed_prox(
model_fn,
proximal_strength=1.0,
client_optimizer_fn=sgdm.build_sgdm(),
model_aggregator=model_update_aggregator.secure_aggregator(
weighted=False
),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)


if __name__ == '__main__':
absltest.main()
28 changes: 0 additions & 28 deletions tensorflow_federated/python/learning/algorithms/fed_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import tensorflow as tf

from tensorflow_federated.python.core.environments.tensorflow_backend import tensorflow_test_utils
from tensorflow_federated.python.core.test import static_assert
from tensorflow_federated.python.learning import loop_builder
from tensorflow_federated.python.learning import model_update_aggregator
from tensorflow_federated.python.learning.algorithms import fed_sgd
Expand Down Expand Up @@ -160,11 +159,6 @@ def test_client_tf_dataset_reduce_fn(self, loop_implementation, mock_method):

@parameterized.named_parameters(
('robust_aggregator', model_update_aggregator.robust_aggregator),
(
'compression_aggregator',
model_update_aggregator.compression_aggregator,
),
('secure_aggreagtor', model_update_aggregator.secure_aggregator),
)
def test_construction_calls_model_fn(self, aggregation_factory):
# Assert that the process building does not call `model_fn` too many times.
Expand All @@ -177,17 +171,6 @@ def test_construction_calls_model_fn(self, aggregation_factory):
# TODO: b/186451541 - reduce the number of calls to model_fn.
self.assertEqual(mock_model_fn.call_count, 3)

def test_no_unsecure_aggregation_with_secure_aggregator(self):
model_fn = model_examples.LinearRegression
learning_process = fed_sgd.build_fed_sgd(
model_fn,
model_aggregator=model_update_aggregator.secure_aggregator(),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)


class FunctionalFederatedSgdTest(tf.test.TestCase, parameterized.TestCase):

Expand Down Expand Up @@ -276,17 +259,6 @@ def test_build_functional_fed_sgd_succeeds(self):
model = _build_functional_model()
fed_sgd.build_fed_sgd(model_fn=model)

def test_no_unsecure_aggregation_with_secure_aggregator(self):
model = _build_functional_model()
learning_process = fed_sgd.build_fed_sgd(
model,
model_aggregator=model_update_aggregator.secure_aggregator(),
metrics_aggregator=aggregator.secure_sum_then_finalize,
)
static_assert.assert_not_contains_unsecure_aggregation(
learning_process.next
)


if __name__ == '__main__':
tf.test.main()
Loading

0 comments on commit ad8c378

Please sign in to comment.