Skip to content

Commit

Permalink
Add list like stars to see if that improves the formatting on the web…
Browse files Browse the repository at this point in the history
…site.

PiperOrigin-RevId: 297442449
  • Loading branch information
yashk2810 authored and tensorflower-gardener committed Feb 26, 2020
1 parent 147cbf4 commit 08dd1e6
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions tensorflow_estimator/python/estimator/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ class Estimator(object):
constructor enforces this). Subclasses should use `model_fn` to configure
the base class, and may add methods implementing specialized functionality.
See [estimators](https://tensorflow.org/guide/estimators) for more
information.
To warm-start an `Estimator`:
```python
estimator = tf.estimator.DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
hidden_units=[1024, 512, 256],
warm_start_from="/path/to/checkpoint/dir")
```
For more details on warm-start configuration, see
`tf.estimator.WarmStartSettings`.
@compatibility(eager)
Calling methods of `Estimator` will work while eager execution is enabled.
However, the `model_fn` and `input_fn` is not executed eagerly, `Estimator`
Expand All @@ -114,42 +129,29 @@ def __init__(self,
warm_start_from=None):
"""Constructs an `Estimator` instance.
See [estimators](https://tensorflow.org/guide/estimators) for more
information.
To warm-start an `Estimator`:
```python
estimator = tf.estimator.DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
hidden_units=[1024, 512, 256],
warm_start_from="/path/to/checkpoint/dir")
```
For more details on warm-start configuration, see
`tf.estimator.WarmStartSettings`.
Args:
model_fn: Model function. Follows the signature:
`features` -- This is the first item returned from the `input_fn`
* `features` -- This is the first item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `tf.Tensor` or `dict` of same.
`labels` -- This is the second item returned from the `input_fn`
* `labels` -- This is the second item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `tf.Tensor` or `dict` of same (for multi-head models). If
mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will be
passed. If the `model_fn`'s signature does not accept `mode`, the
`model_fn` must still be able to handle `labels=None`.
`mode` -- Optional. Specifies if this is training, evaluation or
* `mode` -- Optional. Specifies if this is training, evaluation or
prediction. See `tf.estimator.ModeKeys`.
`params` -- Optional `dict` of hyperparameters. Will receive what is
passed to Estimator in `params` parameter. This allows to configure
Estimators from hyper parameter tuning.
`config` -- Optional `estimator.RunConfig` object. Will receive what
* `config` -- Optional `estimator.RunConfig` object. Will receive what
is passed to Estimator as its `config` parameter, or a default
value. Allows setting up things in your `model_fn` based on
configuration such as `num_ps_replicas`, or `model_dir`.
Returns -- `tf.estimator.EstimatorSpec`
* Returns -- `tf.estimator.EstimatorSpec`
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into an estimator to
continue training a previously saved model. If `PathLike` object, the
Expand Down Expand Up @@ -560,12 +562,12 @@ def predict(self,
https://tensorflow.org/guide/premade_estimators#create_input_functions)
for more information. The function should construct and return one of
the following:
`tf.data.Dataset` object -- Outputs of `Dataset` object must have
same constraints as below.
features -- A `tf.Tensor` or a dictionary of string feature name to
`Tensor`. features are consumed by `model_fn`. They should satisfy
the expectation of `model_fn` from inputs. * A tuple, in which case
the first item is extracted as features.
* `tf.data.Dataset` object -- Outputs of `Dataset` object must have
same constraints as below.
* features -- A `tf.Tensor` or a dictionary of string feature name to
`Tensor`. features are consumed by `model_fn`. They should satisfy
the expectation of `model_fn` from inputs. * A tuple, in which case
the first item is extracted as features.
predict_keys: list of `str`, name of the keys to predict. It is used if
the `tf.estimator.EstimatorSpec.predictions` is a `dict`. If
`predict_keys` is used then rest of the predictions will be filtered
Expand Down

0 comments on commit 08dd1e6

Please sign in to comment.