Skip to content

Commit

Permalink
* Rename gumbel_sample to logsoftmax_sample for clarity.
Browse files Browse the repository at this point in the history
* Add the forgotten decoding import in supervised/__init__.
* Allow to access decoding.autoregressive_sample in streaming mode.

PiperOrigin-RevId: 323177464
  • Loading branch information
Lukasz Kaiser authored and copybara-github committed Jul 25, 2020
1 parent 88b033c commit a2497cb
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 32 deletions.
8 changes: 4 additions & 4 deletions trax/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,12 +378,12 @@ def multigaussian_loss(preds, targets, ngauss=1): # pylint: disable=invalid-nam
return fastmath.logsumexp(loglogits + glogprobs, axis=-1)


def gumbel_sample(log_probs, temperature=1.0): # pylint: disable=invalid-name
"""Returns a Gumbel sample from a categorical distribution, with temperature.
def logsoftmax_sample(log_probs, temperature=1.0): # pylint: disable=invalid-name
"""Returns a sample from a log-softmax output, with temperature.
Args:
log_probs: <tbd>
temperature: <tbd>
log_probs: Logarithms of probabilities (often coming from LogSofmax)
temperature: For scaling before sampling (1.0 = default, 0.0 = pick argmax)
"""
# This is equivalent to sampling from a softmax with temperature.
u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape)
Expand Down
6 changes: 3 additions & 3 deletions trax/rl/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@ def _unflatten_inputs(self, inputs):
)

def sample(self, inputs, temperature=1.0):
# No need for LogSoftmax with Gumbel sampling - softmax normalization is
# subtracting a constant from every logit, and Gumbel sampling is taking
# No need for LogSoftmax with sampling - softmax normalization is
# subtracting a constant from every logit, and sampling is taking
# a max over logits plus noise, so invariant to adding a constant.
if temperature == 0.0:
return jnp.argmax(self._unflatten_inputs(inputs), axis=-1)
return tl.gumbel_sample(self._unflatten_inputs(inputs), temperature)
return tl.logsoftmax_sample(self._unflatten_inputs(inputs), temperature)

def log_prob(self, inputs, point):
inputs = tl.LogSoftmax()(self._unflatten_inputs(inputs))
Expand Down
1 change: 1 addition & 0 deletions trax/supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

"""Supervised learning imports in Trax."""

from trax.supervised import decoding
from trax.supervised import lr_schedules
from trax.supervised import trainer_lib
from trax.supervised import training
Expand Down
76 changes: 55 additions & 21 deletions trax/supervised/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,51 @@
from trax import layers as tl


def autoregressive_sample(model, prefix=None, inputs=None,
def autoregressive_sample_stream(model, inputs=None,
batch_size=1, temperature=1.0,
start_id=0, accelerate=True):
"""Stream aturegressive samples from the provided model.
Note that the provided model should be an autoregressive model initialized
in 'predict' mode. In this mode, a model takes the outputs it is generating
one-by-one (instead of taking them all at once, as, e.g., during training).
Model state is used to store the intermediate information needed, and usually
the model perfoms inference in this mode faster than in 'eval' mode.
Args:
model: instance of trax.Layer, the model to sample from (at mode='predict')
inputs: optional tensor [batch_size, M]: inputs to provide to the model;
for language models (with n_in=1) we use inputs as prefix to the model
batch_size: how many batches to sample (default: 1)
temperature: sampling temperature (default: 1.0)
start_id: int, id for the start symbol fed at the beginning (default: 1)
accelerate: whether to accelerate the model before decoding (default: True)
Yields:
Tensor of ints of shape [batch_size] containing subsequent
autoregressive samples from the model.
"""
if inputs is not None and inputs.shape[0] != batch_size:
raise ValueError(f'Inputs batch size {inputs.shape[0]} != {batch_size}.')
fast_model = tl.Accelerate(model) if accelerate else model
cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
if inputs is not None and model.n_in == 1: # use inputs as prefix
cur_symbol = np.concatenate([cur_symbol, inputs], axis=1)
while True:
model_input = cur_symbol
if inputs is not None and model.n_in > 1:
model_input = (inputs, cur_symbol)
logits = fast_model(model_input)
if inputs is not None and model.n_in > 1:
logits = logits[0] # Pick first element from model output (a pair here)
sample = tl.logsoftmax_sample(logits[:, -1, :], temperature=temperature)
yield sample
# Note: we're using 'predict' mode autoregressive models here, so history
# is caches in the model state and we are only feeding one symbol next.
cur_symbol = sample[:, None]


def autoregressive_sample(model, inputs=None,
batch_size=1, temperature=1.0,
start_id=0, eos_id=1, max_length=100,
accelerate=True):
Expand All @@ -33,8 +77,8 @@ def autoregressive_sample(model, prefix=None, inputs=None,
Args:
model: instance of trax.Layer, the model to sample from (at mode='predict')
prefix: optional tensor [batch_size, L]: prefix for decoding
inputs: optional tensor [batch_size, M]: inputs to provide to the model
inputs: optional tensor [batch_size, M]: inputs to provide to the model;
for language models (with n_in=1) we use inputs as prefix to the model
batch_size: how many batches to sample (default: 1)
temperature: sampling temperature (default: 1.0)
start_id: int, id for the start symbol fed at the beginning (default: 1)
Expand All @@ -46,32 +90,22 @@ def autoregressive_sample(model, prefix=None, inputs=None,
a tensor of ints of shape [batch_size, N] with N <= max_length containing
the autoregressively sampled output from the model
"""
if prefix is not None and prefix.shape[0] != batch_size:
raise ValueError(f'Prefix batch size {prefix.shape[0]} != {batch_size}.')
if inputs is not None and inputs.shape[0] != batch_size:
raise ValueError(f'Inputs batch size {inputs.shape[0]} != {batch_size}.')
fast_model = tl.Accelerate(model) if accelerate else model
cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32)
if prefix is not None:
cur_symbol = np.concatenate([cur_symbol, prefix], axis=1)
result = []
eos_seen = []
for _ in range(max_length):
model_input = cur_symbol if inputs is None else (inputs, cur_symbol)
logits = fast_model(model_input)
if inputs is not None:
logits = logits[0] # Pick first element from model output (a pair here)
sample = tl.gumbel_sample(logits[:, -1, :], temperature=temperature)
counter = 0
for sample in autoregressive_sample_stream(
model, inputs, batch_size=batch_size, temperature=temperature,
start_id=start_id, accelerate=accelerate):
sample = sample[:, None]
result.append(sample)
# Note: we're using 'predict' mode autoregressive models here, so history
# is caches in the model state and we are only feeding one symbol next.
cur_symbol = sample
counter += 1
if counter >= max_length:
return np.concatenate(result, axis=1)
# Check at which batch positions have we already encountered EOS.
for j in range(batch_size):
if int(sample[j, 0]) == eos_id:
eos_seen.append(j)
# If EOS has been seen on all positions, stop.
if all([j in eos_seen for j in range(batch_size)]):
break
return np.concatenate(result, axis=1)
return np.concatenate(result, axis=1)
8 changes: 4 additions & 4 deletions trax/supervised/decoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def test_autoregressive_sample_transformerlm(self):
self.assertLess(s2.shape[1], 11)
model.init(shapes.ShapeDtype((1, 1), dtype=np.int32))
prefix = np.array([[1, 2, 3]])
s3 = decoding.autoregressive_sample(model, eos_id=-1, max_length=10,
batch_size=1, prefix=prefix)
s3 = decoding.autoregressive_sample(model, prefix, eos_id=-1, max_length=10,
batch_size=1)
self.assertEqual(s3.shape[0], 1)
self.assertEqual(s3.shape[1], 10)

Expand Down Expand Up @@ -131,7 +131,7 @@ def test_autoregressive_sample_transformerlm_quality(self):
pred_model.init_from_file(model_path, weights_only=True,
input_signature=(shape11, shape11))
inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32)
s = decoding.autoregressive_sample(pred_model, prefix=inputs,
s = decoding.autoregressive_sample(pred_model, inputs,
max_length=6, temperature=0.0)
self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')

Expand All @@ -146,7 +146,7 @@ def test_autoregressive_sample_reformerlm_quality(self):
pred_model.init_from_file(model_path, weights_only=True,
input_signature=(shape11, shape11))
inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32)
s = decoding.autoregressive_sample(pred_model, prefix=inputs,
s = decoding.autoregressive_sample(pred_model, inputs,
max_length=6, temperature=0.0)
self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')

Expand Down

0 comments on commit a2497cb

Please sign in to comment.