From f03fe723cc90a1991503c13031154c940174f571 Mon Sep 17 00:00:00 2001 From: chandramouli Date: Wed, 21 Aug 2019 08:34:28 +0530 Subject: [PATCH] Added forward_features inside Estimator class of Adanet. --- adanet/core/estimator.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/adanet/core/estimator.py b/adanet/core/estimator.py index 9d600abc..ebf30e49 100644 --- a/adanet/core/estimator.py +++ b/adanet/core/estimator.py @@ -450,6 +450,8 @@ def __init__(self, enable_subnetwork_summaries=True, global_step_combiner_fn=tf.math.reduce_mean, max_iterations=None, + forward_feature_keys=None, + sparse_default_values=None, **kwargs): if subnetwork_generator is None: raise ValueError("subnetwork_generator can't be None.") @@ -596,6 +598,18 @@ def _latest_checkpoint_global_step(self): return tf.train.load_variable(latest_checkpoint, tf_compat.v1.GraphKeys.GLOBAL_STEP) + def forward_features(self, + estimator, + forward_feature_keys=None, + sparse_default_values=None): + if forward_feeature_keys is not None: + result = tf.contrib.estimator.forward_features(super(Estimator, self), + forward_feature_keys=forward_feature_keys, + sparse_default_values=sparse_default_values) + return result + else: + raise ValueError("No key/(s) provided to forward features.") + def train(self, input_fn, hooks=None,