diff --git a/official/common/streamz_counters.py b/official/common/streamz_counters.py new file mode 100644 index 00000000000..ab3df36ce60 --- /dev/null +++ b/official/common/streamz_counters.py @@ -0,0 +1,27 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Global streamz counters.""" + +from tensorflow.python.eager import monitoring + + +progressive_policy_creation_counter = monitoring.Counter( + "/tensorflow/training/fast_training/progressive_policy_creation", + "Counter for the number of ProgressivePolicy creations.") + + +stack_vars_to_vars_call_counter = monitoring.Counter( + "/tensorflow/training/fast_training/tf_vars_to_vars", + "Counter for the number of low-level stacking API calls.") diff --git a/official/modeling/fast_training/progressive/policies.py b/official/modeling/fast_training/progressive/policies.py index f37bfbf0df3..de4006432c8 100644 --- a/official/modeling/fast_training/progressive/policies.py +++ b/official/modeling/fast_training/progressive/policies.py @@ -25,16 +25,11 @@ import six import tensorflow as tf -from tensorflow.python.eager import monitoring +from official.common import streamz_counters from official.modeling.fast_training.progressive import utils from official.modeling.hyperparams import base_config -_progressive_policy_creation_counter = monitoring.Counter( - '/tensorflow/training/fast_training/progressive_policy_creation', - 'Counter for the number of ProgressivePolicy creations.') - - @dataclasses.dataclass class ProgressiveConfig(base_config.Config): pass @@ -76,7 +71,8 @@ def __init__(self): optimizer=self.get_optimizer(stage_id), model=self.get_model(stage_id, old_model=None)) - _progressive_policy_creation_counter.get_cell().increase_by(1) + streamz_counters.progressive_policy_creation_counter.get_cell( + ).increase_by(1) def compute_stage_id(self, global_step: int) -> int: for stage_id in range(self.num_stages()):