diff --git a/trax/layers/core.py b/trax/layers/core.py index d17bc82eb..fb114f9d3 100644 --- a/trax/layers/core.py +++ b/trax/layers/core.py @@ -378,12 +378,12 @@ def multigaussian_loss(preds, targets, ngauss=1): # pylint: disable=invalid-nam return fastmath.logsumexp(loglogits + glogprobs, axis=-1) -def gumbel_sample(log_probs, temperature=1.0): # pylint: disable=invalid-name - """Returns a Gumbel sample from a categorical distribution, with temperature. +def logsoftmax_sample(log_probs, temperature=1.0): # pylint: disable=invalid-name + """Returns a sample from a log-softmax output, with temperature. Args: - log_probs: - temperature: + log_probs: Logarithms of probabilities (often coming from LogSofmax) + temperature: For scaling before sampling (1.0 = default, 0.0 = pick argmax) """ # This is equivalent to sampling from a softmax with temperature. u = np.random.uniform(low=1e-6, high=1.0 - 1e-6, size=log_probs.shape) diff --git a/trax/rl/distributions.py b/trax/rl/distributions.py index 9cb4ed5f9..0002fab4b 100644 --- a/trax/rl/distributions.py +++ b/trax/rl/distributions.py @@ -93,12 +93,12 @@ def _unflatten_inputs(self, inputs): ) def sample(self, inputs, temperature=1.0): - # No need for LogSoftmax with Gumbel sampling - softmax normalization is - # subtracting a constant from every logit, and Gumbel sampling is taking + # No need for LogSoftmax with sampling - softmax normalization is + # subtracting a constant from every logit, and sampling is taking # a max over logits plus noise, so invariant to adding a constant. if temperature == 0.0: return jnp.argmax(self._unflatten_inputs(inputs), axis=-1) - return tl.gumbel_sample(self._unflatten_inputs(inputs), temperature) + return tl.logsoftmax_sample(self._unflatten_inputs(inputs), temperature) def log_prob(self, inputs, point): inputs = tl.LogSoftmax()(self._unflatten_inputs(inputs)) diff --git a/trax/supervised/__init__.py b/trax/supervised/__init__.py index 51b6c8995..2753623ee 100644 --- a/trax/supervised/__init__.py +++ b/trax/supervised/__init__.py @@ -15,6 +15,7 @@ """Supervised learning imports in Trax.""" +from trax.supervised import decoding from trax.supervised import lr_schedules from trax.supervised import trainer_lib from trax.supervised import training diff --git a/trax/supervised/decoding.py b/trax/supervised/decoding.py index ed4611d7d..5709ecc05 100644 --- a/trax/supervised/decoding.py +++ b/trax/supervised/decoding.py @@ -19,7 +19,51 @@ from trax import layers as tl -def autoregressive_sample(model, prefix=None, inputs=None, +def autoregressive_sample_stream(model, inputs=None, + batch_size=1, temperature=1.0, + start_id=0, accelerate=True): + """Stream aturegressive samples from the provided model. + + Note that the provided model should be an autoregressive model initialized + in 'predict' mode. In this mode, a model takes the outputs it is generating + one-by-one (instead of taking them all at once, as, e.g., during training). + Model state is used to store the intermediate information needed, and usually + the model perfoms inference in this mode faster than in 'eval' mode. + + Args: + model: instance of trax.Layer, the model to sample from (at mode='predict') + inputs: optional tensor [batch_size, M]: inputs to provide to the model; + for language models (with n_in=1) we use inputs as prefix to the model + batch_size: how many batches to sample (default: 1) + temperature: sampling temperature (default: 1.0) + start_id: int, id for the start symbol fed at the beginning (default: 1) + accelerate: whether to accelerate the model before decoding (default: True) + + Yields: + Tensor of ints of shape [batch_size] containing subsequent + autoregressive samples from the model. + """ + if inputs is not None and inputs.shape[0] != batch_size: + raise ValueError(f'Inputs batch size {inputs.shape[0]} != {batch_size}.') + fast_model = tl.Accelerate(model) if accelerate else model + cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) + if inputs is not None and model.n_in == 1: # use inputs as prefix + cur_symbol = np.concatenate([cur_symbol, inputs], axis=1) + while True: + model_input = cur_symbol + if inputs is not None and model.n_in > 1: + model_input = (inputs, cur_symbol) + logits = fast_model(model_input) + if inputs is not None and model.n_in > 1: + logits = logits[0] # Pick first element from model output (a pair here) + sample = tl.logsoftmax_sample(logits[:, -1, :], temperature=temperature) + yield sample + # Note: we're using 'predict' mode autoregressive models here, so history + # is caches in the model state and we are only feeding one symbol next. + cur_symbol = sample[:, None] + + +def autoregressive_sample(model, inputs=None, batch_size=1, temperature=1.0, start_id=0, eos_id=1, max_length=100, accelerate=True): @@ -33,8 +77,8 @@ def autoregressive_sample(model, prefix=None, inputs=None, Args: model: instance of trax.Layer, the model to sample from (at mode='predict') - prefix: optional tensor [batch_size, L]: prefix for decoding - inputs: optional tensor [batch_size, M]: inputs to provide to the model + inputs: optional tensor [batch_size, M]: inputs to provide to the model; + for language models (with n_in=1) we use inputs as prefix to the model batch_size: how many batches to sample (default: 1) temperature: sampling temperature (default: 1.0) start_id: int, id for the start symbol fed at the beginning (default: 1) @@ -46,32 +90,22 @@ def autoregressive_sample(model, prefix=None, inputs=None, a tensor of ints of shape [batch_size, N] with N <= max_length containing the autoregressively sampled output from the model """ - if prefix is not None and prefix.shape[0] != batch_size: - raise ValueError(f'Prefix batch size {prefix.shape[0]} != {batch_size}.') - if inputs is not None and inputs.shape[0] != batch_size: - raise ValueError(f'Inputs batch size {inputs.shape[0]} != {batch_size}.') - fast_model = tl.Accelerate(model) if accelerate else model - cur_symbol = np.full((batch_size, 1), start_id, dtype=np.int32) - if prefix is not None: - cur_symbol = np.concatenate([cur_symbol, prefix], axis=1) result = [] eos_seen = [] - for _ in range(max_length): - model_input = cur_symbol if inputs is None else (inputs, cur_symbol) - logits = fast_model(model_input) - if inputs is not None: - logits = logits[0] # Pick first element from model output (a pair here) - sample = tl.gumbel_sample(logits[:, -1, :], temperature=temperature) + counter = 0 + for sample in autoregressive_sample_stream( + model, inputs, batch_size=batch_size, temperature=temperature, + start_id=start_id, accelerate=accelerate): sample = sample[:, None] result.append(sample) - # Note: we're using 'predict' mode autoregressive models here, so history - # is caches in the model state and we are only feeding one symbol next. - cur_symbol = sample + counter += 1 + if counter >= max_length: + return np.concatenate(result, axis=1) # Check at which batch positions have we already encountered EOS. for j in range(batch_size): if int(sample[j, 0]) == eos_id: eos_seen.append(j) # If EOS has been seen on all positions, stop. if all([j in eos_seen for j in range(batch_size)]): - break + return np.concatenate(result, axis=1) return np.concatenate(result, axis=1) diff --git a/trax/supervised/decoding_test.py b/trax/supervised/decoding_test.py index 4a9b3b6d1..10e876279 100644 --- a/trax/supervised/decoding_test.py +++ b/trax/supervised/decoding_test.py @@ -53,8 +53,8 @@ def test_autoregressive_sample_transformerlm(self): self.assertLess(s2.shape[1], 11) model.init(shapes.ShapeDtype((1, 1), dtype=np.int32)) prefix = np.array([[1, 2, 3]]) - s3 = decoding.autoregressive_sample(model, eos_id=-1, max_length=10, - batch_size=1, prefix=prefix) + s3 = decoding.autoregressive_sample(model, prefix, eos_id=-1, max_length=10, + batch_size=1) self.assertEqual(s3.shape[0], 1) self.assertEqual(s3.shape[1], 10) @@ -131,7 +131,7 @@ def test_autoregressive_sample_transformerlm_quality(self): pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape11, shape11)) inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) - s = decoding.autoregressive_sample(pred_model, prefix=inputs, + s = decoding.autoregressive_sample(pred_model, inputs, max_length=6, temperature=0.0) self.assertEqual(str(s[0]), '[3 7 5 3 2 4]') @@ -146,7 +146,7 @@ def test_autoregressive_sample_reformerlm_quality(self): pred_model.init_from_file(model_path, weights_only=True, input_signature=(shape11, shape11)) inputs = np.array([[0, 3, 7, 5, 3, 2, 4, 0]], dtype=np.int32) - s = decoding.autoregressive_sample(pred_model, prefix=inputs, + s = decoding.autoregressive_sample(pred_model, inputs, max_length=6, temperature=0.0) self.assertEqual(str(s[0]), '[3 7 5 3 2 4]')