diff --git a/.github/stale.yml b/.github/stale.yml new file mode 100644 index 0000000000..b12867dab0 --- /dev/null +++ b/.github/stale.yml @@ -0,0 +1,30 @@ +# Configuration for probot-stale - https://github.com/probot/stale +# Mostly copied from github.com/facebook/react/blob/master/.github/stale.yml +# Number of days of inactivity before an issue becomes stale +daysUntilStale: 90 +# Number of days of inactivity before a stale issue is closed +daysUntilClose: 7 +# Issues with these labels will never be considered stale +exemptLabels: + - bug +# Label to use when marking an issue as stale +staleLabel: stale +issues: + # Comment to post when marking an issue as stale. + markComment: > + This issue has been automatically marked as stale. + **If this issue is still affecting you, please leave any comment** (for example, "bump"), and we'll keep it open. + We are sorry that we haven't been able to prioritize it yet. If you have any new additional information, please include it with your comment! + # Comment to post when closing a stale issue. + closeComment: > + Closing this issue after a prolonged period of inactivity. If this issue is still present in the latest release, please create a new issue with up-to-date information. Thank you! +pulls: + # Comment to post when marking a pull request as stale. + markComment: > + This pull request has been automatically marked as stale. + **If this pull request is still relevant, please leave any comment** (for example, "bump"), and we'll keep it open. + We are sorry that we haven't been able to prioritize reviewing it yet. Your contribution is very much appreciated. + # Comment to post when closing a stale pull request. + closeComment: > + Closing this pull request after a prolonged period of inactivity. If this issue is still present in the latest release, please ask for this pull request to be reopened. Thank you! + diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6ae8093a8a..29e5254d33 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,26 +19,36 @@ jobs: runs-on: ${{ matrix.platform }} steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} + - name: Conditionally install pytorch if: matrix.platform == 'windows-latest' run: pip3 install torch -f https://download.pytorch.org/whl/torch_stable.html + - name: Install locally run: | python -m pip install --upgrade pip + git submodule update --init --recursive python setup.py build_ext --inplace python -m pip install --editable . + + - name: Install optional test requirements + run: | + python -m pip install fairscale iopath transformers + - name: Lint with flake8 run: | pip install flake8 # stop the build if there are Python syntax errors or undefined names - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --extend-exclude fairseq/model_parallel/megatron # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide - flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics + flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --extend-exclude fairseq/model_parallel/megatron + - name: Run tests run: | python setup.py test diff --git a/.github/workflows/build_wheels.yml b/.github/workflows/build_wheels.yml new file mode 100644 index 0000000000..7261708596 --- /dev/null +++ b/.github/workflows/build_wheels.yml @@ -0,0 +1,41 @@ +name: build_wheels + +on: + push: + branches: + - v[0-9]+.[0-9]+.[x0-9]+ + tags: + - v* + +jobs: + build_wheels: + name: Build wheels on ${{ matrix.os }} + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest, macos-latest] + + steps: + - uses: actions/checkout@v2 + + - name: Install Python + uses: actions/setup-python@v2 + with: + python-version: '3.7' + + - name: Install cibuildwheel + run: | + python -m pip install cibuildwheel + + - name: Build wheels for CPython + run: | + python -m cibuildwheel --output-dir dist + env: + CIBW_BUILD: "cp36-*64 cp37-*64 cp38-*64" + CIBW_MANYLINUX_X86_64_IMAGE: manylinux1 + CIBW_BEFORE_BUILD: git submodule update --init --recursive && pip install . + + - uses: actions/upload-artifact@v2 + with: + name: wheels + path: ./dist/*.whl diff --git a/.gitmodules b/.gitmodules index df0d3d3071..07a55d45d4 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,7 +1,3 @@ -[submodule "fairseq/models/huggingface/transformers"] - path = fairseq/models/huggingface/transformers - url = https://github.com/myleott/transformers.git - branch = fairseq [submodule "fairseq/model_parallel/megatron"] path = fairseq/model_parallel/megatron url = https://github.com/ngoyal2707/Megatron-LM diff --git a/README.md b/README.md index 0648da15f7..cc1c76ec36 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,8 @@ We provide reference implementations of various sequence modeling papers: + [Understanding Back-Translation at Scale (Edunov et al., 2018)](examples/backtranslation/README.md) + [Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)](examples/language_model/README.adaptive_inputs.md) + [Lexically constrained decoding with dynamic beam allocation (Post & Vilar, 2018)](examples/constrained_decoding/README.md) + + [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context (Dai et al., 2019)](examples/truncated_bptt/README.md) + + [Adaptive Attention Span in Transformers (Sukhbaatar et al., 2019)](examples/adaptive_span/README.md) + [Mixture Models for Diverse Machine Translation: Tricks of the Trade (Shen et al., 2019)](examples/translation_moe/README.md) + [RoBERTa: A Robustly Optimized BERT Pretraining Approach (Liu et al., 2019)](examples/roberta/README.md) + [Facebook FAIR's WMT19 News Translation Task Submission (Ng et al., 2019)](examples/wmt19/README.md) @@ -59,8 +61,10 @@ We provide reference implementations of various sequence modeling papers: ### What's New: -* November 2020: Adopted [Hydra](https://github.com/facebookresearch/hydra) as a configuration framework; -[added documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) +* December 2020: [GottBERT model and code released](examples/gottbert/README.md) +* November 2020: Adopted the [Hydra](https://github.com/facebookresearch/hydra) configuration framework + * [see documentation explaining how to use it for new and existing projects](docs/hydra_integration.md) +* November 2020: [fairseq 0.10.0 released](https://github.com/pytorch/fairseq/releases/tag/v0.10.0) * October 2020: [Added R3F/R4F (Better Fine-Tuning) code](examples/rxf/README.md) * October 2020: [Deep Transformer with Latent Depth code released](examples/latent_depth/README.md) * October 2020: [Added CRISS models and code](examples/criss/README.md) @@ -69,13 +73,13 @@ We provide reference implementations of various sequence modeling papers: * August 2020: [Added lexically constrained decoding](examples/constrained_decoding/README.md) * August 2020: [wav2vec2 models and code released](examples/wav2vec/README.md) * July 2020: [Unsupervised Quality Estimation code released](examples/unsupervised_quality_estimation/README.md) + +
Previous updates

+ * May 2020: [Follow fairseq on Twitter](https://twitter.com/fairseq) * April 2020: [Monotonic Multihead Attention code released](examples/simultaneous_translation/README.md) * April 2020: [Quant-Noise code released](examples/quant_noise/README.md) * April 2020: [Initial model parallel support and 11B parameters unidirectional LM released](examples/megatron_11b/README.md) - -

Previous updates

- * March 2020: [Byte-level BPE code released](examples/byte_level_bpe/README.md) * February 2020: [mBART model and code released](examples/mbart/README.md) * February 2020: [Added tutorial for back-translation](https://github.com/pytorch/fairseq/tree/master/examples/backtranslation#training-your-own-model-wmt18-english-german) @@ -99,10 +103,10 @@ We provide reference implementations of various sequence modeling papers: + beam search + Diverse Beam Search ([Vijayakumar et al., 2016](https://arxiv.org/abs/1610.02424)) + sampling (unconstrained, top-k and top-p/nucleus) - + lexically constrained decoding ([Post & Vilar, 2018](examples/constrained_decoding/README.md)) -* large mini-batch training even on a single GPU via delayed updates -* mixed precision training (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) -* extensible: easily register new models, criterions, tasks, optimizers and learning rate schedulers + + [lexically constrained decoding](examples/constrained_decoding/README.md) (Post & Vilar, 2018) +* [gradient accumulation](https://fairseq.readthedocs.io/en/latest/getting_started.html#large-mini-batch-training-with-delayed-updates) enables training with large mini-batches even on a single GPU +* [mixed precision training](https://fairseq.readthedocs.io/en/latest/getting_started.html#training-with-half-precision-floating-point-fp16) (trains faster with less GPU memory on [NVIDIA tensor cores](https://developer.nvidia.com/tensor-cores)) +* [extensible](https://fairseq.readthedocs.io/en/latest/overview.html): easily register new models, criterions, tasks, optimizers and learning rate schedulers * [flexible configuration](docs/hydra_integration.md) based on [Hydra](https://github.com/facebookresearch/hydra) allowing a combination of code, command-line and file based configuration We also provide [pre-trained models for translation and language modeling](#pre-trained-models-and-examples) @@ -131,6 +135,9 @@ pip install --editable ./ # on MacOS: # CFLAGS="-stdlib=libc++" pip install --editable ./ + +# to install the latest stable release (0.10.0) +# pip install fairseq==0.10.0 ``` * **For faster training** install NVIDIA's [apex](https://github.com/NVIDIA/apex) library: diff --git a/docs/getting_started.rst b/docs/getting_started.rst index d227b95544..745ad7763c 100644 --- a/docs/getting_started.rst +++ b/docs/getting_started.rst @@ -90,7 +90,7 @@ well for the IWSLT 2014 dataset: > mkdir -p checkpoints/fconv > CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt14.tokenized.de-en \ - --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \ + --optimizer nag --lr 0.25 --clip-norm 0.1 --dropout 0.2 --max-tokens 4000 \ --arch fconv_iwslt_de_en --save-dir checkpoints/fconv By default, :ref:`fairseq-train` will use all available GPUs on your machine. Use the @@ -182,9 +182,10 @@ sure to update ``--master_addr`` to the IP address of the first node: --arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ - --lr 0.0005 --min-lr 1e-09 \ + --lr 0.0005 \ --dropout 0.3 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --max-tokens 3584 \ + --max-epoch 70 \ --fp16 On SLURM clusters, fairseq will automatically detect the number of nodes and diff --git a/docs/hydra_integration.md b/docs/hydra_integration.md index f924de961b..04c797fe50 100644 --- a/docs/hydra_integration.md +++ b/docs/hydra_integration.md @@ -1,57 +1,70 @@ ## Hydra -[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python framework that simplifies the development of -research and other complex applications. The key feature is the ability to dynamically create a hierarchical -configuration by composition and override it through config files and the command line. The name Hydra comes from its -ability to run multiple similar jobs - much like a Hydra with multiple heads. +[Hydra](https://github.com/facebookresearch/hydra) is an open-source Python +framework that simplifies the development of research and other complex +applications. The key feature is the ability to dynamically create a +hierarchical configuration by composition and override it through config files +and the command line. The name Hydra comes from its ability to run multiple +similar jobs - much like a Hydra with multiple heads. ## Motivation -Until recently, all components in fairseq were configured through a shared "args" namespace that was created at -application startup. Components declared their own "add_args" method to update the argparse parser, hoping that -the names would not clash with arguments from other components. While this model works for smaller applications, -as fairseq grew and became integrated into other applications, this became problematic. -In order to determine how to configure each component, one needed to a) examine what args were added by this component, and -b) read the code to figure out what shared arguments it is using that were added in other places. Reproducing -models involved sharing commands that often contained dozens of command line switches. - -The model described above is still supported by fairseq for backward compatibility, but will be deprecated some time -in the future. - -New components in fairseq should now create a dataclass that encapsulates all parameters required to configure this -component. The dataclass is registered along with the component, and fairseq takes care of constructing and -providing this configuration object to the component's constructor. Note that sharing parameters can optionally -still work, but one has to explicitly point to the "source of truth" (see inheritance example below). -These changes make components in fairseq -more independent and re-usable by other applications: all that is needed to create a component is to initialize its -dataclass and overwrite some of the defaults. - -While configuring fairseq through command line (using either the legacy argparse based or the new Hydra based entry points) is still -fully supported, you can now take advantage of configuring fairseq completely or piece-by-piece through -hierarchical YAML configuration files. These files can also be shipped as examples that others can use to run -an identically configured job. - -Additionally, Hydra has a rich and growing -[library of plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that provide functionality such as -hyperparameter sweeping (including using bayesian optimization through the [Ax](https://github.com/facebook/Ax) library), -job launching across various platforms, and more. +Until recently, all components in fairseq were configured through a shared +`args` namespace that was created at application startup. Components declared +their own `add_args` method to update the argparse parser, hoping that the names +would not clash with arguments from other components. While this model works for +smaller applications, as fairseq grew and became integrated into other +applications, this became problematic. In order to determine how to configure +each component, one needed to a) examine what args were added by this component, +and b) read the code to figure out what shared arguments it is using that were +added in other places. Reproducing models involved sharing commands that often +contained dozens of command line switches. + +The model described above is still supported by fairseq for backward +compatibility, but will be deprecated some time in the future. + +New components in fairseq should now create a dataclass that encapsulates all +parameters required to configure this component. The dataclass is registered +along with the component, and fairseq takes care of constructing and providing +this configuration object to the component's constructor. Note that sharing +parameters can optionally still work, but one has to explicitly point to the +"source of truth" (see inheritance example below). These changes make components +in fairseq more independent and re-usable by other applications: all that is +needed to create a component is to initialize its dataclass and overwrite some +of the defaults. + +While configuring fairseq through command line (using either the legacy argparse +based or the new Hydra based entry points) is still fully supported, you can now +take advantage of configuring fairseq completely or piece-by-piece through +hierarchical YAML configuration files. These files can also be shipped as +examples that others can use to run an identically configured job. + +Additionally, Hydra has a rich and growing [library of +plugins](https://github.com/facebookresearch/hydra/tree/master/plugins) that +provide functionality such as hyperparameter sweeping (including using bayesian +optimization through the [Ax](https://github.com/facebook/Ax) library), job +launching across various platforms, and more. ## Creating or migrating components -In general, each new (or updated) component should provide a companion [dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are typically located in the same -file as the component and are passed as arguments to the register_*() functions. Top-level configs that should be -present in every fairseq application are placed in the [global](fairseq/dataclass/configs.py) config file and added -to the FairseqConfig object. - -Each dataclass is a plain-old-data object, similar to a NamedTuple. These classes are decorated with a @dataclass -decorator, and typically inherit from `FairseqDataclass` (which adds some functionality for backward compatibility). -Each field must have a type, and generally has metadata (such as a help string) and a default value. Only primitive types or other config objects are allowed as +In general, each new (or updated) component should provide a companion +[dataclass](https://www.python.org/dev/peps/pep-0557/). These dataclass are +typically located in the same file as the component and are passed as arguments +to the `register_*()` functions. Top-level configs that should be present in +every fairseq application are placed in the +[global](fairseq/dataclass/configs.py) config file and added to the +`FairseqConfig` object. + +Each dataclass is a plain-old-data object, similar to a `NamedTuple`. These +classes are decorated with a `@dataclass` decorator, and typically inherit from +`FairseqDataclass` (which adds some functionality for backward compatibility). +Each field must have a type, and generally has metadata (such as a help string) +and a default value. Only primitive types or other config objects are allowed as data types for each field. - Example: - +#### Example: -``` python +```python from dataclasses import dataclass, field from fairseq.dataclass import FairseqDataclass @@ -71,11 +84,12 @@ class InteractiveConfig(FairseqDataclass): ### Inherting values -Some components require sharing a value. For example, a learning rate scheduler and an optimizer may both need to -know the initial learning rate value. One can declare a field that, by default, will -inherit its value from another config node in the same hierarchy: +Some components require sharing a value. For example, a learning rate scheduler +and an optimizer may both need to know the initial learning rate value. One can +declare a field that, by default, will inherit its value from another config +node in the same hierarchy: -``` python +```python @dataclass FairseqAdamConfig(FairseqDataclass): ... @@ -83,18 +97,21 @@ FairseqAdamConfig(FairseqDataclass): ... ``` -`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"` , which is the value one can use in a YAML config file or through -command line to achieve the same effect. Note that this assumes that there is an "optimization" config object -in the root config and it has a field called "lr". +`II("optimization.lr")` is syntactic sugar for `"${optimization.lr}"`, which is +the value one can use in a YAML config file or through command line to achieve +the same effect. Note that this assumes that there is an "optimization" config +object in the root config and it has a field called "lr". ### Tasks and Models -Creating Tasks and Models works same as before, except that legacy implementations now inherit from Legacy* base classes, -while new components inherit from FairseqTask and FairseqModel and provide a dataclass to the register_*() functions. +Creating Tasks and Models works same as before, except that legacy +implementations now inherit from `LegacyFairseq*` base classes, while new +components inherit from `FairseqTask` and `FairseqModel` and provide a dataclass +to the `register_*()` functions. -Task example: +#### Task example: -``` python +```python @dataclass class LanguageModelingConfig(FairseqDataclass): data: Optional[str] = field( @@ -110,9 +127,9 @@ class LanguageModelingTask(LegacyFairseqTask): ... ``` -Model example: +#### Model example: -``` python +```python @dataclass class TransformerLanguageModelConfig(FairseqDataclass): activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( @@ -131,9 +148,10 @@ class TransformerLanguageModel(FairseqLanguageModel): ### Other components -Other components work as before, but they now take their configuration dataclass as the only constructor argument: +Other components work as before, but they now take their configuration dataclass +as the only constructor argument: -``` python +```python @dataclass class MosesTokenizerConfig(FairseqDataclass): source_lang: str = field(default="en", metadata={"help": "source language"}) @@ -145,50 +163,61 @@ class MosesTokenizer(object): ... ``` -Note that if you are adding a new registry for a new set of components, you need to add it to the FairseqConfig object in -fairseq/dataclass/configs.py: +Note that if you are adding a new registry for a new set of components, you need +to add it to the `FairseqConfig` object in `fairseq/dataclass/configs.py`: -``` python +```python @dataclass class FairseqConfig(object): ... my_new_registry: Any = None ``` -## Training with hydra_train.py +## Training with `fairseq-hydra-train` -To fully take advantage of configuration flexibility offered by Hydra, you may want to train new models using the -hydra_train.py entry point located in the fairseq_cli directory. Legacy CLI tools such as train.py, -will remain supported for the foreseeable future but will be deprecated eventually. +To fully take advantage of configuration flexibility offered by Hydra, you may +want to train new models using the `fairseq-hydra-train` entry point. Legacy CLI +tools such as `fairseq-train` will remain supported for the foreseeable future +but will be deprecated eventually. -On startup, Hydra will create a configuration object that contains a hierarchy of all the necessary dataclasses -populated with their default values in the code. The default values are overwritten by values found in YAML files in -fairseq/config directory (which currently just set default task, optimizer, etc) and then further overwritten by values -provided through command line arguments. +On startup, Hydra will create a configuration object that contains a hierarchy +of all the necessary dataclasses populated with their default values in the +code. The default values are overwritten by values found in YAML files in +`fairseq/config` directory (which currently sets minimal defaults) and then +further overwritten by values provided through command line arguments. Some of the most common use cases are shown below: -### 1. Overwrite default values through command line: +### 1. Override default values through command line: ```shell script -python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 task.data=data-bin \ -model=transformer_lm/transformer_lm_gpt task=language_modeling optimization.max_update=5000 - +$ fairseq-hydra-train \ + distributed_training.distributed_world_size=1 \ + dataset.batch_size=2 \ + task.data=data-bin \ + model=transformer_lm/transformer_lm_gpt \ + task=language_modeling \ + optimization.max_update=5000 ``` -Note that along with explicitly providing values for parameters such as dataset.batch_size, this also tells Hydra to overlay configuration found in `fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` -over the default values in the dataclass. If you want to train a model without specifying a particular architecture -you can simply specify model=transformer_lm. This only works for migrated tasks and models. +Note that along with explicitly providing values for parameters such as +`dataset.batch_size`, this also tells Hydra to overlay configuration found in +`fairseq/config/model/transformer_lm/transformer_lm_gpt.yaml` over the default +values in the dataclass. If you want to train a model without specifying a +particular architecture you can simply specify `model=transformer_lm`. This only +works for migrated tasks and models. ### 2. Replace bundled configs with an external config: ```shell script -python fairseq_cli/hydra_train.py --config-path /path/to/external/configs --config-name wiki103 +$ fairseq-hydra-train \ + --config-dir /path/to/external/configs \ + --config-name wiki103 ``` -where /path/to/external/configs/wiki103.yaml contains: +where `/path/to/external/configs/wiki103.yaml` contains: -``` yaml +```yaml # @package _group_ model: @@ -211,24 +240,38 @@ lr_scheduler: _name: cosine ``` -Note that here bundled configs from `fairseq/config` directory are not used, however the defaults from each dataclass will still be used (unless overwritten by your external config). +Note that here bundled configs from `fairseq/config` directory are not used, +however the defaults from each dataclass will still be used (unless overwritten +by your external config). -Additionally you can choose to break up your configs by creating a directory structure in the same location as your main config file, with the names of the top-level fields -(such as "model", "dataset", etc), and placing config files with meaningful names that would populate that specific section of your -top-level config file (for example, you might have model/small_transformer_lm.yaml, model/big_transformer_lm.yaml, etc). You can then specify the correct configuration via command line, defaults in the main config, or even launch all of them as a sweep (see Hydra documentation on how to do this). +Additionally you can choose to break up your configs by creating a directory +structure in the same location as your main config file, with the names of the +top-level fields (such as "model", "dataset", etc), and placing config files +with meaningful names that would populate that specific section of your +top-level config file (for example, you might have +`model/small_transformer_lm.yaml`, `model/big_transformer_lm.yaml`, etc). You +can then specify the correct configuration via command line, defaults in the +main config, or even launch all of them as a sweep (see Hydra documentation on +how to do this). ### 3. Add an external config directory to Hydra search path: -This allows combining default configuration (including using any bundled config files), while specifying your own config files for some parts of the configuration. +This allows combining default configuration (including using any bundled config +files), while specifying your own config files for some parts of the +configuration. ```shell script -python fairseq_cli/hydra_train.py distributed_training.distributed_world_size=1 dataset.batch_size=2 \ -task.data=/path/to/data/ model=transformer_lm/2_layers task=language_modeling optimization.max_update=5000 \ ---config-dir /path/to/external/configs - +$ fairseq-hydra-train \ + distributed_training.distributed_world_size=1 \ + dataset.batch_size=2 \ + task.data=/path/to/data/ \ + model=transformer_lm/2_layers \ + task=language_modeling \ + optimization.max_update=5000 \ + --config-dir /path/to/external/configs ``` -where /path/to/external/configs has the following structure: +where `/path/to/external/configs` has the following structure: ``` . +-- model @@ -236,5 +279,6 @@ where /path/to/external/configs has the following structure: | | +-- 2_layers.yaml ``` -and 2_layers.yaml contains a copy of transformer_lm_gpt.yaml but with decoder_layers set to 2. You can add -other configs to configure other components as well. +and `2_layers.yaml` contains a copy of `transformer_lm_gpt.yaml` but with +`decoder_layers` set to 2. You can add other configs to configure other +components as well. diff --git a/examples/__init__.py b/examples/__init__.py index 80d95f5fe7..44bb24ae61 100644 --- a/examples/__init__.py +++ b/examples/__init__.py @@ -3,4 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from fairseq.version import __version__ # noqa +try: + from fairseq.version import __version__ # noqa +except ImportError: + pass diff --git a/examples/adaptive_span/README.md b/examples/adaptive_span/README.md new file mode 100644 index 0000000000..913a873386 --- /dev/null +++ b/examples/adaptive_span/README.md @@ -0,0 +1,90 @@ +# Adaptive Span + +Adaptive Span is a novel self-attention mechanism that can learn its optimal +attention span. This allows us to extend significantly the maximum context size +used in Transformer, while maintaining control over their memory footprint +and computational time. It uses the Truncated BPTT technique for training, +as in [transformerXL](https://github.com/pytorch/fairseq/blob/master/examples/truncated_bptt/README.md). + +Adaptive Span was introduced by paper: +[Adaptive Attention Span in Transformers](https://arxiv.org/abs/1905.07799), +which achieved state-of-the-art language modeling results at the time of publication. + +We manage to reproduce their result in fairseq and keep most of the +[original implementation](https://github.com/facebookresearch/adaptive-span) untouched. +You can refer to the their sweep file as well if any combination of hyperparameter is not clear. + +##### 0. Setup + +First you need to process the Enwik8 dataset, we use the pre-tokenized dataset +from [adaptive span paper](https://github.com/facebookresearch/adaptive-span/blob/master/get_data.sh). +You can download the dataset, and then run: +```bash +fairseq-preprocess --only-source --trainpref ~/data/enwik8/train.txt \ + --validpref ~/data/enwik8/valid.txt --testpref ~/data/enwik8/test.txt \ + --destdir ~/data/enwik8/data-bin/ --joined-dictionary --workers 20 +``` + +##### 1. Train a Adaptive Span model on Enwik8 + +We will train a 12-layer Adaptive Span model following the [hyperparameters +used in the original +paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh). + +The following command assumes 4 GPUs, so that the total batch size is 64 +sequences (4 x 16). Training should take 2-3 days on 4 V100 GPUs: +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ + --user-dir examples/adaptive_span \ + --data ~/data/enwik8/data-bin/ \ + --fp16 --fp16-no-flatten-grads --max-update 600000 \ + --task truncated_bptt_lm --tokens-per-sample 512 --arch adaptive_span \ + --n-layer 12 --d-model 512 --n-head 8 --d-inner 2048 --dropout 0.3 \ + --attn-span 8192 --optimizer adagrad_with_grad_clip --adagrad-clip 0.03 \ + --validate-interval-updates 1000 \ + --lr-scheduler fixed --warmup-updates 32000 --batch-size-valid 32 \ + --lr 0.07 --criterion adaptive_span_loss --batch-size 16 --update-freq 1 \ + --seed 2 --log-format json --log-interval 25 --aux-loss-scaler 5e-07 +``` +This should land around 1.05 on validation, 1.03 on test. You can lower the +--aux-loss-scaler for better performance (longer span). It gives ~0.03 bpc +improvement to the transformerXL baseline here. +If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients +and simulate training on 4 GPUs. +You can also reproduce the transformerXL result on enwik8 using this code base. +It should land around 1.06 on test,matching the [original paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_enwik8_base.sh). +You can try by +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ + --user-dir examples/truncated_bptt \ + ~/data/enwik8/data-bin/ \ + --task truncated_bptt_lm --fp16 --max-update 400000 \ + --tokens-per-sample 512 --arch transformer_xl --n-layer 12 \ + --d-model 512 --n-head 8 --d-head 64 --d-inner 2048 --dropout 0.1 \ + --dropatt 0.0 --mem-len 512 --optimizer adam --clip-norm 0.25 \ + --lr-scheduler cosine --warmup-updates 0 \ + --lr 0.0 --lr 0.00025 --batch-size 15 \ + --update-freq 1 --seed 2 --log-format json --log-interval 25 \ + --fp16 +``` + +##### 2. Evaluate +For Adaptive Span: +```bash +fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \ + --user-dir examples/adaptive_span \ + --task truncated_bptt_lm --batch-size 8 --tokens-per-sample 512 --gen-subset test +``` +For Transformer-XL evaluation: +```bash +fairseq-eval-lm ~/data/enwik8/data-bin/ --path model/checkpoint_best.pt \ + --user-dir examples/truncated_bptt/ --task truncated_bptt_lm --batch-size 8 \ + --tokens-per-sample 80 \ + --model-overrides '{"mem_len":2100,"clamp_len":820,"same_length":True}' \ + --gen-subset valid +``` + +*Note:* During training the model saw 512 tokens of context +(``--tokens-per-sample=512``), with batch size 8. These settings match the evaluation +settings from [the original +paper](https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8.sh). diff --git a/examples/adaptive_span/__init__.py b/examples/adaptive_span/__init__.py new file mode 100644 index 0000000000..e0a142a769 --- /dev/null +++ b/examples/adaptive_span/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +# automatically import any Python files in the current directory +cur_dir = os.path.dirname(__file__) +for file in os.listdir(cur_dir): + path = os.path.join(cur_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + mod_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module(__name__ + "." + mod_name) diff --git a/examples/adaptive_span/adagrad_with_grad_clip.py b/examples/adaptive_span/adagrad_with_grad_clip.py new file mode 100644 index 0000000000..585ce184ab --- /dev/null +++ b/examples/adaptive_span/adagrad_with_grad_clip.py @@ -0,0 +1,128 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch.optim import Adagrad + +from fairseq.optim import LegacyFairseqOptimizer, register_optimizer + + +@register_optimizer("adagrad_with_grad_clip") +class FairseqAdagradWithGradClip(LegacyFairseqOptimizer): + def __init__(self, args, params): + super().__init__(args) + self._optimizer = AdagradWithGradClip(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # fmt: off + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--adagrad-clip', default=0.0, type=float, metavar='D', + help='internal grad clip') + # fmt: on + + @property + def optimizer_config(self): + """ + Return a kwarg dictionary that will be used to override optimizer + args stored in checkpoints. This allows us to load a checkpoint and + resume training using a different set of optimizer args, e.g., with a + different learning rate. + """ + return { + "lr": self.args.lr[0], + "weight_decay": self.args.weight_decay, + "grad_clip": self.args.adagrad_clip, + } + + @property + def supports_flat_params(self): + return False + + +def _clip_grad(clr, grad, group_grad_clip): + if group_grad_clip > 0: + norm = grad.norm(2).item() + if norm > group_grad_clip: + clr *= group_grad_clip / (norm + 1e-10) + return clr + + +class AdagradWithGradClip(Adagrad): + """Adagrad algorithm with custom gradient clipping""" + + def __init__( + self, + params, + lr=1e-2, + lr_decay=0, + weight_decay=0, + initial_accumulator_value=0, + grad_clip=0, + ): + Adagrad.__init__( + self, + params, + lr=lr, + lr_decay=lr_decay, + weight_decay=weight_decay, + initial_accumulator_value=initial_accumulator_value, + ) + self.defaults["grad_clip"] = grad_clip + self.param_groups[0].setdefault("grad_clip", grad_clip) + + def step(self, closure=None): + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + + grad = p.grad.data + state = self.state[p] + + state["step"] += 1 + + if group["weight_decay"] != 0: + if p.grad.data.is_sparse: + raise RuntimeError( + "weight_decay option is " + "not compatible with sparse " + "gradients" + ) + grad = grad.add(group["weight_decay"], p.data) + + clr = group["lr"] / (1 + (state["step"] - 1) * group["lr_decay"]) + + # clip + clr = _clip_grad(clr=clr, grad=grad, group_grad_clip=group["grad_clip"]) + + if grad.is_sparse: + # the update is non-linear so indices must be unique + grad = grad.coalesce() + grad_indices = grad._indices() + grad_values = grad._values() + size = grad.size() + + def make_sparse(values): + constructor = grad.new + if grad_indices.dim() == 0 or values.dim() == 0: + return constructor().resize_as_(grad) + return constructor(grad_indices, values, size) + + state["sum"].add_(make_sparse(grad_values.pow(2))) + std = state["sum"]._sparse_mask(grad) + std_values = std._values().sqrt_().add_(1e-10) + p.data.add_(-clr, make_sparse(grad_values / std_values)) + else: + state["sum"].addcmul_(1, grad, grad) + std = state["sum"].sqrt().add_(1e-10) + p.data.addcdiv_(-clr, grad, std) + + return loss diff --git a/examples/adaptive_span/adaptive_span_attention.py b/examples/adaptive_span/adaptive_span_attention.py new file mode 100644 index 0000000000..07f757bb8e --- /dev/null +++ b/examples/adaptive_span/adaptive_span_attention.py @@ -0,0 +1,160 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AdaptiveMask(nn.Module): + """Soft masking function for adaptive size. + It masks out the last K values of an input. The masking value + goes from 1 to 0 gradually, so K can be learned with + back-propagation. + Args: + max_size: maximum size (i.e. input dimension) + ramp_size: size of the ramp going from 0 to 1 + init_val: initial size proportion not to be masked out + shape: learn multiple sizes independent of each other + """ + + def __init__(self, max_size, ramp_size, init_val=0, shape=(1,)): + nn.Module.__init__(self) + self._max_size = max_size + self._ramp_size = ramp_size + self.current_val = nn.Parameter(torch.zeros(*shape) + init_val) + mask_template = torch.linspace(1 - max_size, 0, steps=max_size) + self.register_buffer("mask_template", mask_template) + + def forward(self, x): + mask = self.mask_template.float() + self.current_val.float() * self._max_size + mask = mask / self._ramp_size + 1 + mask = mask.clamp(0, 1) + if x.size(-1) < self._max_size: + # the input could have been trimmed beforehand to save computation + mask = mask.narrow(-1, self._max_size - x.size(-1), x.size(-1)) + x = (x * mask).type_as(x) + return x + + def get_current_max_size(self, include_ramp=True): + current_size = math.ceil(self.current_val.max().item() * self._max_size) + if include_ramp: + current_size += self._ramp_size + current_size = max(0, min(self._max_size, current_size)) + return current_size + + def get_current_avg_size(self, include_ramp=True): + current_size = math.ceil( + self.current_val.float().mean().item() * self._max_size + ) + if include_ramp: + current_size += self._ramp_size + current_size = max(0, min(self._max_size, current_size)) + return current_size + + def clamp_param(self): + """this need to be called after each update""" + self.current_val.data.clamp_(0, 1) + + +class AdaptiveSpan(nn.Module): + """Adaptive attention span for Transformerself. + This module learns an attention span length from data for each + self-attention head. + Args: + attn_span: maximum attention span + adapt_span_loss: loss coefficient for the span length + adapt_span_ramp: length of the masking ramp + adapt_span_init: initial size ratio + adapt_span_cache: adapt cache size to reduce memory usage + """ + + def __init__( + self, + attn_span, + adapt_span_ramp, + adapt_span_init, + n_head, + adapt_span_layer, + **kargs + ): + nn.Module.__init__(self) + self._max_span = attn_span + self._n_head = n_head + self._adapt_span_layer = adapt_span_layer + if self._adapt_span_layer: + self._mask = AdaptiveMask( + max_size=self._max_span, + ramp_size=adapt_span_ramp, + init_val=adapt_span_init, + ) + else: + self._mask = AdaptiveMask( + max_size=self._max_span, + ramp_size=adapt_span_ramp, + init_val=adapt_span_init, + shape=(n_head, 1, 1), + ) + + def forward(self, attn, normalize=True): + """mask attention with the right span""" + # batch and head dimensions are merged together, so separate them first + self.clamp_param() + if self._adapt_span_layer: + attn = self._mask(attn) + else: + B = attn.size(0) # batch size + M = attn.size(1) # block size + attn = attn.reshape(B // self._n_head, self._n_head, M, -1) + attn = self._mask(attn) + attn = attn.view(B, M, -1) + return attn + + def get_trim_len(self): + """how much of memory can be trimmed to reduce computation""" + L = self._max_span + trim_len = min(L - 1, L - self._mask.get_current_max_size()) + # too fine granularity might be bad for the memory management + trim_len = math.floor(trim_len / 64) * 64 + return trim_len + + def trim_memory(self, query, key, value, key_pe): + """trim out unnecessary memory beforehand to reduce computation""" + trim_len = self.get_trim_len() + cache_size = key.size(1) - query.size(1) + trim_len_cache = trim_len - (self._max_span - cache_size) + if trim_len_cache > 0: + key = key[:, trim_len_cache:, :] + value = value[:, trim_len_cache:, :] + elif trim_len_cache < 0: + # cache is too short! this happens when validation resumes + # after a lot of updates. + key = F.pad(key, [0, 0, -trim_len_cache, 0]) + value = F.pad(value, [0, 0, -trim_len_cache, 0]) + if trim_len > 0: + if key_pe is not None: + key_pe = key_pe[:, :, trim_len:] + return key, value, key_pe + + def get_cache_size(self): + """determine how long the cache should be""" + trim_len = self.get_trim_len() + # give a buffer of 64 steps since a span might increase + # in future updates + return min(self._max_span, self._max_span - trim_len + 64) + + def get_loss(self): + """a loss term for regularizing the span length""" + return self._max_span * self._mask.current_val.float().mean() + + def get_current_max_span(self): + return self._mask.get_current_max_size() + + def get_current_avg_span(self): + return self._mask.get_current_avg_size() + + def clamp_param(self): + self._mask.clamp_param() diff --git a/examples/adaptive_span/adaptive_span_loss.py b/examples/adaptive_span/adaptive_span_loss.py new file mode 100644 index 0000000000..056245807e --- /dev/null +++ b/examples/adaptive_span/adaptive_span_loss.py @@ -0,0 +1,106 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass + +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import register_criterion +from fairseq.criterions.cross_entropy import CrossEntropyCriterion +from fairseq.dataclass import FairseqDataclass +from omegaconf import II + + +@dataclass +class AdaptiveSpanCriterionConfig(FairseqDataclass): + sentence_avg: bool = II("optimization.sentence_avg") + + +@register_criterion("adaptive_span_loss", dataclass=AdaptiveSpanCriterionConfig) +class AdaptiveSpanCriterion(CrossEntropyCriterion): + def __init__(self, task, sentence_avg): + super().__init__(task, sentence_avg) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss here is summed, different from the adaptive span code + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss, aux_loss, avg_span, max_span = self.compute_loss( + model, net_output, sample, reduce=reduce + ) + sample_size = ( + sample["target"].size(0) if self.sentence_avg else sample["ntokens"] + ) + loss /= sample_size + total_loss = loss + aux_loss + sample_size = 1 + + logging_output = { + "loss": loss.data, + "ntokens": sample["ntokens"], + "nsentences": sample["target"].size(0), + "sample_size": sample_size, + "total_loss": total_loss.data, + "avg_span": avg_span * sample_size, + "max_span": max_span * sample_size, + } + return total_loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + loss, _ = super().compute_loss(model, net_output, sample, reduce) + aux_loss = model.get_aux_loss() + avg_span = model.get_current_avg_span() + max_span = model.get_current_max_span() + return loss, aux_loss, avg_span, max_span + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + total_loss_sum = sum(log.get("total_loss", 0) for log in logging_outputs) + avg_span_sum = sum(log.get("avg_span", 0) for log in logging_outputs) + max_span_sum = sum(log.get("max_span", 0) for log in logging_outputs) + + # we divide by log(2) to convert the loss from base e to base 2 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar("avg_span", avg_span_sum / sample_size, sample_size, round=3) + metrics.log_scalar("max_span", max_span_sum / sample_size, sample_size, round=3) + # total loss contains the L1 norm on adaptive-span + metrics.log_scalar( + "total_loss", + total_loss_sum / sample_size / math.log(2), + sample_size, + round=3, + ) + if sample_size != ntokens: + metrics.log_scalar( + "nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3 + ) + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg) + ) + else: + metrics.log_derived( + "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg) + ) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/examples/adaptive_span/adaptive_span_model.py b/examples/adaptive_span/adaptive_span_model.py new file mode 100644 index 0000000000..d96c95b85d --- /dev/null +++ b/examples/adaptive_span/adaptive_span_model.py @@ -0,0 +1,263 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fairseq.modules.layer_norm import LayerNorm + +from .adaptive_span_attention import AdaptiveSpan + +# Size notations: +# B = batch_size, H = d_model, M = block_size, L = attn_span + + +def _skew(X, pad_value): + """shift every row 1 step to right""" + # X = B x M x L + B, M, L = X.size() + X = F.pad(X, (0, M + 1), value=pad_value) # B x M x (L+M+1) + X = X.view(B, -1) # B x ML+MM+M + X = X[:, :-M] # B x ML+MM + X = X.view(B, M, M + L) # B x M x L+M + return X + + +def _unskew(X): + """reverse _skew operation""" + # X = B x M x L+M + B, M, L = X.size() + L -= M + X = X.view(B, -1) # B x ML+MM + X = F.pad(X, (0, M)) # B x ML+MM+M + X = X.view(B, M, M + L + 1) # B x M x L+M+1 + X = X[:, :, :L] # B x M x L + return X + + +class SeqAttention(nn.Module): + """Sequential self-attention layer. + Each token will attend to its previous fixed number of steps. + Note that attention doesn't include the current step itself. + """ + + def __init__(self, d_model, n_head, attn_span, dropout, adapt_span_layer, **kargs): + nn.Module.__init__(self) + self.dropout = nn.Dropout(dropout) + self.d_model = d_model # size of a single head + self.attn_span = attn_span + self.adaptive_span = AdaptiveSpan( + attn_span=attn_span, + n_head=n_head, + adapt_span_layer=adapt_span_layer, + **kargs + ) + + def forward(self, query, key, value, key_pe): + # query size = B x M x H + # key, value sizes = B x (M+L) x H + + key, value, key_pe = self.adaptive_span.trim_memory(query, key, value, key_pe) + + # compute attention from context + # B x M (dest) x (M+L) (src) + attn_cont = torch.matmul(query, key.transpose(-1, -2)) + attn_cont = _unskew(attn_cont) # B x M x L + + # compute the effect of position embedding + attn_pos = torch.matmul(query, key_pe) # B x M x L_pos + attn = attn_cont + attn_pos + + attn = attn / math.sqrt(self.d_model) # B x M X L_pos + + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + + # trim attention lengths according to the learned span + attn = self.adaptive_span(attn) + + attn = self.dropout(attn) # B x M X L_pos + + attn_cont = _skew(attn, 0) # B x M X (L+M) + out = torch.matmul(attn_cont, value) # B x M x H + return out + + def get_cache_size(self): + return self.adaptive_span.get_cache_size() + + +class MultiHeadSeqAttention(nn.Module): + def __init__(self, d_model, n_head, **kargs): + nn.Module.__init__(self) + assert d_model % n_head == 0 + self.n_head = n_head + self.head_dim = d_model // n_head + self.attn = SeqAttention(d_model=self.head_dim, n_head=n_head, **kargs) + self.proj_query = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_query.weight) + self.proj_out = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_out.weight) + self.proj_val = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_val.weight) + self.proj_key = nn.Linear(d_model, d_model, bias=False) + nn.init.xavier_normal_(self.proj_key.weight) + + def head_reshape(self, x): + K = self.n_head + D = self.head_dim + x = x.view(x.size()[:-1] + (K, D)) # B x (M+L) x K x D + x = x.transpose(1, 2).contiguous() # B x K x (M+L) x D + x = x.view(-1, x.size(-2), x.size(-1)) # B_K x (M+L) x D + return x + + def forward(self, query, key, value, key_pe): + B = query.size(0) + K = self.n_head + D = self.head_dim + M = query.size(1) + + query = self.proj_query(query) + query = self.head_reshape(query) + value = self.proj_val(value) + value = self.head_reshape(value) + key = self.proj_key(key) + key = self.head_reshape(key) + + out = self.attn(query, key, value, key_pe) # B_K x M x D + out = out.view(B, K, M, D) # B x K x M x D + out = out.transpose(1, 2).contiguous() # B x M x K x D + out = out.view(B, M, -1) # B x M x K_D + out = self.proj_out(out) + return out + + +class FeedForwardLayer(nn.Module): + def __init__(self, d_model, d_inner, dropout, **kargs): + nn.Module.__init__(self) + self.fc1 = nn.Linear(d_model, d_inner) + self.fc2 = nn.Linear(d_inner, d_model) + nn.init.xavier_uniform_(self.fc1.weight) + nn.init.xavier_uniform_(self.fc2.weight) + self.dropout = nn.Dropout(dropout) + + def forward(self, h): + h1 = F.relu(self.fc1(h)) + h1 = self.dropout(h1) + h2 = self.fc2(h1) + return h2 + + +class TransformerSeqLayer(nn.Module): + def __init__(self, d_model, **kargs): + nn.Module.__init__(self) + self.attn = MultiHeadSeqAttention(d_model=d_model, **kargs) + self.norm1 = LayerNorm(d_model) + self.ff = FeedForwardLayer(d_model=d_model, **kargs) + self.norm2 = LayerNorm(d_model) + + def forward(self, h, h_cache, key_pe): + # h = B x M x H + # h_cache = B x L x H + h_all = torch.cat([h_cache, h], dim=1) # B x (M+L) x H + attn_out = self.attn(h, h_all, h_all, key_pe) + h = self.norm1(h + attn_out) # B x M x H + if self.ff is not None: + ff_out = self.ff(h) + out = self.norm2(h + ff_out) # B x M x H + else: + out = h + return out + + def get_cache_size(self): + return self.attn.attn.get_cache_size() + + +class TransformerSeq(nn.Module): + def __init__( + self, + vocab_size, + d_model, + n_head, + n_layer, + attn_span, + emb_dropout, + aux_loss_scaler, + adapt_span_layer, + **kargs + ): + nn.Module.__init__(self) + # token embeddings + self.in_emb = nn.Embedding(vocab_size, d_model) + nn.init.normal_(self.in_emb.weight, mean=0, std=d_model ** -0.5) + self.out_emb = nn.Linear(d_model, vocab_size) + self.aux_loss_scaler = aux_loss_scaler + if emb_dropout > 0: + self.emb_dropout = nn.Dropout(emb_dropout) + else: + self.emb_dropout = None + # position embeddings + self.key_pe = nn.Parameter(torch.randn(1, d_model // n_head, attn_span)) + + self.layers = nn.ModuleList() + self.layers.extend( + TransformerSeqLayer( + d_model=d_model, + n_head=n_head, + attn_span=attn_span, + adapt_span_layer=adapt_span_layer, + **kargs + ) + for _ in range(n_layer) + ) + + def forward(self, x, h_cache, target=None): + # x size = B x M + block_size = x.size(1) + h = self.in_emb(x) # B x M x H + if self.emb_dropout is not None: + h = self.emb_dropout(h) + + h_cache_next = [] + for l, layer in enumerate(self.layers): + cache_size = layer.attn.attn.get_cache_size() + if cache_size > block_size: + h_cache_next_l = torch.cat( + [h_cache[l][:, -cache_size + block_size :, :], h], dim=1 + ).detach() + else: + h_cache_next_l = h[:, -cache_size:, :].detach() + h_cache_next.append(h_cache_next_l) + h = layer(h, h_cache[l], self.key_pe) # B x M x H + + if self.emb_dropout is not None: + h = self.emb_dropout(h) + + out = F.log_softmax(self.out_emb(h).float(), dim=-1).type_as(h) + dummy_loss = None + + return out, h_cache_next, dummy_loss + + def get_aux_loss(self): + loss = 0.0 + for layer in self.layers: + loss += layer.attn.attn.adaptive_span.get_loss() + return self.aux_loss_scaler * loss + + def get_current_max_span(self): + max_span = 0.0 + for layer in self.layers: + max_span = max( + max_span, layer.attn.attn.adaptive_span.get_current_max_span() + ) + return max_span + + def get_current_avg_span(self): + avg_span = 0.0 + for layer in self.layers: + avg_span += layer.attn.attn.adaptive_span.get_current_avg_span() + return avg_span / len(self.layers) diff --git a/examples/adaptive_span/adaptive_span_model_wrapper.py b/examples/adaptive_span/adaptive_span_model_wrapper.py new file mode 100644 index 0000000000..5b147fe11f --- /dev/null +++ b/examples/adaptive_span/adaptive_span_model_wrapper.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional + +import torch +from fairseq.dataclass import FairseqDataclass +from fairseq.models import ( + FairseqIncrementalDecoder, + FairseqLanguageModel, + register_model, +) +from .adaptive_span_model import TransformerSeq as AdaptiveSpanTransformerModel + + +logger = logging.getLogger(__name__) + + +@dataclass +class AdaptiveSpanSmallConfig(FairseqDataclass): + # defaults come from https://github.com/facebookresearch/adaptive-span/blob/master/experiments/enwik8_small.sh + vocab_size: int = 50 + d_model: int = 256 + n_head: int = 4 + d_inner: int = 1024 + n_layer: int = 8 + attn_span: int = 1024 + dropout: float = 0.0 + emb_dropout: float = 0.0 + adapt_span_ramp: int = 32 + adapt_span_init: float = 0.0 + aux_loss_scaler: float = 0.000002 + adapt_span_layer: bool = False + + +@register_model("adaptive_span", dataclass=AdaptiveSpanSmallConfig) +class AdaptiveSpanTransformer(FairseqLanguageModel): + @classmethod + def build_model(cls, cfg: AdaptiveSpanSmallConfig, task): + return cls(AdaptiveSpanDecoder(cfg, task)) + + def get_aux_loss(self): + return self.decoder.get_aux_loss() + + def get_current_max_span(self): + return self.decoder.get_current_max_span() + + def get_current_avg_span(self): + return self.decoder.get_current_avg_span() + + +class AdaptiveSpanDecoder(FairseqIncrementalDecoder): + def __init__(self, cfg, task): + + super().__init__(task.target_dictionary) + + self.config = cfg + config = AdaptiveSpanSmallConfig( + vocab_size=len(task.target_dictionary), + d_model=cfg.d_model, + n_head=cfg.n_head, + d_inner=cfg.d_inner, + n_layer=cfg.n_layer, + attn_span=cfg.attn_span, + dropout=cfg.dropout, + emb_dropout=cfg.emb_dropout, + adapt_span_ramp=cfg.adapt_span_ramp, + adapt_span_init=cfg.adapt_span_init, + aux_loss_scaler=cfg.aux_loss_scaler, + adapt_span_layer=cfg.adapt_span_layer, + ) + logger.info(config) + self.model = AdaptiveSpanTransformerModel(**config.__dict__) + + self._mems = None + + def forward( + self, + src_tokens, + incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, + encoder_out=None, + ): + bsz = src_tokens.size(0) + if incremental_state is not None: # used during inference + mems = self.get_incremental_state("mems") + src_tokens = src_tokens[:, -1:] # only keep the most recent token + else: + mems = self._mems + + if mems is None: + # first time init + mems = self.init_hid_cache(bsz) + output = self.model(x=src_tokens, h_cache=mems,) + if incremental_state is not None: + self.set_incremental_state(incremental_state, "mems", output[1]) + else: + self._mems = output[1] + return (output[0],) + + def max_positions(self): + return self.config.attn_span + + def init_hid_cache(self, batch_sz): + hid = [] + for layer in self.model.layers: + param = next(self.model.parameters()) + h = torch.zeros( + batch_sz, + layer.get_cache_size(), + self.config.d_model, + dtype=param.dtype, + device=param.device, + ) + hid.append(h) + return hid + + def get_aux_loss(self): + return self.model.get_aux_loss() + + def get_current_max_span(self): + return self.model.get_current_max_span() + + def get_current_avg_span(self): + return self.model.get_current_avg_span() + + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], + new_order: torch.Tensor, + ): + """Reorder incremental state. + + This will be called when the order of the input has changed from the + previous time step. A typical use case is beam search, where the input + order changes between time steps based on the selection of beams. + """ + raise NotImplementedError("This is required for generation/beam search") + # mems = self.get_incremental_state(incremental_state, "mems") + # if mems is not None: + # new_mems = [mems_i.index_select(1, new_order) for mems_i in mems] + # self.set_incremental_state(incremental_state, "mems", new_mems) diff --git a/examples/adaptive_span/truncated_bptt_lm_task.py b/examples/adaptive_span/truncated_bptt_lm_task.py new file mode 120000 index 0000000000..a92da3a298 --- /dev/null +++ b/examples/adaptive_span/truncated_bptt_lm_task.py @@ -0,0 +1 @@ +../truncated_bptt/truncated_bptt_lm_task.py \ No newline at end of file diff --git a/examples/criss/mining/mine.py b/examples/criss/mining/mine.py index c86f73ae87..c872da196f 100644 --- a/examples/criss/mining/mine.py +++ b/examples/criss/mining/mine.py @@ -7,7 +7,12 @@ import glob from subprocess import check_call -import faiss +try: + import faiss + + has_faiss = True +except ImportError: + has_faiss = False import numpy as np @@ -40,6 +45,8 @@ def load_batch(emb_file, dim): def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"): + if not has_faiss: + raise ImportError("Please install Faiss") sims = [] inds = [] xfrom = 0 diff --git a/examples/cross_lingual_language_model/README.md b/examples/cross_lingual_language_model/README.md index a78f86d8da..f4c76cfed5 100644 --- a/examples/cross_lingual_language_model/README.md +++ b/examples/cross_lingual_language_model/README.md @@ -61,7 +61,7 @@ fairseq-train \ --max-update 2400000 --save-interval 1 --no-epoch-checkpoints \ --arch xlm_base \ --optimizer adam --lr-scheduler reduce_lr_on_plateau \ ---lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \ +--lr-shrink 0.5 --lr 0.0001 --stop-min-lr 1e-09 \ --dropout 0.1 \ --criterion legacy_masked_lm_loss \ --max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \ diff --git a/examples/fast_noisy_channel/README.md b/examples/fast_noisy_channel/README.md new file mode 100644 index 0000000000..a04151a796 --- /dev/null +++ b/examples/fast_noisy_channel/README.md @@ -0,0 +1,345 @@ +# Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling + +## Introduction +- [Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) introduce a simple and effective noisy channel modeling approach for neural machine translation. However, the noisy channel online decoding approach introduced in this paper is too slow to be practical. +- To address this, [Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 simple approximations to make this approach very fast and practical without much loss in accuracy. +- This README provides intructions on how to run online decoding or generation with the noisy channel modeling approach, including ways to make it very fast without much loss in accuracy. + +## Noisy Channel Modeling + +[Yee et al. (2019)](https://www.aclweb.org/anthology/D19-1571.pdf) applies the Bayes Rule to predict `P(y|x)`, the probability of the target `y` given the source `x`. +```P(y|x) = P(x|y) * P(y) / P(x)``` +- `P(x|y)` predicts the source `x` given the target `y` and is referred to as the **channel model** +- `P(y)` is a **language model** over the target `y` +- `P(x)` is generally not modeled since it is constant for all `y`. + +We use Transformer models to parameterize the direct model `P(y|x)`, the channel model `P(x|y)` and the language model `P(y)`. + +During online decoding with beam search, we generate the top `K2` candidates per beam and score them with the following linear combination of the channel model, the language model as well as the direct model scores. + +```(1 / t) * log(P(y|x) + (1 / s) * ( λ1 * log(P(x|y)) + λ2 * log(P(y) ) )``` +- `t` - Target Prefix Length +- `s` - Source Length +- `λ1` - Channel Model Weight +- `λ2` - Language Model Weight + +The top `beam_size` candidates based on the above combined scores are chosen to continue the beams in beam search. In beam search with a direct model alone, the scores from the direct model `P(y|x)` are used to choose the top candidates in beam search. + +This framework provides a great way to utlize strong target language models trained on large amounts of unlabeled data. Language models can prefer targets unrelated to the source, so we also need a channel model whose role is to ensure that the target preferred by the language model also translates back to the source. + +### Training Translation Models and Language Models + +For training Transformer models in fairseq for machine translation, refer to instructions [here](https://github.com/pytorch/fairseq/tree/master/examples/translation) + +For training Transformer models in fairseq for language modeling, refer to instructions [here](https://github.com/pytorch/fairseq/tree/master/examples/language_model) + +### Generation with Language Model for German-English translation with fairseq + +Here are instructions to generate using a direct model and a target-side language model. + +Note: +- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq) +- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing) + +```sh +binarized_data=data_dir/binarized +direct_model=de_en_seed4.pt +lm_model=en_lm.pt +lm_data=lm_data +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model} +mkdir -p ${lm_data} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt + +k2=10 +lenpen=0.16 +lm_wt=0.14 +fairseq-generate ${binarized_data} \ + --user-dir examples/fast_noisy_channel \ + --beam 5 \ + --path ${direct_model} \ + --lm-model ${lm_model} \ + --lm-data ${lm_data} \ + --k2 ${k2} \ + --combine-method lm_only \ + --task noisy_channel_translation \ + --lenpen ${lenpen} \ + --lm-wt ${lm_wt} \ + --gen-subset valid \ + --remove-bpe \ + --fp16 \ + --batch-size 10 +``` +### Noisy Channel Generation for German-English translation with fairseq + +Here are instructions for noisy channel generation with a direct model, channel model and language model as explained in section [Noisy Channel Modeling](#noisy-channel-modeling). + +Note: +- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq) +- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing) + +```sh +binarized_data=data_dir/binarized +direct_model=de_en_seed4.pt +lm_model=en_lm.pt +lm_data=lm_data +ch_model=en_de.big.seed4.pt +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model} +mkdir -p ${lm_data} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt -O ${ch_model} + +k2=10 +lenpen=0.21 +lm_wt=0.50 +bw_wt=0.30 +fairseq-generate ${binarized_data} \ + --user-dir examples/fast_noisy_channel \ + --beam 5 \ + --path ${direct_model} \ + --lm-model ${lm_model} \ + --lm-data ${lm_data} \ + --channel-model ${ch_model} \ + --k2 ${k2} \ + --combine-method noisy_channel \ + --task noisy_channel_translation \ + --lenpen ${lenpen} \ + --lm-wt ${lm_wt} \ + --ch-wt ${bw_wt} \ + --gen-subset test \ + --remove-bpe \ + --fp16 \ + --batch-size 1 +``` +## Fast Noisy Channel Modeling + +[Bhosale et al. (2020)](http://www.statmt.org/wmt20/pdf/2020.wmt-1.68.pdf) introduces 3 approximations that speed up online noisy channel decoding - +- Smaller channel models (`Tranformer Base` with 1 encoder and decoder layer each vs. `Transformer Big`) + - This involves training a channel model that is possibly smaller and less accurate in terms of BLEU than a channel model of the same size as the direct model. + - Since the role of the channel model is mainly to assign low scores to generations from the language model if they don't translate back to the source, we may not need the most accurate channel model for this purpose. +- Smaller output vocabulary size for the channel model (~30,000 -> ~1000) + - The channel model doesn't need to score the full output vocabulary, it just needs to score the source tokens, which are completely known. + - This is specified using the arguments `--channel-scoring-type src_vocab --top-k-vocab 500` + - This means that the output vocabulary for the channel model will be the source tokens for all examples in the batch and the top-K most frequent tokens in the vocabulary + - This reduces the memory consumption needed to store channel model scores significantly +- Smaller number of candidates (`k2`) scored per beam + - This is specified by reducing the argument `--k2` + + +### Fast Noisy Channel Generation for German-English translation with fairseq + +Here are instructions for **fast** noisy channel generation with a direct model, channel model and language model as explained in section [Fast Noisy Channel Modeling](#fast-noisy-channel-modeling). The main differences are that we use a smaller channel model, reduce `--k2`, set `--channel-scoring-type src_vocab --top-k-vocab 500` and increase the `--batch-size`. + +Note: +- Download and install fairseq as per instructions [here](https://github.com/pytorch/fairseq) +- Preprocess and binarize the dataset as per instructions in section [Test Data Preprocessing](#test-data-preprocessing) + +```sh +binarized_data=data_dir/binarized +direct_model=de_en_seed4.pt +lm_model=en_lm.pt +lm_data=lm_data +small_ch_model=en_de.base_1_1.seed4.pt +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt -O ${direct_model} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt -O ${lm_model} +mkdir -p ${lm_data} +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/dict.txt -O ${lm_data}/dict.txt +wget https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt -O ${small_ch_model} + +k2=3 +lenpen=0.23 +lm_wt=0.58 +bw_wt=0.26 +fairseq-generate ${binarized_data} \ + --user-dir examples/fast_noisy_channel \ + --beam 5 \ + --path ${direct_model} \ + --lm-model ${lm_model} \ + --lm-data ${lm_data} \ + --channel-model ${small_ch_model} \ + --k2 ${k2} \ + --combine-method noisy_channel \ + --task noisy_channel_translation \ + --lenpen ${lenpen} \ + --lm-wt ${lm_wt} \ + --ch-wt ${bw_wt} \ + --gen-subset test \ + --remove-bpe \ + --fp16 \ + --batch-size 50 \ + --channel-scoring-type src_vocab --top-k-vocab 500 +``` + +## Test Data Preprocessing + +For preprocessing and binarizing the test sets for Romanian-English and German-English translation, we use the following script - + +```sh +FAIRSEQ=/path/to/fairseq +cd $FAIRSEQ +SCRIPTS=$FAIRSEQ/mosesdecoder/scripts +if [ ! -d "${SCRIPTS}" ]; then + echo 'Cloning Moses github repository (for tokenization scripts)...' + git clone https://github.com/moses-smt/mosesdecoder.git +fi +TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl +NORMALIZE=$SCRIPTS/tokenizer/normalize-punctuation.perl + +s=de +t=en +test=wmt18 + +mkdir -p data_dir + +# Tokenization +if [ $s == "ro" ] ; then + # Note: Get normalise-romanian.py and remove-diacritics.py from + # https://github.com/rsennrich/wmt16-scripts/tree/master/preprocess + sacrebleu -t $test -l $s-$t --echo src | \ + $NORMALIZE -l $s | \ + python normalise-romanian.py | \ + python remove-diacritics.py | \ + $TOKENIZER -l $s -a -q > data_dir/$test.$s-$t.$s +else + sacrebleu -t $test -l $s-$t --echo src | perl $NORMALIZE -l $s | perl $TOKENIZER -threads 8 -a -l $s > data_dir/$test.$s-$t.$s +fi + +sacrebleu -t $test -l $s-$t --echo ref | perl $NORMALIZE -l $t | perl $TOKENIZER -threads 8 -a -l $t > data_dir/$test.$s-$t.$t + + +# Applying BPE +src_bpe_code=/path/to/source/language/bpe/code +tgt_bpe_code=/path/to/target/language/bpe/code +src_dict=/path/to/source/language/dict +tgt_dict=/path/to/target/language/dict + +FASTBPE=$FAIRSEQ/fastBPE +if [ ! -d "${FASTBPE}" ] ; then + git clone https://github.com/glample/fastBPE.git + # Follow compilation instructions at https://github.com/glample/fastBPE + g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast +fi + +${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${src_bpe_code} +${FASTBPE}/fast applybpe data_dir/bpe.$test.$s-$t.$s data_dir/$test.$s-$t.$s ${tgt_bpe_code} + +fairseq-preprocess -s $s -t $t \ + --testpref data_dir/bpe.$test.$s-$t \ + --destdir data_dir/binarized \ + --srcdict ${src_dict} \ + --tgtdict ${tgt_dict} +``` + +## Calculating BLEU + +```sh +DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl +cat ${generation_output} | grep -P "^H" | sort -V | cut -f 3- | $DETOKENIZER -l $t -q -a | sacrebleu -t $test -l $s-$t +``` + + +## Romanian-English Translation + +The direct and channel models are trained using bitext data (WMT16) combined with backtranslated data (The monolingual data used for backtranslation comes from http://data.statmt.org/rsennrich/wmt16_backtranslations/ (Sennrich et al., 2016c)) + +The backtranslated data is generated using an ensemble of 3 English-Romanian models trained on bitext training data (WMT16) with unrestricted sampling. + +### BPE Codes and Dictionary + +We learn a joint BPE vocabulary of 18K types on the bitext training data which is used for both the source and target. +||Path| +|----------|------| +| BPE Code | [joint_bpe_18k](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/bpe_18k) | +| Dictionary | [dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/dict) | + +### Direct Models +For Ro-En with backtranslation, the direct and channel models use a Transformer-Big architecture. + +| Seed | Model | +|----|----| +| 2 | [ro_en_seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed2.pt) +| 4 | [ro_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed4.pt) +| 6 | [ro_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/direct_models/seed6.pt) + +### Channel Models +For channel models, we follow the same steps as for the direct models. But backtranslated data is generated in the opposite direction using [this Romanian monolingual data](http://data.statmt.org/rsennrich/wmt16_backtranslations/). +The best lenpen, LM weight and CH weight are obtained by sweeping over the validation set (wmt16/dev) using beam 5. +| Model Size | Lenpen | LM Weight | CH Weight | Seed 2 | Seed 4 | Seed 6 | +|----|----|----|----|----|----|----| +| `big` | 0.84 | 0.64 | 0.56 | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | [big.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/big.seed2.pt) | +| `base_1_1` | 0.63 | 0.40 | 0.37 | [base_1_1.seed2.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed2.pt) | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/channel_models/base_1_1.seed6.pt) | + +### Language Model +The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization. +| | Path | +|----|----| +| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/transformer_lm.pt) | +| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/ro_en/lm_model/lm_dict) + +## German-English Translation + +### BPE Codes and Dictionaries + +| | Path| +|----------|------| +| Source BPE Code | [de_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_bpe_code_24K) | +| Target BPE Code | [en_bpe_code_24K](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_bpe_code_24K) +| Source Dictionary | [de_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/de_dict) | +| Target Dictionary | [en_dict](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/en_dict) | + +### Direct Models +We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs. +We use the Transformer-Big architecture for the direct model. + +| Seed | Model | +|:----:|----| +| 4 | [de_en_seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed4.pt) +| 5 | [de_en_seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed5.pt) +| 6 | [de_en_seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/direct_models/seed6.pt) + +### Channel Models + +We train on WMT’19 training data. Following [Ng et al., 2019](http://statmt.org/wmt19/pdf/53/WMT33.pdf), we apply language identification filtering and remove sentences longer than 250 tokens as well as sentence pairs with a source/target length ratio exceeding 1.5. This results in 26.8M sentence pairs. + +| Model Size | Seed 4 | Seed 5 | Seed 6 | +|----|----|----|----| +| `big` | [big.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed4.pt) | [big.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed5.pt) | [big.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big.seed6.pt) | +| `big_1_1` | [big_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed4.pt) | [big_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed5.pt) | [big_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/big_1_1.seed6.pt) | +| `base` | [base.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed4.pt) | [base.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed5.pt) | [base.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base.seed6.pt) | +| `base_1_1` | [base_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed4.pt) | [base_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed5.pt) | [base_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/base_1_1.seed6.pt) | +| `half` | [half.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed4.pt) | [half.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed5.pt) | [half.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half.seed6.pt) | +| `half_1_1` | [half_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed4.pt) | [half_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed5.pt) | [half_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/half_1_1.seed6.pt) | +| `quarter` | [quarter.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed4.pt) | [quarter.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed5.pt) | [quarter.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter.seed6.pt) | +| `quarter_1_1` | [quarter_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed4.pt) | [quarter_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed5.pt) | [quarter_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/quarter_1_1.seed6.pt) | +| `8th` | [8th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed4.pt) | [8th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed5.pt) | [8th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th.seed6.pt) | +| `8th_1_1` | [8th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed4.pt) | [8th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed5.pt) | [8th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/8th_1_1.seed6.pt) | +| `16th` | [16th.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed4.pt) | [16th.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed5.pt) | [16th.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th.seed6.pt) | +| `16th_1_1` | [16th_1_1.seed4.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed4.pt) | [16th_1_1.seed5.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed5.pt) | [16th_1_1.seed6.pt](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/channel_models/16th_1_1.seed6.pt) | + +### Language Model +The model is trained on de-duplicated English Newscrawl data from 2007-2018 comprising 186 million sentences or 4.5B words after normalization and tokenization. +| | Path | +|----|----| +| `--lm-model` | [transformer_en_lm](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/transformer_lm.pt) | +| `--lm-data` | [lm_data](https://dl.fbaipublicfiles.com/fast_noisy_channel/de_en/lm_model/lm_dict/) + + +## Citation + +```bibtex +@inproceedings{bhosale2020language, + title={Language Models not just for Pre-training: Fast Online Neural Noisy Channel Modeling}, + author={Shruti Bhosale and Kyra Yee and Sergey Edunov and Michael Auli}, + booktitle={Proceedings of the Fifth Conference on Machine Translation (WMT)}, + year={2020}, +} + +@inproceedings{yee2019simple, + title={Simple and Effective Noisy Channel Modeling for Neural Machine Translation}, + author={Yee, Kyra and Dauphin, Yann and Auli, Michael}, + booktitle={Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP)}, + pages={5700--5705}, + year={2019} +} +``` diff --git a/examples/fast_noisy_channel/__init__.py b/examples/fast_noisy_channel/__init__.py new file mode 100644 index 0000000000..9b248c3a24 --- /dev/null +++ b/examples/fast_noisy_channel/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import noisy_channel_translation # noqa +from . import noisy_channel_sequence_generator # noqa +from . import noisy_channel_beam_search # noqa diff --git a/examples/fast_noisy_channel/noisy_channel_beam_search.py b/examples/fast_noisy_channel/noisy_channel_beam_search.py new file mode 100644 index 0000000000..23869ebcd0 --- /dev/null +++ b/examples/fast_noisy_channel/noisy_channel_beam_search.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from fairseq.search import Search + + +class NoisyChannelBeamSearch(Search): + + def __init__(self, tgt_dict): + super().__init__(tgt_dict) + self.fw_scores_buf = None + self.lm_scores_buf = None + + def _init_buffers(self, t): + # super()._init_buffers(t) + if self.fw_scores_buf is None: + self.scores_buf = t.new() + self.indices_buf = torch.LongTensor().to(device=t.device) + self.beams_buf = torch.LongTensor().to(device=t.device) + self.fw_scores_buf = t.new() + self.lm_scores_buf = t.new() + + def combine_fw_bw(self, combine_method, fw_cum, bw, step): + if combine_method == "noisy_channel": + fw_norm = fw_cum.div(step + 1) + lprobs = bw + fw_norm + elif combine_method == "lm_only": + lprobs = bw + fw_cum + + return lprobs + + def step(self, step, fw_lprobs, scores, bw_lprobs, lm_lprobs, combine_method): + self._init_buffers(fw_lprobs) + bsz, beam_size, vocab_size = fw_lprobs.size() + + if step == 0: + # at the first step all hypotheses are equally likely, so use + # only the first beam + fw_lprobs = fw_lprobs[:, ::beam_size, :].contiguous() + bw_lprobs = bw_lprobs[:, ::beam_size, :].contiguous() + # nothing to add since we are at the first step + fw_lprobs_cum = fw_lprobs + + else: + # make probs contain cumulative scores for each hypothesis + raw_scores = (scores[:, :, step - 1].unsqueeze(-1)) + fw_lprobs_cum = (fw_lprobs.add(raw_scores)) + + combined_lprobs = self.combine_fw_bw(combine_method, fw_lprobs_cum, bw_lprobs, step) + + # choose the top k according to the combined noisy channel model score + torch.topk( + combined_lprobs.view(bsz, -1), + k=min( + # Take the best 2 x beam_size predictions. We'll choose the first + # beam_size of these which don't predict eos to continue with. + beam_size * 2, + combined_lprobs.view(bsz, -1).size(1) - 1, # -1 so we never select pad + ), + out=(self.scores_buf, self.indices_buf), + ) + # save corresponding fw and lm scores + self.fw_scores_buf = torch.gather(fw_lprobs_cum.view(bsz, -1), 1, self.indices_buf) + self.lm_scores_buf = torch.gather(lm_lprobs.view(bsz, -1), 1, self.indices_buf) + # Project back into relative indices and beams + self.beams_buf = self.indices_buf // vocab_size + self.indices_buf.fmod_(vocab_size) + return self.scores_buf, self.fw_scores_buf, self.lm_scores_buf, self.indices_buf, self.beams_buf diff --git a/examples/fast_noisy_channel/noisy_channel_sequence_generator.py b/examples/fast_noisy_channel/noisy_channel_sequence_generator.py new file mode 100644 index 0000000000..ea8fae98e8 --- /dev/null +++ b/examples/fast_noisy_channel/noisy_channel_sequence_generator.py @@ -0,0 +1,842 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict, List, Optional + +import math +import numpy as np + +import torch +import torch.nn.functional as F +from torch import Tensor + +from .noisy_channel_beam_search import NoisyChannelBeamSearch +from fairseq.sequence_generator import EnsembleModel + + +class NoisyChannelSequenceGenerator(object): + def __init__( + self, + combine_method, + tgt_dict, + src_dict=None, + beam_size=1, + max_len_a=0, + max_len_b=200, + min_len=1, + len_penalty=1.0, + unk_penalty=0.0, + retain_dropout=False, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + normalize_scores=True, + channel_models=None, + k2=10, + ch_weight=1.0, + channel_scoring_type='log_norm', + top_k_vocab=0, + lm_models=None, + lm_dict=None, + lm_weight=1.0, + normalize_lm_scores_by_tgt_len=False, + ): + """Generates translations of a given source sentence, + using beam search with noisy channel decoding. + + Args: + combine_method (string, optional): Method to combine direct, LM and + channel model scores (default: None) + tgt_dict (~fairseq.data.Dictionary): target dictionary + src_dict (~fairseq.data.Dictionary): source dictionary + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + retain_dropout (bool, optional): use dropout when generating + (default: False) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + no_repeat_ngram_size (int, optional): Size of n-grams that we avoid + repeating in the generation (default: 0) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + channel_models (List[~fairseq.models.FairseqModel]): ensemble of models + translating from the target to the source + k2 (int, optional): Top K2 candidates to score per beam at each step (default:10) + ch_weight (int, optional): Weight associated with the channel model score + assuming that the direct model score has weight 1.0 (default: 1.0) + channel_scoring_type (str, optional): String specifying how to score + the channel model (default: 'log_norm') + top_k_vocab (int, optional): If `channel_scoring_type` is `'src_vocab'` or + `'src_vocab_batched'`, then this parameter specifies the number of + most frequent tokens to include in the channel model output vocabulary, + in addition to the source tokens in the input batch (default: 0) + lm_models (List[~fairseq.models.FairseqModel]): ensemble of models + generating text in the target language + lm_dict (~fairseq.data.Dictionary): LM Model dictionary + lm_weight (int, optional): Weight associated with the LM model score + assuming that the direct model score has weight 1.0 (default: 1.0) + normalize_lm_scores_by_tgt_len (bool, optional): Should we normalize LM scores + by the target length? By default, we normalize the combination of + LM and channel model scores by the source length + """ + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() + self.vocab_size = len(tgt_dict) + self.beam_size = beam_size + # the max beam size is the dictionary size - 1, since we never select pad + self.beam_size = min(beam_size, self.vocab_size - 1) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + self.retain_dropout = retain_dropout + self.temperature = temperature + self.match_source_len = match_source_len + self.no_repeat_ngram_size = no_repeat_ngram_size + self.channel_models = channel_models + self.src_dict = src_dict + self.tgt_dict = tgt_dict + self.combine_method = combine_method + self.k2 = k2 + self.ch_weight = ch_weight + self.channel_scoring_type = channel_scoring_type + self.top_k_vocab = top_k_vocab + self.lm_models = lm_models + self.lm_dict = lm_dict + self.lm_weight = lm_weight + self.log_softmax_fn = torch.nn.LogSoftmax(dim=1) + self.normalize_lm_scores_by_tgt_len = normalize_lm_scores_by_tgt_len + + self.share_tgt_dict = (self.lm_dict == self.tgt_dict) + self.tgt_to_lm = make_dict2dict(tgt_dict, lm_dict) + + self.ch_scoring_bsz = 3072 + + assert temperature > 0, '--temperature must be greater than 0' + + self.search = NoisyChannelBeamSearch(tgt_dict) + + @torch.no_grad() + def generate( + self, + models, + sample, + prefix_tokens=None, + bos_token=None, + **kwargs + ): + """Generate a batch of translations. + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + """ + model = EnsembleModel(models) + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(model.models_size) + ], + ) + if not self.retain_dropout: + model.eval() + + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in sample['net_input'].items() + if k != 'prev_output_tokens' + } + src_tokens = encoder_input['src_tokens'] + src_lengths_no_eos = (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) + input_size = src_tokens.size() + # batch dimension goes first followed by source lengths + bsz = input_size[0] + src_len = input_size[1] + beam_size = self.beam_size + + if self.match_source_len: + max_len = src_lengths_no_eos.max().item() + else: + max_len = min( + int(self.max_len_a * src_len + self.max_len_b), + # exclude the EOS marker + model.max_decoder_positions() - 1, + ) + + # compute the encoder output for each beam + encoder_outs = model.forward_encoder(encoder_input) + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_tokens.device).long() + encoder_outs = model.reorder_encoder_out(encoder_outs, new_order) + + src_lengths = encoder_input['src_lengths'] + # initialize buffers + scores = src_tokens.new(bsz * beam_size, max_len + 1).float().fill_(0) + lm_prefix_scores = src_tokens.new(bsz * beam_size).float().fill_(0) + + scores_buf = scores.clone() + tokens = src_tokens.new(bsz * beam_size, max_len + 2).long().fill_(self.pad) + tokens_buf = tokens.clone() + tokens[:, 0] = self.eos if bos_token is None else bos_token + + # reorder source tokens so they may be used as a reference in generating P(S|T) + src_tokens = reorder_all_tokens(src_tokens, src_lengths, self.src_dict.eos_index) + + src_tokens = src_tokens.repeat(1, beam_size).view(-1, src_len) + src_lengths = src_lengths.view(bsz, -1).repeat(1, beam_size).view(bsz*beam_size, -1) + + attn, attn_buf = None, None + nonpad_idxs = None + + # The cands_to_ignore indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then the cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = src_tokens.new_zeros(bsz, beam_size).eq(-1) # forward and backward-compatible False mask + + # list of completed sentences + finalized = [[] for i in range(bsz)] + finished = [False for i in range(bsz)] + num_remaining_sent = bsz + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) + cand_offsets = torch.arange(0, cand_size).type_as(tokens) + + # helper function for allocating buffers on the fly + buffers = {} + + def buffer(name, type_of=tokens): # noqa + if name not in buffers: + buffers[name] = type_of.new() + return buffers[name] + + def is_finished(sent, step, unfin_idx): + """ + Check whether we've finished generation for a given sentence, by + comparing the worst score among finalized hypotheses to the best + possible score among unfinalized hypotheses. + """ + assert len(finalized[sent]) <= beam_size + if len(finalized[sent]) == beam_size: + return True + return False + + def finalize_hypos(step, bbsz_idx, eos_scores, combined_noisy_channel_eos_scores): + """ + Finalize the given hypotheses at this step, while keeping the total + number of finalized hypotheses per sentence <= beam_size. + + Note: the input must be in the desired finalization order, so that + hypotheses that appear earlier in the input are preferred to those + that appear later. + + Args: + step: current time step + bbsz_idx: A vector of indices in the range [0, bsz*beam_size), + indicating which hypotheses to finalize + eos_scores: A vector of the same size as bbsz_idx containing + fw scores for each hypothesis + combined_noisy_channel_eos_scores: A vector of the same size as bbsz_idx containing + combined noisy channel scores for each hypothesis + """ + assert bbsz_idx.numel() == eos_scores.numel() + + # clone relevant token and attention tensors + tokens_clone = tokens.index_select(0, bbsz_idx) + tokens_clone = tokens_clone[:, 1:step + 2] # skip the first index, which is EOS + assert not tokens_clone.eq(self.eos).any() + tokens_clone[:, step] = self.eos + attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1:step+2] if attn is not None else None + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, :step+1] + pos_scores[:, step] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + combined_noisy_channel_eos_scores /= (step + 1) ** self.len_penalty + + cum_unfin = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + + sents_seen = set() + for i, (idx, score) in enumerate(zip(bbsz_idx.tolist(), combined_noisy_channel_eos_scores.tolist())): + unfin_idx = idx // beam_size + sent = unfin_idx + cum_unfin[unfin_idx] + + sents_seen.add((sent, unfin_idx)) + + if self.match_source_len and step > src_lengths_no_eos[unfin_idx]: + score = -math.inf + + def get_hypo(): + + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i][nonpad_idxs[sent]] + _, alignment = hypo_attn.max(dim=0) + else: + hypo_attn = None + alignment = None + + return { + 'tokens': tokens_clone[i], + 'score': score, + 'attention': hypo_attn, # src_len x tgt_len + 'alignment': alignment, + 'positional_scores': pos_scores[i], + } + + if len(finalized[sent]) < beam_size: + finalized[sent].append(get_hypo()) + + newly_finished = [] + for sent, unfin_idx in sents_seen: + # check termination conditions for this sentence + if not finished[sent] and is_finished(sent, step, unfin_idx): + finished[sent] = True + newly_finished.append(unfin_idx) + return newly_finished + + def noisy_channel_rescoring(lprobs, beam_size, bsz, src_tokens, tokens, k): + """Rescore the top k hypothesis from each beam using noisy channel modeling + Returns: + new_fw_lprobs: the direct model probabilities after pruning the top k + new_ch_lm_lprobs: the combined channel and language model probabilities + new_lm_lprobs: the language model probabilities after pruning the top k + """ + with torch.no_grad(): + lprobs_size = lprobs.size() + if prefix_tokens is not None and step < prefix_tokens.size(1): + probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :] + cand_scores = torch.gather( + probs_slice, dim=1, + index=prefix_tokens[:, step].view(-1, 1).data + ).expand(-1, beam_size).contiguous().view(bsz*beam_size, 1) + cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, beam_size).data.contiguous().view(bsz*beam_size, 1) + + # need to calculate and save fw and lm probs for prefix tokens + fw_top_k = cand_scores + fw_top_k_idx = cand_indices + k = 1 + else: + # take the top k best words for every sentence in batch*beam + fw_top_k, fw_top_k_idx = torch.topk(lprobs.view(beam_size*bsz, -1), k=k) + eos_idx = torch.nonzero(fw_top_k_idx.view(bsz*beam_size*k, -1) == self.eos)[:, 0] + ch_scores = fw_top_k.new_full((beam_size*bsz*k, ), 0) + src_size = torch.sum(src_tokens[:, :] != self.src_dict.pad_index, dim=1, keepdim=True, dtype=fw_top_k.dtype) + + if self.combine_method != "lm_only": + temp_src_tokens_full = src_tokens[:, :].repeat(1, k).view(bsz*beam_size*k, -1) + not_padding = temp_src_tokens_full[:, 1:] != self.src_dict.pad_index + cur_tgt_size = step+2 + + # add eos to all candidate sentences except those that already end in eos + eos_tokens = tokens[:, 0].repeat(1, k).view(-1, 1) + eos_tokens[eos_idx] = self.tgt_dict.pad_index + + if step == 0: + channel_input = torch.cat((fw_top_k_idx.view(-1, 1), eos_tokens), 1) + else: + # move eos from beginning to end of target sentence + channel_input = torch.cat((tokens[:, 1:step + 1].repeat(1, k).view(-1, step), fw_top_k_idx.view(-1, 1), eos_tokens), 1) + + ch_input_lengths = torch.tensor(np.full(channel_input.size(0), cur_tgt_size)) + ch_input_lengths[eos_idx] = cur_tgt_size-1 + if self.channel_scoring_type == "unnormalized": + ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths) + ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True) + del ch_encoder_output + ch_intermed_scores = channel_model.decoder.unnormalized_scores_given_target(ch_decoder_output, target_ids=temp_src_tokens_full[:, 1:]) + ch_intermed_scores = ch_intermed_scores.float() + ch_intermed_scores *= not_padding.float() + ch_scores = torch.sum(ch_intermed_scores, dim=1) + elif self.channel_scoring_type == "k2_separate": + for k_idx in range(k): + k_eos_tokens = eos_tokens[k_idx::k, :] + if step == 0: + k_ch_input = torch.cat((fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1) + else: + # move eos from beginning to end of target sentence + k_ch_input = torch.cat((tokens[:, 1:step + 1], fw_top_k_idx[:, k_idx:k_idx+1], k_eos_tokens), 1) + k_ch_input_lengths = ch_input_lengths[k_idx::k] + k_ch_output = channel_model(k_ch_input, k_ch_input_lengths, src_tokens) + k_ch_lprobs = channel_model.get_normalized_probs(k_ch_output, log_probs=True) + k_ch_intermed_scores = torch.gather(k_ch_lprobs[:, :-1, :], 2, src_tokens[:, 1:].unsqueeze(2)).squeeze(2) + k_ch_intermed_scores *= not_padding.float() + ch_scores[k_idx::k] = torch.sum(k_ch_intermed_scores, dim=1) + elif self.channel_scoring_type == "src_vocab": + ch_encoder_output = channel_model.encoder(channel_input, src_lengths=ch_input_lengths) + ch_decoder_output, _ = channel_model.decoder(temp_src_tokens_full, encoder_out=ch_encoder_output, features_only=True) + + del ch_encoder_output + ch_lprobs = normalized_scores_with_batch_vocab( + channel_model.decoder, + ch_decoder_output, src_tokens, k, bsz, beam_size, + self.src_dict.pad_index, top_k=self.top_k_vocab) + ch_scores = torch.sum(ch_lprobs, dim=1) + elif self.channel_scoring_type == "src_vocab_batched": + ch_bsz_size = temp_src_tokens_full.shape[0] + ch_lprobs_list = [None] * len(range(0, ch_bsz_size, self.ch_scoring_bsz)) + for i, start_idx in enumerate(range(0, ch_bsz_size, self.ch_scoring_bsz)): + end_idx = min(start_idx + self.ch_scoring_bsz, ch_bsz_size) + temp_src_tokens_full_batch = temp_src_tokens_full[start_idx:end_idx, :] + channel_input_batch = channel_input[start_idx:end_idx, :] + ch_input_lengths_batch = ch_input_lengths[start_idx:end_idx] + ch_encoder_output_batch = channel_model.encoder(channel_input_batch, src_lengths=ch_input_lengths_batch) + ch_decoder_output_batch, _ = channel_model.decoder(temp_src_tokens_full_batch, encoder_out=ch_encoder_output_batch, features_only=True) + ch_lprobs_list[i] = normalized_scores_with_batch_vocab( + channel_model.decoder, + ch_decoder_output_batch, src_tokens, k, bsz, beam_size, + self.src_dict.pad_index, top_k=self.top_k_vocab, + start_idx=start_idx, end_idx=end_idx) + ch_lprobs = torch.cat(ch_lprobs_list, dim=0) + ch_scores = torch.sum(ch_lprobs, dim=1) + else: + ch_output = channel_model(channel_input, ch_input_lengths, temp_src_tokens_full) + ch_lprobs = channel_model.get_normalized_probs(ch_output, log_probs=True) + ch_intermed_scores = torch.gather(ch_lprobs[:, :-1, :], 2, temp_src_tokens_full[:, 1:].unsqueeze(2)).squeeze().view(bsz*beam_size*k, -1) + ch_intermed_scores *= not_padding.float() + ch_scores = torch.sum(ch_intermed_scores, dim=1) + + else: + cur_tgt_size = 0 + ch_scores = ch_scores.view(bsz*beam_size, k) + expanded_lm_prefix_scores = lm_prefix_scores.unsqueeze(1).expand(-1, k).flatten() + + if self.share_tgt_dict: + lm_scores = get_lm_scores(lm, tokens[:, :step + 1].view(-1, step+1), lm_incremental_states, fw_top_k_idx.view(-1, 1), torch.tensor(np.full(tokens.size(0), step+1)), k) + else: + new_lm_input = dict2dict(tokens[:, :step + 1].view(-1, step+1), self.tgt_to_lm) + new_cands = dict2dict(fw_top_k_idx.view(-1, 1), self.tgt_to_lm) + lm_scores = get_lm_scores(lm, new_lm_input, lm_incremental_states, new_cands, torch.tensor(np.full(tokens.size(0), step+1)), k) + + lm_scores.add_(expanded_lm_prefix_scores) + ch_lm_scores = combine_ch_lm(self.combine_method, ch_scores, lm_scores, src_size, cur_tgt_size) + # initialize all as min value + new_fw_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1) + new_ch_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1) + new_lm_lprobs = ch_scores.new(lprobs_size).fill_(-1e17).view(bsz*beam_size, -1) + new_fw_lprobs[:, self.pad] = -math.inf + new_ch_lm_lprobs[:, self.pad] = -math.inf + new_lm_lprobs[:, self.pad] = -math.inf + + new_fw_lprobs.scatter_(1, fw_top_k_idx, fw_top_k) + new_ch_lm_lprobs.scatter_(1, fw_top_k_idx, ch_lm_scores) + new_lm_lprobs.scatter_(1, fw_top_k_idx, lm_scores.view(-1, k)) + return new_fw_lprobs, new_ch_lm_lprobs, new_lm_lprobs + + def combine_ch_lm(combine_type, ch_scores, lm_scores1, src_size, tgt_size): + if self.channel_scoring_type == "unnormalized": + ch_scores = self.log_softmax_fn( + ch_scores.view(-1, self.beam_size * self.k2) + ).view(ch_scores.shape) + ch_scores = ch_scores * self.ch_weight + lm_scores1 = lm_scores1 * self.lm_weight + + if combine_type == "lm_only": + # log P(T|S) + log P(T) + ch_scores = lm_scores1.view(ch_scores.size()) + elif combine_type == "noisy_channel": + # 1/t log P(T|S) + 1/s log P(S|T) + 1/t log P(T) + if self.normalize_lm_scores_by_tgt_len: + ch_scores.div_(src_size) + lm_scores_norm = lm_scores1.view(ch_scores.size()).div(tgt_size) + ch_scores.add_(lm_scores_norm) + # 1/t log P(T|S) + 1/s log P(S|T) + 1/s log P(T) + else: + ch_scores.add_(lm_scores1.view(ch_scores.size())) + ch_scores.div_(src_size) + + return ch_scores + + if self.channel_models is not None: + channel_model = self.channel_models[0] # assume only one channel_model model + else: + channel_model = None + + lm = EnsembleModel(self.lm_models) + lm_incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(lm.models_size) + ], + ) + + reorder_state = None + batch_idxs = None + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as(batch_idxs) + reorder_state.view(-1, beam_size).add_(corr.unsqueeze(-1) * beam_size) + model.reorder_incremental_state(incremental_states, reorder_state) + encoder_outs = model.reorder_encoder_out(encoder_outs, reorder_state) + + lm.reorder_incremental_state(lm_incremental_states, reorder_state) + + fw_lprobs, avg_attn_scores = model.forward_decoder( + tokens[:, :step + 1], encoder_outs, incremental_states, temperature=self.temperature, + ) + + fw_lprobs[:, self.pad] = -math.inf # never select pad + fw_lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + fw_lprobs, ch_lm_lprobs, lm_lprobs = noisy_channel_rescoring(fw_lprobs, beam_size, bsz, src_tokens, tokens, self.k2) + + # handle min and max length constraints + if step >= max_len: + fw_lprobs[:, :self.eos] = -math.inf + fw_lprobs[:, self.eos + 1:] = -math.inf + elif step < self.min_len: + fw_lprobs[:, self.eos] = -math.inf + + # handle prefix tokens (possibly with different lengths) + if prefix_tokens is not None and step < prefix_tokens.size(1): + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) + prefix_mask = prefix_toks.ne(self.pad) + + prefix_fw_lprobs = fw_lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + fw_lprobs[prefix_mask] = -math.inf + fw_lprobs[prefix_mask] = fw_lprobs[prefix_mask].scatter_( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_fw_lprobs + ) + + prefix_ch_lm_lprobs = ch_lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + ch_lm_lprobs[prefix_mask] = -math.inf + ch_lm_lprobs[prefix_mask] = ch_lm_lprobs[prefix_mask].scatter_( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_ch_lm_lprobs + ) + + prefix_lm_lprobs = lm_lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + lm_lprobs[prefix_mask] = -math.inf + lm_lprobs[prefix_mask] = lm_lprobs[prefix_mask].scatter_( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lm_lprobs + ) + + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[:, 0, 1:step + 1] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + def replicate_first_beam(tensor, mask): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = replicate_first_beam(tokens, eos_mask_batch_dim) + scores = replicate_first_beam(scores, eos_mask_batch_dim) + + fw_lprobs = replicate_first_beam(fw_lprobs, eos_mask_batch_dim) + ch_lm_lprobs = replicate_first_beam(ch_lm_lprobs, eos_mask_batch_dim) + lm_lprobs = replicate_first_beam(lm_lprobs, eos_mask_batch_dim) + + if self.no_repeat_ngram_size > 0: + # for each beam and batch sentence, generate a list of previous ngrams + gen_ngrams = [{} for bbsz_idx in range(bsz * beam_size)] + for bbsz_idx in range(bsz * beam_size): + gen_tokens = tokens[bbsz_idx].tolist() + for ngram in zip(*[gen_tokens[i:] for i in range(self.no_repeat_ngram_size)]): + gen_ngrams[bbsz_idx][tuple(ngram[:-1])] = \ + gen_ngrams[bbsz_idx].get(tuple(ngram[:-1]), []) + [ngram[-1]] + + # Record attention scores + if avg_attn_scores is not None: + if attn is None: + attn = scores.new(bsz * beam_size, src_tokens.size(1), max_len + 2) + attn_buf = attn.clone() + nonpad_idxs = src_tokens.ne(self.pad) + attn[:, :, step + 1].copy_(avg_attn_scores) + + scores = scores.type_as(fw_lprobs) + scores_buf = scores_buf.type_as(fw_lprobs) + + self.search.set_src_lengths(src_lengths_no_eos) + + if self.no_repeat_ngram_size > 0: + def calculate_banned_tokens(bbsz_idx): + # before decoding the next token, prevent decoding of ngrams that have already appeared + ngram_index = tuple(tokens[bbsz_idx, step + 2 - self.no_repeat_ngram_size:step + 1].tolist()) + return gen_ngrams[bbsz_idx].get(ngram_index, []) + + if step + 2 - self.no_repeat_ngram_size >= 0: + # no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + banned_tokens = [calculate_banned_tokens(bbsz_idx) for bbsz_idx in range(bsz * beam_size)] + else: + banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)] + + for bbsz_idx in range(bsz * beam_size): + fw_lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf + + combined_noisy_channel_scores, fw_lprobs_top_k, lm_lprobs_top_k, cand_indices, cand_beams = self.search.step( + step, + fw_lprobs.view(bsz, -1, self.vocab_size), + scores.view(bsz, beam_size, -1)[:, :, :step], ch_lm_lprobs.view(bsz, -1, self.vocab_size), + lm_lprobs.view(bsz, -1, self.vocab_size), self.combine_method + ) + + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos (except for candidates to be ignored) + eos_mask = cand_indices.eq(self.eos) + eos_mask[:, :beam_size] &= ~cands_to_ignore + + # only consider eos when it's among the top beam_size indices + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents = set() + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.masked_select( + fw_lprobs_top_k[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + combined_noisy_channel_eos_scores = torch.masked_select( + combined_noisy_channel_scores[:, :beam_size], + mask=eos_mask[:, :beam_size], + ) + + # finalize hypo using channel model score + finalized_sents = finalize_hypos( + step, eos_bbsz_idx, eos_scores, combined_noisy_channel_eos_scores) + + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = cand_indices.new_ones(bsz) + batch_mask[cand_indices.new(finalized_sents)] = 0 + batch_idxs = torch.nonzero(batch_mask).squeeze(-1) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + lm_lprobs_top_k = lm_lprobs_top_k[batch_idxs] + + fw_lprobs_top_k = fw_lprobs_top_k[batch_idxs] + cand_indices = cand_indices[batch_idxs] + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths_no_eos = src_lengths_no_eos[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + scores_buf.resize_as_(scores) + tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + tokens_buf.resize_as_(tokens) + src_tokens = src_tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + src_lengths = src_lengths.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + lm_prefix_scores = lm_prefix_scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1).squeeze() + + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, attn.size(1), -1) + attn_buf.resize_as_(attn) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos or + # ignored hypos and values < cand_size indicate candidate + # active hypos. After this, the min values per row are the top + # candidate active hypos. + eos_mask[:, :beam_size] |= cands_to_ignore + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[: eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just the hypos + # with the smallest values in active_mask + active_hypos, new_cands_to_ignore = buffer('active_hypos'), buffer('new_cands_to_ignore') + torch.topk( + active_mask, k=beam_size, dim=1, largest=False, + out=(new_cands_to_ignore, active_hypos) + ) + + # update cands_to_ignore to ignore any finalized hypos + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + assert (~cands_to_ignore).any(dim=1).all() + + active_bbsz_idx = buffer('active_bbsz_idx') + torch.gather( + cand_bbsz_idx, dim=1, index=active_hypos, + out=active_bbsz_idx, + ) + active_scores = torch.gather( + fw_lprobs_top_k, dim=1, index=active_hypos, + out=scores[:, step].view(bsz, beam_size), + ) + + active_bbsz_idx = active_bbsz_idx.view(-1) + active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + torch.index_select( + tokens[:, :step + 1], dim=0, index=active_bbsz_idx, + out=tokens_buf[:, :step + 1], + ) + torch.gather( + cand_indices, dim=1, index=active_hypos, + out=tokens_buf.view(bsz, beam_size, -1)[:, :, step + 1], + ) + if step > 0: + torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx, + out=scores_buf[:, :step], + ) + torch.gather( + fw_lprobs_top_k, dim=1, index=active_hypos, + out=scores_buf.view(bsz, beam_size, -1)[:, :, step], + ) + torch.gather( + lm_lprobs_top_k, dim=1, index=active_hypos, + out=lm_prefix_scores.view(bsz, beam_size) + ) + + # copy attention for active hypotheses + if attn is not None: + torch.index_select( + attn[:, :, :step + 2], dim=0, index=active_bbsz_idx, + out=attn_buf[:, :, :step + 2], + ) + + # swap buffers + tokens, tokens_buf = tokens_buf, tokens + scores, scores_buf = scores_buf, scores + if attn is not None: + attn, attn_buf = attn_buf, attn + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # sort by score descending + for sent in range(len(finalized)): + finalized[sent] = sorted(finalized[sent], key=lambda r: r['score'], reverse=True) + + return finalized + + +def get_lm_scores(model, input_tokens, incremental_states, cand_tokens, input_len, k): + with torch.no_grad(): + lm_lprobs, avg_attn_scores = model.forward_decoder( + input_tokens, encoder_outs=None, incremental_states=incremental_states, + ) + + lm_lprobs_size = lm_lprobs.size(0) + probs_next_wrd = torch.gather(lm_lprobs.repeat(1, k).view(lm_lprobs_size*k, -1), 1, cand_tokens).squeeze().view(-1) + + return probs_next_wrd + + +def make_dict2dict(old_dict, new_dict): + dict2dict_map = {} + for sym in old_dict.symbols: + dict2dict_map[old_dict.index(sym)] = new_dict.index(sym) + return dict2dict_map + + +def dict2dict(tokens, dict2dict_map): + if tokens.device == torch.device('cpu'): + tokens_tmp = tokens + else: + tokens_tmp = tokens.cpu() + return tokens_tmp.map_( + tokens_tmp, + lambda _, val, dict2dict_map=dict2dict_map : dict2dict_map[float(val)] + ).to(tokens.device) + + +def reorder_tokens(tokens, lengths, eos): + # reorder source tokens so they may be used as reference for P(S|T) + return torch.cat((tokens.new([eos]), tokens[-lengths:-1], tokens[:-lengths]), 0) + + +def reorder_all_tokens(tokens, lengths, eos): + # used to reorder src tokens from [ .. ] to [ ...] + # so source tokens can be used to predict P(S|T) + return torch.stack([reorder_tokens(token, length, eos) for token, length in zip(tokens, lengths)]) + + +def normalized_scores_with_batch_vocab( + model_decoder, features, target_ids, k, bsz, beam_size, + pad_idx, top_k=0, vocab_size_meter=None, start_idx=None, + end_idx=None, **kwargs): + """ + Get normalized probabilities (or log probs) from a net's output + w.r.t. vocab consisting of target IDs in the batch + """ + if model_decoder.adaptive_softmax is None: + weight = model_decoder.output_projection.weight + vocab_ids = torch.unique( + torch.cat( + (torch.unique(target_ids), torch.arange(top_k, device=target_ids.device)) + ) + ) + id_map = dict(zip(vocab_ids.tolist(), range(len(vocab_ids)))) + mapped_target_ids = target_ids.cpu().apply_( + lambda x, id_map=id_map: id_map[x] + ).to(target_ids.device) + expanded_target_ids = mapped_target_ids[:, :].repeat(1, k).view(bsz*beam_size*k, -1) + if start_idx is not None and end_idx is not None: + expanded_target_ids = expanded_target_ids[start_idx:end_idx, :] + logits = F.linear(features, weight[vocab_ids, :]) + log_softmax = F.log_softmax(logits, dim=-1, dtype=torch.float32) + intermed_scores = torch.gather( + log_softmax[:, :-1, :], + 2, + expanded_target_ids[:, 1:].unsqueeze(2), + ).squeeze() + not_padding = expanded_target_ids[:, 1:] != pad_idx + intermed_scores *= not_padding.float() + return intermed_scores + else: + raise ValueError("adaptive softmax doesn't work with " + + "`normalized_scores_with_batch_vocab()`") diff --git a/examples/fast_noisy_channel/noisy_channel_translation.py b/examples/fast_noisy_channel/noisy_channel_translation.py new file mode 100644 index 0000000000..b74bdfd456 --- /dev/null +++ b/examples/fast_noisy_channel/noisy_channel_translation.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from fairseq.tasks.translation import TranslationTask +from fairseq.tasks.language_modeling import LanguageModelingTask +from fairseq import checkpoint_utils +import argparse +from fairseq.tasks import register_task +import torch + + +@register_task("noisy_channel_translation") +class NoisyChannelTranslation(TranslationTask): + """ + Rescore the top k candidates from each beam using noisy channel modeling + """ + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + TranslationTask.add_args(parser) + # fmt: off + parser.add_argument('--channel-model', metavar='FILE', + help='path to P(S|T) model. P(S|T) and P(T|S) must share source and target dictionaries.') + parser.add_argument('--combine-method', default='lm_only', + choices=['lm_only', 'noisy_channel'], + help="""method for combining direct and channel model scores. + lm_only: decode with P(T|S)P(T) + noisy_channel: decode with 1/t P(T|S) + 1/s(P(S|T)P(T))""") + parser.add_argument('--normalize-lm-scores-by-tgt-len', action='store_true', default=False, + help='normalize lm score by target length instead of source length') + parser.add_argument('--channel-scoring-type', default='log_norm', choices=['unnormalized', 'log_norm', 'k2_separate', 'src_vocab', 'src_vocab_batched'], + help="Normalize bw scores with log softmax or return bw scores without log softmax") + parser.add_argument('--top-k-vocab', default=0, type=int, + help='top k vocab IDs to use with `src_vocab` in channel model scoring') + parser.add_argument('--k2', default=50, type=int, + help='the top k2 candidates to rescore with the noisy channel model for each beam') + parser.add_argument('--ch-wt', default=1, type=float, + help='weight for the channel model') + parser.add_argument('--lm-model', metavar='FILE', + help='path to lm model file, to model P(T). P(T) must share the same vocab as the direct model on the target side') + parser.add_argument('--lm-data', metavar='FILE', + help='path to lm model training data for target language, used to properly load LM with correct dictionary') + parser.add_argument('--lm-wt', default=1, type=float, + help='the weight of the lm in joint decoding') + # fmt: on + + def build_generator( + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None + ): + if getattr(args, "score_reference", False): + raise NotImplementedError() + else: + from .noisy_channel_sequence_generator import NoisyChannelSequenceGenerator + use_cuda = torch.cuda.is_available() and not self.args.cpu + assert self.args.lm_model is not None, '--lm-model required for noisy channel generation!' + assert self.args.lm_data is not None, '--lm-data required for noisy channel generation to map between LM and bitext vocabs' + if self.args.channel_model is not None: + import copy + ch_args_task = copy.deepcopy(self.args) + tmp = ch_args_task.source_lang + ch_args_task.source_lang = ch_args_task.target_lang + ch_args_task.target_lang = tmp + ch_args_task._name = 'translation' + channel_task = TranslationTask.setup_task(ch_args_task) + + arg_dict = {} + arg_dict['task'] = 'language_modeling' + arg_dict['sample_break_mode'] = 'eos' + arg_dict['data'] = self.args.lm_data + arg_dict['output_dictionary_size'] = -1 + lm_args = argparse.Namespace(**arg_dict) + lm_task = LanguageModelingTask.setup_task(lm_args) + lm_dict = lm_task.output_dictionary + + if self.args.channel_model is not None: + channel_models, _ = checkpoint_utils.load_model_ensemble(self.args.channel_model.split(':'), task=channel_task) + + for model in channel_models: + model.make_generation_fast_( + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, + need_attn=args.print_alignment, + ) + if self.args.fp16: + model.half() + if use_cuda: + model.cuda() + else: + channel_models = None + + lm_models, _ = checkpoint_utils.load_model_ensemble(self.args.lm_model.split(':'), task=lm_task) + + for model in lm_models: + model.make_generation_fast_( + beamable_mm_beam_size=None if args.no_beamable_mm else args.beam, + need_attn=args.print_alignment, + ) + if self.args.fp16: + model.half() + if use_cuda: + model.cuda() + return NoisyChannelSequenceGenerator( + combine_method=self.args.combine_method, + tgt_dict=self.target_dictionary, + src_dict=self.source_dictionary, + beam_size=getattr(args, 'beam', 5), + max_len_a=getattr(args, 'max_len_a', 0), + max_len_b=getattr(args, 'max_len_b', 200), + min_len=getattr(args, 'min_len', 1), + len_penalty=getattr(args, 'lenpen', 1), + unk_penalty=getattr(args, 'unkpen', 0), + temperature=getattr(args, 'temperature', 1.), + match_source_len=getattr(args, 'match_source_len', False), + no_repeat_ngram_size=getattr(args, 'no_repeat_ngram_size', 0), + normalize_scores=(not getattr(args, 'unnormalized', False)), + channel_models=channel_models, + k2=getattr(self.args, 'k2', 50), + ch_weight=getattr(self.args, 'ch_wt', 1), + channel_scoring_type=self.args.channel_scoring_type, + top_k_vocab=self.args.top_k_vocab, + lm_models=lm_models, + lm_dict=lm_dict, + lm_weight=getattr(self.args, 'lm_wt', 1), + normalize_lm_scores_by_tgt_len=getattr(self.args, 'normalize_lm_scores_by_tgt_len', False), + ) diff --git a/examples/gottbert/README.md b/examples/gottbert/README.md new file mode 100644 index 0000000000..1d58feb279 --- /dev/null +++ b/examples/gottbert/README.md @@ -0,0 +1,64 @@ +# GottBERT: a pure German language model + +## Introduction + +[GottBERT](http://arxiv.org/abs/2012.02110) is a pretrained language model trained on 145GB of German text based on RoBERTa. + +## Example usage + +### fairseq +##### Load GottBERT from torch.hub (PyTorch >= 1.1): +```python +import torch +gottbert = torch.hub.load('pytorch/fairseq', 'gottbert-base') +gottbert.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Load GottBERT (for PyTorch 1.0 or custom models): +```python +# Download gottbert model +wget https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz +tar -xzvf gottbert.tar.gz + +# Load the model in fairseq +from fairseq.models.roberta import GottbertModel +gottbert = GottbertModel.from_pretrained('/path/to/gottbert') +gottbert.eval() # disable dropout (or leave in train mode to finetune) +``` + +##### Filling masks: +```python +masked_line = 'Gott ist ! :)' +gottbert.fill_mask(masked_line, topk=3) +# [('Gott ist gut ! :)', 0.3642110526561737, ' gut'), +# ('Gott ist überall ! :)', 0.06009674072265625, ' überall'), +# ('Gott ist großartig ! :)', 0.0370681993663311, ' großartig')] +``` + +##### Extract features from GottBERT + +```python +# Extract the last layer's features +line = "Der erste Schluck aus dem Becher der Naturwissenschaft macht atheistisch , aber auf dem Grunde des Bechers wartet Gott !" +tokens = gottbert.encode(line) +last_layer_features = gottbert.extract_features(tokens) +assert last_layer_features.size() == torch.Size([1, 27, 768]) + +# Extract all layer's features (layer 0 is the embedding layer) +all_layers = gottbert.extract_features(tokens, return_all_hiddens=True) +assert len(all_layers) == 13 +assert torch.all(all_layers[-1] == last_layer_features) +``` +## Citation +If you use our work, please cite: + +```bibtex +@misc{scheible2020gottbert, + title={GottBERT: a pure German Language Model}, + author={Raphael Scheible and Fabian Thomczyk and Patric Tippmann and Victor Jaravine and Martin Boeker}, + year={2020}, + eprint={2012.02110}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` diff --git a/examples/language_model/README.adaptive_inputs.md b/examples/language_model/README.adaptive_inputs.md index 6873467115..98043c5377 100644 --- a/examples/language_model/README.adaptive_inputs.md +++ b/examples/language_model/README.adaptive_inputs.md @@ -19,8 +19,8 @@ fairseq-train --task language_modeling \ data-bin/wikitext-103 \ --save-dir checkpoints/transformer_wikitext-103 \ --arch transformer_lm_wiki103 \ - --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \ - --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \ + --max-update 286000 --lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \ + --warmup-updates 16000 --warmup-init-lr 1e-07 --stop-min-lr 1e-09 --optimizer nag --min-lr 0.0001 --clip-norm 0.1 \ --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \ --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d ``` diff --git a/examples/latent_depth/README.md b/examples/latent_depth/README.md index a0ec55a3f6..e70e16405c 100644 --- a/examples/latent_depth/README.md +++ b/examples/latent_depth/README.md @@ -25,7 +25,7 @@ fairseq-train ${databin_dir} \ --share-decoder-input-output-embed \ --dropout 0.3 --attention-dropout 0.3 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler inverse_sqrt --min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \ + --lr-scheduler inverse_sqrt --stop-min-lr 1e-9 --warmup-init-lr 1e-7 --warmup-updates 8000 \ --max-tokens 4096 --update-freq 1 \ --lr 0.0015 \ --clip-norm 1.0 \ diff --git a/examples/linformer/linformer_src/modules/multihead_linear_attention.py b/examples/linformer/linformer_src/modules/multihead_linear_attention.py index ba2c36b1ef..6be1007279 100644 --- a/examples/linformer/linformer_src/modules/multihead_linear_attention.py +++ b/examples/linformer/linformer_src/modules/multihead_linear_attention.py @@ -111,14 +111,10 @@ def __init__( self.compress_v.weight.requires_grad = False self.onnx_trace = False - self.tpu = False def prepare_for_onnx_export_(self): self.onnx_trace = True - def prepare_for_tpu_(self, **kwargs): - self.tpu = True - def reset_parameters(self): if self.qkv_same_dim: # Empirically observed the convergence to be much better with diff --git a/examples/m2m_100/README.md b/examples/m2m_100/README.md index 0bacd4c8b1..f1b465c7b9 100644 --- a/examples/m2m_100/README.md +++ b/examples/m2m_100/README.md @@ -14,8 +14,8 @@ sacrebleu -t wmt14 -l fr-en --echo src > wmt.test.fr-en.fr sacrebleu -t wmt14 -l fr-en --echo ref > wmt.test.fr-en.en # WAT -wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2019.my-en.zip -unzip wat2019.my-en.zip +wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2020.my-en.zip +unzip wat2020.my-en.zip # FLORES # download from: https://github.com/facebookresearch/flores @@ -116,7 +116,22 @@ If you use any of the resources listed here, please cite: ## Trained Models -More models coming up soon. +### 418M and 1.2B Model +We include the last checkpoint for both of these models. + +```bash +wget https://dl.fbaipublicfiles.com/m2m_100/model_dict.128k.txt +wget https://dl.fbaipublicfiles.com/m2m_100/language_pairs_small_models.txt + +# 418M parameter model +wget https://dl.fbaipublicfiles.com/m2m_100/418M_last_checkpoint.pt + +# 1.2B parameter model +wget https://dl.fbaipublicfiles.com/m2m_100/1.2B_last_checkpoint.pt + +# Generation: +fairseq-generate $binarized_data_path --batch-size 32 --path $path_to_model -s en -t fr --remove-bpe 'sentencepiece' --beam 5 --task translation_multi_simple_epoch --lang-pairs language_pairs_small_models --decoder-langtok --encoder-langtok src --gen-subset test > gen_out +``` ### 12B Model 12B parameter model trained on many-to-many training data for 100 languages. We include the last checkpoint, average of last 5 checkpoints, average of last 10 checkpoints. There isn't a universally best choice out of these three, but all three versions are pretty close in accuracy. You can either sweep over the 3 checkpoints on a dev test and use the best performing checkpoint for final testing. Or the last checkpoint can be a good default choice. diff --git a/examples/mbart/README.md b/examples/mbart/README.md index 510edeff64..8a3e22d425 100644 --- a/examples/mbart/README.md +++ b/examples/mbart/README.md @@ -9,7 +9,7 @@ MBART is a sequence-to-sequence denoising auto-encoder pre-trained on large-scal Model | Description | # params | Download ---|---|---|--- -`mbart.CC25` | mBART model with 12 encoder and decoder layers trained on 25 languages' monolingual corpus | 610M | [mbart.CC25.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.CC25.tar.gz) +`mbart.CC25` | mBART model with 12 encoder and decoder layers trained on 25 languages' monolingual corpus | 610M | [mbart.CC25.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.v2.tar.gz) `mbart.ft.ro_en` | finetune mBART cc25 model on ro-en language pairs | 610M | [mbart.cc25.ft.enro.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.ft.enro.tar.gz) ## Results @@ -26,7 +26,7 @@ Model | en-ro | ro-en ## BPE data # download model -wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.CC25.tar.gz +wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.v2.tar.gz tar -xzvf mbart.CC25.tar.gz # bpe data install SPM [here](https://github.com/google/sentencepiece) @@ -73,7 +73,7 @@ fairseq-train path_2_data \ --source-lang en_XX --target-lang ro_RO \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \ + --lr-scheduler polynomial_decay --lr 3e-05 --warmup-updates 2500 --total-num-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ diff --git a/examples/multilingual/README.md b/examples/multilingual/README.md index 3559c244e2..35eca89804 100644 --- a/examples/multilingual/README.md +++ b/examples/multilingual/README.md @@ -41,7 +41,7 @@ fairseq-train $path_2_data \ --lang-pairs "$lang_pairs" \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ @@ -69,7 +69,7 @@ fairseq-train $path_2_data \ --lang-pairs "$lang_pairs" \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ diff --git a/examples/multilingual/finetune_multilingual_model.sh b/examples/multilingual/finetune_multilingual_model.sh index cfa9a86113..ffcf1fc722 100644 --- a/examples/multilingual/finetune_multilingual_model.sh +++ b/examples/multilingual/finetune_multilingual_model.sh @@ -20,7 +20,7 @@ fairseq-train "$path_2_data" \ --lang-pairs "$lang_pairs" \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ diff --git a/examples/multilingual/train_multilingual_model.sh b/examples/multilingual/train_multilingual_model.sh index 09014c8217..c41730dfcd 100644 --- a/examples/multilingual/train_multilingual_model.sh +++ b/examples/multilingual/train_multilingual_model.sh @@ -16,7 +16,7 @@ fairseq-train "$path_2_data" \ --lang-pairs "$lang_pairs" \ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \ - --lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \ + --lr-scheduler inverse_sqrt --lr 3e-05 --warmup-updates 2500 --max-update 40000 \ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \ --max-tokens 1024 --update-freq 2 \ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \ diff --git a/examples/nonautoregressive_translation/README.md b/examples/nonautoregressive_translation/README.md index dfc592f0a0..7b2d42a91d 100644 --- a/examples/nonautoregressive_translation/README.md +++ b/examples/nonautoregressive_translation/README.md @@ -44,7 +44,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ diff --git a/examples/nonautoregressive_translation/scripts.md b/examples/nonautoregressive_translation/scripts.md index 63b945c1d3..a3a33e6e02 100644 --- a/examples/nonautoregressive_translation/scripts.md +++ b/examples/nonautoregressive_translation/scripts.md @@ -14,7 +14,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ @@ -43,7 +43,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ @@ -76,7 +76,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ @@ -109,7 +109,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ @@ -136,7 +136,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ @@ -165,7 +165,7 @@ fairseq-train \ --share-all-embeddings \ --optimizer adam --adam-betas '(0.9,0.98)' \ --lr 0.0005 --lr-scheduler inverse_sqrt \ - --min-lr '1e-09' --warmup-updates 10000 \ + --stop-min-lr '1e-09' --warmup-updates 10000 \ --warmup-init-lr '1e-07' --label-smoothing 0.1 \ --dropout 0.3 --weight-decay 0.01 \ --decoder-learned-pos \ diff --git a/examples/pay_less_attention_paper/README.md b/examples/pay_less_attention_paper/README.md index 3fb93b23d1..d5b19af6cc 100644 --- a/examples/pay_less_attention_paper/README.md +++ b/examples/pay_less_attention_paper/README.md @@ -110,7 +110,7 @@ mkdir -p $SAVE CUDA_VISIBLE_DEVICES=0 $(which fairseq-train) data-bin/iwslt14.tokenized.de-en \ --clip-norm 0 --optimizer adam --lr 0.0005 \ --source-lang de --target-lang en --max-tokens 4000 --no-progress-bar \ - --log-interval 100 --min-lr '1e-09' --weight-decay 0.0001 \ + --log-interval 100 --stop-min-lr '1e-09' --weight-decay 0.0001 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --lr-scheduler inverse_sqrt \ --ddp-backend=no_c10d \ @@ -137,10 +137,10 @@ python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \ --max-update 30000 --share-all-embeddings --optimizer adam \ --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ - --min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ + --stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ --ddp-backend=no_c10d --max-tokens 3584 \ --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ - --lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \ + --lr-shrink 1 --lr 0.001 --min-lr 1e-7 --warmup-init-lr 1e-07 \ --t-mult 1 --lr-period-updates 20000 \ --arch lightconv_wmt_en_de_big --save-dir $SAVE \ --dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \ @@ -162,10 +162,10 @@ python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \ --max-update 30000 --share-all-embeddings --optimizer adam \ --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ - --min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ + --stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \ --ddp-backend=no_c10d --max-tokens 3584 \ --lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \ - --lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \ + --lr-shrink 1 --lr 0.001 --min-lr 1e-7 --warmup-init-lr 1e-07 \ --t-mult 1 --lr-period-updates 70000 \ --arch lightconv_wmt_en_fr_big --save-dir $SAVE \ --dropout 0.1 --attention-dropout 0.1 --weight-dropout 0.1 \ diff --git a/examples/pointer_generator/pointer_generator_src/transformer_pg.py b/examples/pointer_generator/pointer_generator_src/transformer_pg.py index 079fdda581..fb40a80836 100644 --- a/examples/pointer_generator/pointer_generator_src/transformer_pg.py +++ b/examples/pointer_generator/pointer_generator_src/transformer_pg.py @@ -185,14 +185,14 @@ def forward(self, src_tokens, src_lengths, **kwargs): `(batch, src_len)` """ encoder_out = super().forward(src_tokens, src_lengths, **kwargs) - return EncoderOut( - encoder_out=encoder_out.encoder_out, # T x B x C - encoder_padding_mask=encoder_out.encoder_padding_mask, # B x T - encoder_embedding=encoder_out.encoder_embedding, # B x T x C - encoder_states=encoder_out.encoder_states, # List[T x B x C] - src_tokens=src_tokens, # B x T - src_lengths=None, - ) + return { + "encoder_out": encoder_out["encoder_out"], # T x B x C + "encoder_padding_mask": encoder_out["encoder_padding_mask"], # B x T + "encoder_embedding": encoder_out["encoder_embedding"], # B x T x C + "encoder_states": encoder_out["encoder_states"], # List[T x B x C] + "src_tokens": [src_tokens], # B x T + "src_lengths": [], + } class TransformerPointerGeneratorDecoder(TransformerDecoder): @@ -284,7 +284,7 @@ def forward( predictors = torch.cat((prev_output_embed, x), 2) p_gens = self.project_p_gens(predictors) p_gens = torch.sigmoid(p_gens) - x = self.output_layer(x, extra["attn"][0], encoder_out.src_tokens, p_gens) + x = self.output_layer(x, extra["attn"][0], encoder_out["src_tokens"][0], p_gens) return x, extra def output_layer(self, features, attn, src_tokens, p_gens, **kwargs): diff --git a/examples/quant_noise/README.md b/examples/quant_noise/README.md index 057ea620ab..7fe301f732 100644 --- a/examples/quant_noise/README.md +++ b/examples/quant_noise/README.md @@ -208,11 +208,11 @@ fairseq-train --task language_modeling /path/to/wikitext-103/data \ --ddp-backend no_c10d \ --decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 \ --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \ - --lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 1.0 --t-mult 2.0 \ + --min-lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --lr 1.0 --t-mult 2.0 \ --max-tokens 3072 --tokens-per-sample 3072 --momentum 0.99 --optimizer nag \ --sample-break-mode none --update-freq 3 \ --warmup-init-lr 1e-07 --warmup-updates 16000 \ - --weight-decay 0 --seed 1 --min-lr 1e-09 \ + --weight-decay 0 --seed 1 --stop-min-lr 1e-09 \ --quant-noise-pq 0.05 --quant-noise-pq-block-size 8 ``` @@ -269,7 +269,7 @@ fairseq-train --task language_modeling /path/to/wikitext-103/data \ --ddp-backend no_c10d \ --decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \ --fp16 --keep-last-epochs -1 \ - --lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 0.05 --min-lr 1e-09 \ + --min-lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --lr 0.05 --stop-min-lr 1e-09 \ --max-tokens 2944 --tokens-per-sample 2944\ --momentum 0.99 --no-epoch-checkpoints --no-progress-bar --optimizer nag --required-batch-size-multiple 8 \ --sample-break-mode none --t-mult 2.0 --skip-invalid-size-inputs-valid-test \ diff --git a/examples/roberta/README.md b/examples/roberta/README.md index ca86131eea..58091b2c7d 100644 --- a/examples/roberta/README.md +++ b/examples/roberta/README.md @@ -8,6 +8,7 @@ RoBERTa iterates on BERT's pretraining procedure, including training the model l ### What's New: +- December 2020: German model (GottBERT) is available: [GottBERT](https://github.com/pytorch/fairseq/tree/master/examples/gottbert). - January 2020: Italian model (UmBERTo) is available from Musixmatch Research: [UmBERTo](https://github.com/musixmatchresearch/umberto). - November 2019: French model (CamemBERT) is available: [CamemBERT](https://github.com/pytorch/fairseq/tree/master/examples/camembert). - November 2019: Multilingual encoder (XLM-RoBERTa) is available: [XLM-R](https://github.com/pytorch/fairseq/tree/master/examples/xlmr). diff --git a/examples/simultaneous_translation/README.md b/examples/simultaneous_translation/README.md index e27b65280e..bbc6dacdda 100644 --- a/examples/simultaneous_translation/README.md +++ b/examples/simultaneous_translation/README.md @@ -23,7 +23,7 @@ fairseq-train \ --optimizer adam --adam-betas '(0.9, 0.98)' \ --lr-scheduler 'inverse_sqrt' \ --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ --dropout 0.3 \ --label-smoothing 0.1\ --max-tokens 3584 @@ -44,7 +44,7 @@ fairseq-train \ --optimizer adam --adam-betas '(0.9, 0.98)' \ --lr-scheduler 'inverse_sqrt' \ --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ --dropout 0.3 \ --label-smoothing 0.1\ --max-tokens 3584 @@ -65,7 +65,7 @@ fairseq-train \ --optimizer adam --adam-betas '(0.9, 0.98)' \ --lr-scheduler 'inverse_sqrt' \ --warmup-init-lr 1e-7 --warmup-updates 4000 \ - --lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ + --lr 5e-4 --stop-min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\ --dropout 0.3 \ --label-smoothing 0.1\ --max-tokens 3584 diff --git a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py b/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py index b3c8f6d53f..761cfe61a1 100644 --- a/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py +++ b/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py @@ -14,15 +14,30 @@ class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( LabelSmoothedCrossEntropyCriterion ): - def __init__(self, args, task): - super().__init__(args, task) - self.eps = args.label_smoothing - self.latency_weight_avg = args.latency_weight_avg - self.latency_weight_avg_type = args.latency_weight_avg_type - self.latency_weight_var = args.latency_weight_var - self.latency_weight_var_type = args.latency_weight_var_type - self.mass_preservation = args.mass_preservation - self.average_method = args.average_method + def __init__( + self, + task, + sentence_avg, + label_smoothing, + ignore_prefix_size, + report_accuracy, + latency_weight_avg, + latency_weight_avg_type, + latency_weight_var, + latency_weight_var_type, + mass_preservation, + average_method, + ): + super().__init__( + task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy + ) + self.eps = label_smoothing + self.latency_weight_avg = latency_weight_avg + self.latency_weight_avg_type = latency_weight_avg_type + self.latency_weight_var = latency_weight_var + self.latency_weight_var_type = latency_weight_var_type + self.mass_preservation = mass_preservation + self.average_method = average_method self.latency_train = LatencyTraining( self.latency_weight_avg, self.latency_weight_var, diff --git a/examples/speech_recognition/infer.py b/examples/speech_recognition/infer.py index 68889463f4..ddd3fd6340 100644 --- a/examples/speech_recognition/infer.py +++ b/examples/speech_recognition/infer.py @@ -8,6 +8,7 @@ Run inference for pre-processed data with a trained model. """ +import ast import logging import math import os @@ -18,7 +19,6 @@ import torch from fairseq import checkpoint_utils, options, progress_bar, tasks, utils from fairseq.data.data_utils import post_process -from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging.meters import StopwatchMeter, TimeMeter @@ -178,53 +178,6 @@ def get_res_file(file_prefix): } -def load_models_and_criterions( - filenames, data_path, arg_overrides=None, task=None, model_state=None -): - models = [] - criterions = [] - - if arg_overrides is None: - arg_overrides = {} - - arg_overrides["wer_args"] = None - arg_overrides["data"] = data_path - - if filenames is None: - assert model_state is not None - filenames = [0] - else: - filenames = filenames.split(":") - - for filename in filenames: - if model_state is None: - if not os.path.exists(filename): - raise IOError("Model file not found: {}".format(filename)) - state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides) - else: - state = model_state - - if "cfg" in state: - cfg = state["cfg"] - else: - cfg = convert_namespace_to_omegaconf(state["args"]) - - if task is None: - if hasattr(cfg.task, 'data'): - cfg.task.data = data_path - task = tasks.setup_task(cfg.task) - - model = task.build_model(cfg.model) - model.load_state_dict(state["model"], strict=True) - models.append(model) - - criterion = task.build_criterion(cfg.criterion) - if "criterion" in state: - criterion.load_state_dict(state["criterion"], strict=True) - criterions.append(criterion) - return models, criterions, task - - def optimize_models(args, use_cuda, models): """Optimize ensemble for generation""" for model in models: @@ -266,23 +219,26 @@ def main(args, task=None, model_state=None): logger.info("| decoding with criterion {}".format(args.criterion)) + task = tasks.setup_task(args) + # Load ensemble if args.load_emissions: models, criterions = [], [] - task = tasks.setup_task(args) + task.load_dataset(args.gen_subset) else: logger.info("| loading model(s) from {}".format(args.path)) - models, criterions, task = load_models_and_criterions( - args.path, - data_path=args.data, - arg_overrides=eval(args.model_overrides), # noqa + models, saved_cfg = checkpoint_utils.load_model_ensemble( + utils.split_paths(args.path), + arg_overrides=ast.literal_eval(args.model_overrides), task=task, - model_state=model_state, + suffix=args.checkpoint_suffix, + strict=(args.checkpoint_shard_count == 1), + num_shards=args.checkpoint_shard_count, + state=model_state, ) optimize_models(args, use_cuda, models) + task.load_dataset(args.gen_subset, task_cfg=saved_cfg.task) - # Load dataset splits - task.load_dataset(args.gen_subset) # Set dictionary tgt_dict = task.target_dictionary @@ -295,8 +251,9 @@ def main(args, task=None, model_state=None): # hack to pass transitions to W2lDecoder if args.criterion == "asg_loss": - trans = criterions[0].asg.trans.data - args.asg_transitions = torch.flatten(trans).tolist() + raise NotImplementedError("asg_loss is currently not supported") + # trans = criterions[0].asg.trans.data + # args.asg_transitions = torch.flatten(trans).tolist() # Load dataset (possibly sharded) itr = get_dataset_itr(args, task, models) diff --git a/examples/translation/README.md b/examples/translation/README.md index 3eb8e01310..7b1fcc8de2 100644 --- a/examples/translation/README.md +++ b/examples/translation/README.md @@ -268,7 +268,7 @@ CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \ --arch multilingual_transformer_iwslt_de_en \ --share-decoders --share-decoder-input-output-embed \ --optimizer adam --adam-betas '(0.9, 0.98)' \ - --lr 0.0005 --lr-scheduler inverse_sqrt --min-lr '1e-09' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ --warmup-updates 4000 --warmup-init-lr '1e-07' \ --label-smoothing 0.1 --criterion label_smoothed_cross_entropy \ --dropout 0.3 --weight-decay 0.0001 \ diff --git a/examples/translation_moe/README.md b/examples/translation_moe/README.md index ef7abdb44b..3cc3fb46dc 100644 --- a/examples/translation_moe/README.md +++ b/examples/translation_moe/README.md @@ -24,7 +24,7 @@ fairseq-train --ddp-backend='no_c10d' \ --arch transformer_wmt_en_de --share-all-embeddings \ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ --lr-scheduler inverse_sqrt --warmup-init-lr 1e-07 --warmup-updates 4000 \ - --lr 0.0007 --min-lr 1e-09 \ + --lr 0.0007 \ --dropout 0.1 --weight-decay 0.0 --criterion cross_entropy \ --max-tokens 3584 ``` diff --git a/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py b/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py index 484b6ac912..efc7ae40bf 100644 --- a/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py +++ b/examples/translation_moe/translation_moe_src/mean_pool_gating_network.py @@ -26,15 +26,15 @@ def __init__(self, embed_dim, num_experts, dropout=None): def forward(self, encoder_out): if not ( - hasattr(encoder_out, "encoder_out") - and hasattr(encoder_out, "encoder_padding_mask") - and encoder_out.encoder_out.size(2) == self.embed_dim + "encoder_out" in encoder_out + and "encoder_padding_mask" in encoder_out + and encoder_out["encoder_out"][0].size(2) == self.embed_dim ): raise ValueError("Unexpected format for encoder_out") # mean pooling over time - encoder_padding_mask = encoder_out.encoder_padding_mask # B x T - encoder_out = encoder_out.encoder_out.transpose(0, 1) # B x T x C + encoder_padding_mask = encoder_out["encoder_padding_mask"][0] # B x T + encoder_out = encoder_out["encoder_out"][0].transpose(0, 1) # B x T x C if encoder_padding_mask is not None: encoder_out = encoder_out.clone() # required because of transpose above encoder_out[encoder_padding_mask] = 0 diff --git a/examples/truncated_bptt/README.md b/examples/truncated_bptt/README.md new file mode 100644 index 0000000000..86518c9d5e --- /dev/null +++ b/examples/truncated_bptt/README.md @@ -0,0 +1,70 @@ +# Truncated Backpropagation Through Time (BPTT) + +Truncated BPTT is a useful technique for training language models on very long +sequences. Typically a long sequences is split into chunks and a language model +is trained over the chunks sequentially. The LM may condition on previous +chunks, but gradients only flow through the current chunk. This technique was +the basis for the paper: [Transformer-XL: Attentive Language Models Beyond a +Fixed-Length Context](https://arxiv.org/abs/1901.02860), which achieved +state-of-the-art language modeling results at the time of publication. + +It is slightly tricky to implement Truncated BPTT efficiently in fairseq, since +we need to iterate over the data sequentially and disable any batch shuffling +logic. The code provided in this example illustrates how to implement Truncated +BPTT in fairseq by overriding ``FairseqTask::get_batch_iterator`` to iterate +over the data sequentially. Crucially, this example supports batching and +multi-GPU (data parallel) training. + +##### 0. Setup + +First, see the general [language modeling README](README.md) for instructions on +preprocessing the WikiText-103 data. + +##### 1. Train a Transformer-XL model on WikiText-103 + +We will train a 16-layer Transformer-XL model following the [hyperparameters +used in the original +paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh). + +The following command assumes 4 GPUs, so that the total batch size is 60 +sequences (15 x 4). Training should take ~24 hours on 4 V100 GPUs: +```bash +CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \ + --user-dir examples/truncated_bptt \ + data-bin/wikitext-103/ \ + --task truncated_bptt_lm --tokens-per-sample 150 \ + --batch-size 15 --max-update 200000 \ + --arch transformer_xl --n-layer 16 --d-model 410 --n-head 10 \ + --d-head 41 --d-inner 2100 --dropout 0.1 --dropatt 0.0 --mem-len 150 \ + --optimizer adam --clip-norm 0.25 \ + --lr-scheduler cosine --warmup-updates 0 --min-lr 0.0 --lr 0.00025 \ + --log-format json --log-interval 25 \ + --fp16 +``` + +If training on a single GPU, set `--update-freq=4` to accumulate 4x gradients +and simulate training on 4 GPUs. + +##### 2. Evaluate + +```bash +fairseq-eval-lm data-bin/wikitext-103/ \ + --path checkpoints/checkpoint_best.pt \ + --user-dir examples/truncated_bptt/ \ + --task truncated_bptt_lm \ + --batch-size 1 --required-batch-size-multiple 1 \ + --model-overrides '{"mem_len":640,"clamp_len":400,"same_length":True}' \ + --tokens-per-sample 64 +# ... | INFO | fairseq_cli.eval_lm | num. model params: 151123537 +# ... | INFO | fairseq_cli.eval_lm | Evaluated 245569 tokens in 83.1s (2956.82 tokens/s) +# ... | INFO | fairseq_cli.eval_lm | Loss (base 2): 4.5668, Perplexity: 23.70 +# Compare to 24.0 test perplexity from the paper +``` + +*Note:* During training the model saw 150 tokens of context +(``--tokens-per-sample=150``) and 150 extra memory tokens (``--mem-len=150``). +During evaluation we measure perplexity on sequences of 64 tokens +(``--tokens-per-sample=64``) and increase the memory length +(``--model-overrides='{"mem_len":640}'``). These settings match the evaluation +settings from [the original +paper](https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/run_wt103_base.sh). diff --git a/examples/truncated_bptt/__init__.py b/examples/truncated_bptt/__init__.py new file mode 100644 index 0000000000..eee484d427 --- /dev/null +++ b/examples/truncated_bptt/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import transformer_xl_model, truncated_bptt_lm_task # noqa diff --git a/examples/truncated_bptt/transformer_xl_model.py b/examples/truncated_bptt/transformer_xl_model.py new file mode 100644 index 0000000000..83b248479e --- /dev/null +++ b/examples/truncated_bptt/transformer_xl_model.py @@ -0,0 +1,151 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +import torch +from fairseq.dataclass import FairseqDataclass +from fairseq.models import ( + FairseqIncrementalDecoder, + FairseqLanguageModel, + register_model, +) +from fairseq.modules.checkpoint_activations import checkpoint_wrapper +from omegaconf import II + + +logger = logging.getLogger(__name__) + + +@dataclass +class TransformerXLConfig(FairseqDataclass): + # defaults come from the original Transformer-XL code + cutoffs: List[int] = field(default_factory=lambda: [20000, 40000, 200000]) + d_model: int = 500 + n_head: int = 10 + d_head: int = 50 + d_inner: int = 1000 + div_val: int = 1 + n_layer: int = 12 + mem_len: int = 0 + clamp_len: int = -1 + same_length: bool = False + dropout: float = 0.0 + dropatt: float = 0.0 + checkpoint_activations: bool = False + max_target_positions: int = II("task.max_target_positions") + + +@register_model("transformer_xl", dataclass=TransformerXLConfig) +class TransformerXLLanguageModel(FairseqLanguageModel): + @classmethod + def build_model(cls, cfg: TransformerXLConfig, task): + return cls(TransformerXLDecoder(cfg, task)) + + +class TransformerXLDecoder(FairseqIncrementalDecoder): + def __init__(self, cfg, task): + try: + from transformers.models.transfo_xl import ( + TransfoXLConfig, TransfoXLLMHeadModel + ) + except ImportError: + from transformers.configuration_transfo_xl import TransfoXLConfig + from transformers.modeling_transfo_xl import TransfoXLLMHeadModel + + super().__init__(task.target_dictionary) + self.cfg = cfg + + # remove any cutoffs larger than the vocab size + cutoffs = [ + cutoff for cutoff in cfg.cutoffs if cutoff < len(task.target_dictionary) + ] + + config = TransfoXLConfig( + vocab_size=len(task.target_dictionary), + cutoffs=cutoffs, + d_model=cfg.d_model, + d_embed=cfg.d_model, + n_head=cfg.n_head, + d_head=cfg.d_head, + d_inner=cfg.d_inner, + div_val=cfg.div_val, + n_layer=cfg.n_layer, + mem_len=cfg.mem_len, + clamp_len=cfg.clamp_len, + same_length=cfg.same_length, + dropout=cfg.dropout, + dropatt=cfg.dropatt, + ) + logger.info(config) + self.model = TransfoXLLMHeadModel(config) + + # Workaround a bug in huggingface's ``ProjectedAdaptiveLogSoftmax`` + # which adds ``None`` values to an ``nn.ParameterList``, which is not + # supported in PyTorch. Instead we can replace this with an + # ``nn.ModuleList``, which does support ``None`` values. + try: + if all(p is None for p in self.model.crit.out_projs._parameters.values()): + self.model.crit.out_projs = torch.nn.ModuleList( + [None] * len(self.model.crit.out_projs._parameters) + ) + except Exception: + pass + + if cfg.checkpoint_activations: + for i in range(len(self.model.transformer.layers)): + self.model.transformer.layers[i] = checkpoint_wrapper( + self.model.transformer.layers[i] + ) + + self._mems = None + + def forward( + self, + src_tokens, + src_lengths=None, # unused + incremental_state: Optional[Dict[str, List[torch.Tensor]]] = None, + encoder_out=None, + ): + if incremental_state is not None: # used during inference + mems = self.get_incremental_state(incremental_state, "mems") + src_tokens = src_tokens[:, -1:] # only keep the most recent token + else: + mems = self._mems + + output = self.model( + input_ids=src_tokens, + mems=mems, + return_dict=False, + ) + + if len(output) >= 2: + if incremental_state is not None: + self.set_incremental_state(incremental_state, "mems", output[1]) + else: + self._mems = output[1] + + return (output[0],) + + def max_positions(self): + return self.cfg.max_target_positions + + def reorder_incremental_state( + self, + incremental_state: Dict[str, Dict[str, Optional[torch.Tensor]]], + new_order: torch.Tensor, + ): + """Reorder incremental state. + + This will be called when the order of the input has changed from the + previous time step. A typical use case is beam search, where the input + order changes between time steps based on the selection of beams. + """ + mems = self.get_incremental_state(incremental_state, "mems") + if mems is not None: + new_mems = [mems_i.index_select(1, new_order) for mems_i in mems] + self.set_incremental_state(incremental_state, "mems", new_mems) diff --git a/examples/truncated_bptt/truncated_bptt_lm_task.py b/examples/truncated_bptt/truncated_bptt_lm_task.py new file mode 100644 index 0000000000..34c4f03955 --- /dev/null +++ b/examples/truncated_bptt/truncated_bptt_lm_task.py @@ -0,0 +1,280 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from dataclasses import dataclass, field +from typing import List, Optional, Tuple + +import torch +from fairseq import distributed_utils as dist_utils, utils +from fairseq.data import ( + Dictionary, + TokenBlockDataset, + data_utils, + iterators, +) +from fairseq.dataclass import FairseqDataclass +from fairseq.tasks import FairseqTask, register_task +from omegaconf import II + + +logger = logging.getLogger(__name__) + + +@dataclass +class TruncatedBPTTLMConfig(FairseqDataclass): + data: str = field(default="???", metadata={"help": "path to data directory"}) + tokens_per_sample: int = field( + default=1024, + metadata={"help": "max number of tokens per sequence"}, + ) + batch_size: int = II("dataset.batch_size") + # Some models use *max_target_positions* to know how many positional + # embeddings to learn. We use II(...) to make it default to + # *tokens_per_sample*, but in principle there could be more positional + # embeddings than tokens in a single batch. This may also be irrelevant for + # custom model implementations. + max_target_positions: int = II("task.tokens_per_sample") + # these will be populated automatically if not provided + data_parallel_rank: Optional[int] = None + data_parallel_size: Optional[int] = None + + +@register_task("truncated_bptt_lm", dataclass=TruncatedBPTTLMConfig) +class TruncatedBPTTLMTask(FairseqTask): + def __init__(self, cfg: TruncatedBPTTLMConfig): + super().__init__(cfg) + + if cfg.data_parallel_rank is None or cfg.data_parallel_size is None: + if torch.distributed.is_initialized(): + cfg.data_parallel_rank = dist_utils.get_data_parallel_rank() + cfg.data_parallel_size = dist_utils.get_data_parallel_world_size() + else: + cfg.data_parallel_rank = 0 + cfg.data_parallel_size = 1 + + # load the dictionary + paths = utils.split_paths(cfg.data) + assert len(paths) > 0 + self.dictionary = Dictionary.load(os.path.join(paths[0], "dict.txt")) + logger.info("dictionary: {} types".format(len(self.dictionary))) + + def load_dataset(self, split, epoch=1, combine=False, **kwargs): + """Load a given dataset split (e.g., train, valid, test)""" + + # support sharded datasets + paths = utils.split_paths(self.cfg.data) + assert len(paths) > 0 + data_path = paths[(epoch - 1) % len(paths)] + split_path = os.path.join(data_path, split) + + # each element of *data* will be a tensorized line from the original + # text dataset, similar to ``open(split_path).readlines()`` + data = data_utils.load_indexed_dataset( + split_path, self.dictionary, combine=combine + ) + if data is None: + raise FileNotFoundError( + "Dataset not found: {} ({})".format(split, split_path) + ) + + # this is similar to ``data.view(-1).split(tokens_per_sample)`` + data = TokenBlockDataset( + data, + data.sizes, + block_size=self.cfg.tokens_per_sample, + pad=None, # unused + eos=None, # unused + break_mode="none", + ) + + self.datasets[split] = TruncatedBPTTDataset( + data=data, + bsz_per_shard=self.cfg.batch_size, + shard_id=self.cfg.data_parallel_rank, + num_shards=self.cfg.data_parallel_size, + ) + + def dataset(self, split): + return self.datasets[split] + + def get_batch_iterator( + self, dataset, num_workers=0, epoch=1, data_buffer_size=0, **kwargs + ): + return iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=self._collate_fn, + num_workers=num_workers, + epoch=epoch, + buffer_size=data_buffer_size, + # we don't use the batching functionality from EpochBatchIterator; + # instead every item in *dataset* is a whole batch + batch_sampler=[[i] for i in range(len(dataset))], + disable_shuffling=True, + ) + + def _collate_fn(self, items: List[List[torch.Tensor]]): + # we don't use fairseq's batching functionality, so we expect a single + # Tensor of type List[torch.Tensor] + assert len(items) == 1 + + # item will have shape B x T (the last batch may have length < T) + id, item = items[0] + item = data_utils.collate_tokens(item, pad_idx=self.source_dictionary.pad()) + B, T = item.size() + + # shift item one position over and append a padding token for the target + target = torch.nn.functional.pad( + item[:, 1:], (0, 1, 0, 0), value=self.target_dictionary.pad() + ) + + # fairseq expects batches to have the following structure + return { + "id": torch.tensor([id]*item.size(0)), + "net_input": { + "src_tokens": item, + }, + "target": target, + "nsentences": item.size(0), + "ntokens": item.numel(), + } + + def build_dataset_for_inference( + self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs + ) -> torch.utils.data.Dataset: + eos = self.source_dictionary.eos() + dataset = TokenBlockDataset( + src_tokens, + src_lengths, + block_size=None, # ignored for "eos" break mode + pad=self.source_dictionary.pad(), + eos=eos, + break_mode="eos", + ) + + class Dataset(torch.utils.data.Dataset): + def __getitem__(self, i): + item = dataset[i] + if item[-1] == eos: + # remove eos to support generating with a prefix + item = item[:-1] + return (i, [item]) + + def __len__(self): + return len(dataset) + + return Dataset() + + def inference_step( + self, generator, models, sample, prefix_tokens=None, constraints=None + ): + with torch.no_grad(): + if constraints is not None: + raise NotImplementedError + + # SequenceGenerator doesn't use *src_tokens* directly, we need to + # pass the *prefix_tokens* argument instead. + if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement(): + prefix_tokens = sample["net_input"]["src_tokens"] + + # begin generation with the end-of-sentence token + bos_token = self.source_dictionary.eos() + + return generator.generate( + models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token + ) + + def eval_lm_dataloader( + self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 1, + data_buffer_size: int = 10, + context_window: int = 0, + ): + if context_window > 0: + raise NotImplementedError( + "Transformer-XL doesn't need --context-window, try " + "--model-overrides '{\"mem_len\":42}' instead " + ) + return self.get_batch_iterator( + dataset=dataset, + max_tokens=max_tokens, + max_sentences=batch_size, + max_positions=max_positions, + ignore_invalid_inputs=True, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + data_buffer_size=data_buffer_size, + ).next_epoch_itr(shuffle=False) + + @property + def source_dictionary(self): + return self.dictionary + + @property + def target_dictionary(self): + return self.dictionary + + +class TruncatedBPTTDataset(torch.utils.data.Dataset): + def __init__( + self, + data: List[torch.Tensor], # ordered list of items + bsz_per_shard, # number of items processed per GPUs per forward + shard_id, # current GPU ID + num_shards, # number of GPUs + ): + super().__init__() + self.data = data + + def batchify(data, bsz): + # Work out how cleanly we can divide the dataset into bsz parts. + nbatch = data.size(0) // bsz + # Trim off any extra elements that wouldn't cleanly fit (remainders). + data = data.narrow(0, 0, nbatch * bsz) + # Evenly divide the data across the bsz batches. + data = data.view(bsz, -1).contiguous() + return data + + # total number of sequences processed by all GPUs in each forward pass + global_batch_size = bsz_per_shard * num_shards + + """ + With a 16 item dataset, bsz_per_shard=2 and num_shards=3, + *indices* might look like: + + indices = [[0, 1], + [2, 3], + [4, 5], + [6, 7], + [8, 9], + [10, 11]] + + The size of the TruncatedBPTTDataset instance will be 2, + and shard 1 will see items: + + [(0, [data[4], data[6]]), + (1, [data[5], data[7]])] + """ + indices = batchify(torch.arange(len(data)), global_batch_size) + assert indices.size(0) == global_batch_size + + self.my_indices = indices[ + shard_id * bsz_per_shard : (shard_id + 1) * bsz_per_shard + ] + assert self.my_indices.size(0) == bsz_per_shard + + def __len__(self): + return self.my_indices.size(1) + + def __getitem__(self, i) -> Tuple[int, List[torch.Tensor]]: + return (i, [self.data[idx] for idx in self.my_indices[:, i]]) diff --git a/examples/wav2vec/README.md b/examples/wav2vec/README.md index 1da42f388a..a0c95e9c34 100644 --- a/examples/wav2vec/README.md +++ b/examples/wav2vec/README.md @@ -2,6 +2,8 @@ wav2vec 2.0 learns speech representations on unlabeled data as described in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](https://arxiv.org/abs/2006.11477). +We learned speech representations in multiple languages as well in [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979). + We also combined wav2vec 2.0 with self-training in [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430). ## Pre-trained models @@ -26,6 +28,21 @@ Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https:// \* updated (Oct. 24, 2020) +We also release multilingual pre-trained wav2vec 2.0 (XLSR) models: + +Model | Architecture | Hours | Languages | Datasets | Model +|---|---|---|---|---|--- +XLSR-53 | Large | 56k | 53 | MLS, CommonVoice, BABEL | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt) + +The XLSR model uses the following datasets for multilingual pretraining: + +* **[MLS: Multilingual LibriSpeech](https://indico2.conference4me.psnc.pl/event/35/contributions/3585/attachments/1060/1101/Wed-2-6-10.pdf)** (8 languages, 50.7k hours): *Dutch, English, French, German, Italian, Polish, Portuguese, Spanish* + +* **[CommonVoice](https://commonvoice.mozilla.org/en/languages)** (36 languages, 3.6k hours): *Arabic, Basque, Breton, Chinese (CN), Chinese (HK), Chinese (TW), Chuvash, Dhivehi, Dutch, English, Esperanto, Estonian, French, German, Hakh-Chin, Indonesian, Interlingua, Irish, Italian, Japanese, Kabyle, Kinyarwanda, Kyrgyz, Latvian, Mongolian, Persian, Portuguese, Russian, Sakha, Slovenian, Spanish, Swedish, Tamil, Tatar, Turkish, Welsh* (see also [finetuning splits]([https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz]) from [this paper](https://arxiv.org/abs/2002.02848)). + +* **[Babel](https://catalog.ldc.upenn.edu/byyear)** (17 languages, 1.7k hours): *Assamese, Bengali, Cantonese, Cebuano, Georgian, Haitian, Kazakh, Kurmanji, Lao, Pashto, Swahili, Tagalog, Tamil, Tok, Turkish, Vietnamese, Zulu* + + ## Training a new model with the CLI tools Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) @@ -53,44 +70,31 @@ separately pre-processed manifest file. This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper -Note that this was tested with pytorch 1.4.0 and the input is expected to be single channel, sampled at 16 kHz +Note that the input is expected to be single channel, sampled at 16 kHz ```shell script -$ python train.py --distributed-world-size 64 --distributed-port $PORT /manifest/path \ ---save-dir /model/path --fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \ ---log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ ---conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 256 --latent-vars 320 \ ---latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce --optimizer adam \ ---adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 400000 \ ---lr 0.0005 --warmup-updates 32000 --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \ ---encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \ ---loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 --num-negatives 100 --cross-sample-negatives 0 \ ---max-sample-size 250000 --min-sample-size 32000 --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \ ---max-tokens 1400000 --max-update 400000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_base_librispeech ``` -Note: you can simulate 64 GPUs by using k GPUs and setting --update-freq 64/k +Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 64/k ### Train a wav2vec 2.0 large model: This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper ```shell script -$ python train.py --distributed-world-size 128 --distributed-port $PORT /manifest/path \ ---save-dir /model/path --fp16 --num-workers 6 --task audio_pretraining --criterion wav2vec --arch wav2vec2 \ ---log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ ---conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 768 --latent-vars 320 \ ---latent-groups 2 --latent-temp '(2.0,0.1,0.999995)' --infonce --optimizer adam \ ---adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay --total-num-update 600000 \ ---lr 0.0003 --warmup-updates 32000 --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \ ---encoder-layerdrop 0.0 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.03 \ ---loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 --encoder-layers 24 --encoder-embed-dim 1024 \ ---encoder-ffn-embed-dim 4096 --encoder-attention-heads 16 --num-negatives 100 --cross-sample-negatives 0 \ ---max-sample-size 320000 --min-sample-size 32000 --dropout 0.0 --attention-dropout 0.1 --weight-decay 0.01 \ ---max-tokens 1200000 --max-update 600000 --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox ``` -Note: you can simulate 128 GPUs by using k GPUs and setting --update-freq 128/k +Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 128/k ### Fine-tune a pre-trained model with CTC: @@ -105,28 +109,22 @@ $ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $sp Fine-tuning on 100h of Librispeech with letter targets: ```shell script -valid_subset=dev_other -python train.py --distributed-world-size 24 --distributed-port $PORT /path/to/training_data --save-dir /model/path --fp16 \ ---wer-args '("/path/to/lm/4-gram.bin","/path/to/lexicon",2,-1)' \ ---post-process letter --valid-subset $valid_subset --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 4 \ ---max-update 80000 --sentence-avg --task audio_pretraining --arch wav2vec_ctc --w2v-path /path/to/pretrained/model \ ---labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \ ---mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 --zero-infinity \ ---feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 --optimizer adam \ ---adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 --hold-steps 32000 \ ---decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 --activation-dropout 0.1 --criterion ctc \ ---attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json --log-interval 500 --ddp-backend no_c10d +$ fairseq-hydra-train \ + distributed_training.distributed_port=$PORT \ + task.data=/path/to/data \ + model.w2v_path=/path/to/model.pt \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \ + --config-name base_100h ``` -Note: you can simulate 24 GPUs by using k GPUs and setting --update-freq 24/k +There are other config files in the config/finetuning directory that can be used to fine-tune on other splits. +You can specify the right config via the `--config-name` parameter. -Decoding with a language model during training requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings). -Alternatively, simply omit the --wer-args flag. +Note: you can simulate 24 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 24/k -For hyper-parameters to fine-tune other Librispeech splits (10 minutes, 1 hour, etc) please refer to the table in Appendix B in the wav2vec 2.0 paper. -The main changes to make are adjusting --max-update, and then adjusting --warmup-steps, --hold-steps, and --decay steps so that they use 0.1/0.4/0.5 of max-update respectively. You then need to adjust --mask-prob and --mask-channel-prob. This should be set to the mask-length * x where x is the number in the table and mask-length is what you use for --mask-length (10 in this example. Use --mask-channel-length value for --mask-channel-prob). - -For example, for 10 hours, we see in the paper that timestep mask prob should be 0.065, so we set --mask-prob to 10* 0.065 = 0.65. channel mask prob is 0.004, so we set it to 64 * 0.004 = 0.256. then we set --max-updates to 20000 and change --warmup-steps to 20000 * 0.1 = 2000, --hold-steps to 8000 and --decay-steps to 10000. +Decoding with a language model during training requires wav2letter [python bindings](https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings). +If you want to use a language model, add `+criterion.wer_args='[/path/to/kenlm, /path/to/lexicon, 2, -1]'` to the command line. ### Evaluating a CTC model: @@ -162,11 +160,11 @@ Wav2Vec large | [Librispeech](http://www.openslr.org/12) | [download](https://dl #### Example usage: ```python import torch -from fairseq.models.wav2vec import Wav2VecModel +import fairseq -cp = torch.load('/path/to/wav2vec.pt') -model = Wav2VecModel.build_model(cp['args'], task=None) -model.load_state_dict(cp['model']) +cp_path = '/path/to/wav2vec.pt' +model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) +model = model[0] model.eval() wav_input_16khz = torch.randn(1,10000) @@ -188,7 +186,7 @@ $ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/pa ``` $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ ---arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \ +--arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ --conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ --conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ --skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ @@ -219,11 +217,11 @@ Roberta on K-means codes | [Librispeech](http://www.openslr.org/12) | [download] #### Example usage: ```python import torch -from fairseq.models.wav2vec import Wav2VecModel +import fairseq cp = torch.load('/path/to/vq-wav2vec.pt') -model = Wav2VecModel.build_model(cp['args'], task=None) -model.load_state_dict(cp['model']) +model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp]) +model = model[0] model.eval() wav_input_16khz = torch.randn(1,10000) @@ -246,8 +244,8 @@ $ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/pa ``` $ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 \ ---save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --lr 1e-06 --min-lr 1e-09 \ ---optimizer adam --max-lr 1e-05 --lr-scheduler cosine \ +--save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 \ +--optimizer adam --lr 1e-05 --lr-scheduler cosine \ --conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)] \ --conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ --activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \ diff --git a/examples/wav2vec/config/finetuning/base_100h.yaml b/examples/wav2vec/config/finetuning/base_100h.yaml new file mode 100644 index 0000000000..7d1664a184 --- /dev/null +++ b/examples/wav2vec/config/finetuning/base_100h.yaml @@ -0,0 +1,59 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 2 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 80000 + lr: [0.00003] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.5 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 0 + diff --git a/examples/wav2vec/config/finetuning/base_10h.yaml b/examples/wav2vec/config/finetuning/base_10h.yaml new file mode 100644 index 0000000000..31125947c0 --- /dev/null +++ b/examples/wav2vec/config/finetuning/base_10h.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 50 + save_interval_updates: 10000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 50 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 2 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 20000 + lr: [0.00005] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.5 + mask_channel_length: 64 + layerdrop: 0.05 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/base_10m.yaml b/examples/wav2vec/config/finetuning/base_10m.yaml new file mode 100644 index 0000000000..2235504489 --- /dev/null +++ b/examples/wav2vec/config/finetuning/base_10m.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 1000 + save_interval_updates: 50 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 1000 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 2 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [0.00005] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/base_1h.yaml b/examples/wav2vec/config/finetuning/base_1h.yaml new file mode 100644 index 0000000000..2235504489 --- /dev/null +++ b/examples/wav2vec/config/finetuning/base_1h.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 1000 + save_interval_updates: 50 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 1000 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 2 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [0.00005] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/base_960h.yaml b/examples/wav2vec/config/finetuning/base_960h.yaml new file mode 100644 index 0000000000..d742c94abf --- /dev/null +++ b/examples/wav2vec/config/finetuning/base_960h.yaml @@ -0,0 +1,58 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 320000 + lr: [0.00001] + sentence_avg: true + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.5 + mask_channel_prob: 0.1 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 0 + diff --git a/examples/wav2vec/config/finetuning/vox_100h.yaml b/examples/wav2vec/config/finetuning/vox_100h.yaml new file mode 100644 index 0000000000..8885c78470 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_100h.yaml @@ -0,0 +1,59 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 80000 + lr: [0.00003] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.5 + mask_channel_prob: 0.5 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/vox_10h.yaml b/examples/wav2vec/config/finetuning/vox_10h.yaml new file mode 100644 index 0000000000..c0957c0058 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_10h.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 50 + save_interval_updates: 10000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 50 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 20000 + lr: [0.0001] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.75 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/vox_10m.yaml b/examples/wav2vec/config/finetuning/vox_10m.yaml new file mode 100644 index 0000000000..0d567552d7 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_10m.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 1000 + save_interval_updates: 50 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 1000 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [0.0001] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/vox_1h.yaml b/examples/wav2vec/config/finetuning/vox_1h.yaml new file mode 100644 index 0000000000..10c45a52d8 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_1h.yaml @@ -0,0 +1,64 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 1000 + save_interval_updates: 50 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 1000 + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [0.0003] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.75 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/finetuning/vox_960h.yaml b/examples/wav2vec/config/finetuning/vox_960h.yaml new file mode 100644 index 0000000000..6212a2e738 --- /dev/null +++ b/examples/wav2vec/config/finetuning/vox_960h.yaml @@ -0,0 +1,58 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_pretraining + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + valid_subset: dev_other + +distributed_training: + ddp_backend: no_c10d + distributed_world_size: 24 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 320000 + lr: [0.00003] + sentence_avg: true + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.5 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml b/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml new file mode 100644 index 0000000000..e2c2b7b0b3 --- /dev/null +++ b/examples/wav2vec/config/pretraining/wav2vec2_base_librispeech.yaml @@ -0,0 +1,55 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: ??? + max_sample_size: 250000 + min_sample_size: 32000 + +dataset: + num_workers: 6 + max_tokens: 1400000 + skip_invalid_size_inputs_valid_test: true + +distributed_training: + distributed_world_size: 64 + ddp_backend: no_c10d + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 10] + +optimization: + max_update: 400000 + lr: [0.0005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: wav2vec2 + quantize_targets: true + final_dim: 256 + encoder_layerdrop: 0.05 + dropout_input: 0.1 + dropout_features: 0.1 + feature_grad_mult: 0.1 diff --git a/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml new file mode 100644 index 0000000000..0c911b7491 --- /dev/null +++ b/examples/wav2vec/config/pretraining/wav2vec2_large_librivox.yaml @@ -0,0 +1,69 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval_updates: 25000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + +task: + _name: audio_pretraining + data: ??? + max_sample_size: 320000 + min_sample_size: 32000 + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1200000 + skip_invalid_size_inputs_valid_test: true + +distributed_training: + distributed_world_size: 128 + ddp_backend: no_c10d + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 0] + +optimization: + max_update: 1000000 + lr: [0.005] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 32000 + +model: + _name: wav2vec2 + quantize_targets: true + extractor_mode: layer_norm + layer_norm_first: true + final_dim: 768 + latent_temp: [2.0,0.1,0.999995] + encoder_layerdrop: 0.00 + dropout_input: 0.0 + dropout_features: 0.0 + dropout: 0.0 + attention_dropout: 0.0 + conv_bias: true + + encoder_layers: 24 + encoder_embed_dim: 1024 + encoder_ffn_embed_dim: 4096 + encoder_attention_heads: 16 + + feature_grad_mult: 1.0 + diff --git a/examples/wav2vec/vq-wav2vec_featurize.py b/examples/wav2vec/vq-wav2vec_featurize.py index baabc1d365..1adb52de1c 100644 --- a/examples/wav2vec/vq-wav2vec_featurize.py +++ b/examples/wav2vec/vq-wav2vec_featurize.py @@ -16,8 +16,7 @@ import soundfile as sf import torch -import tqdm -from fairseq.models.wav2vec.wav2vec import Wav2VecModel +import fairseq from torch import nn from torch.utils.data import DataLoader @@ -211,13 +210,11 @@ def load_data(self, fnames): return loader def load_model(self): - cp = torch.load(self.checkpoint, map_location=lambda x, _: x) + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([self.checkpoint]) + model = model[0] - model = Wav2VecModel.build_model(cp["args"], None) + self.quantize_location = getattr(cfg.model, "vq", "encoder") - self.quantize_location = getattr(cp["args"], "vq", "encoder") - - model.load_state_dict(cp["model"]) model.eval().float() model.cuda() diff --git a/examples/wav2vec/wav2vec_featurize.py b/examples/wav2vec/wav2vec_featurize.py index 9283930587..b806316e5a 100644 --- a/examples/wav2vec/wav2vec_featurize.py +++ b/examples/wav2vec/wav2vec_featurize.py @@ -18,7 +18,7 @@ import soundfile as sf import torch import tqdm -from fairseq.models.wav2vec.wav2vec import Wav2VecModel +import fairseq from torch import nn @@ -35,10 +35,8 @@ class PretrainedWav2VecModel(nn.Module): def __init__(self, fname): super().__init__() - checkpoint = torch.load(fname) - self.args = checkpoint["args"] - model = Wav2VecModel.build_model(self.args, None) - model.load_state_dict(checkpoint["model"]) + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fname]) + model = model[0] model.eval() self.model = model diff --git a/examples/wav2vec/wav2vec_manifest.py b/examples/wav2vec/wav2vec_manifest.py index 1d27f58afc..5417084554 100644 --- a/examples/wav2vec/wav2vec_manifest.py +++ b/examples/wav2vec/wav2vec_manifest.py @@ -47,6 +47,9 @@ def get_parser(): def main(args): assert args.valid_percent >= 0 and args.valid_percent <= 1.0 + if not os.path.exists(args.dest): + os.makedirs(args.dest) + dir_path = os.path.realpath(args.root) search_path = os.path.join(dir_path, "**/*." + args.ext) rand = random.Random(args.seed) diff --git a/examples/wmt20/README.md b/examples/wmt20/README.md new file mode 100644 index 0000000000..b4f2874652 --- /dev/null +++ b/examples/wmt20/README.md @@ -0,0 +1,72 @@ +# WMT 20 + +This page provides pointers to the models of Facebook-FAIR's WMT'20 news translation task submission [(Chen et al., 2020)](https://arxiv.org/abs/2011.08298). + +## Single best MT models (after finetuning on part of WMT20 news dev set) + +Model | Description | Download +---|---|--- +`transformer.wmt20.ta-en` | Ta->En | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz) +`transformer.wmt20.en-ta` | En->Ta | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz) +`transformer.wmt20.iu-en.news` | Iu->En (News domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz) +`transformer.wmt20.en-iu.news` | En->Iu (News domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz) +`transformer.wmt20.iu-en.nh` | Iu->En (Nunavut Hansard domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz) +`transformer.wmt20.en-iu.nh` | En->Iu (Nunavut Hansard domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz) + +## Language models +Model | Description | Download +---|---|--- +`transformer_lm.wmt20.en` | En Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en.tar.gz) +`transformer_lm.wmt20.ta` | Ta Language Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta.tar.gz) +`transformer_lm.wmt20.iu.news` | Iu Language Model (News domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu.news.tar.gz) +`transformer_lm.wmt20.iu.nh` | Iu Language Model (Nunavut Hansard domain) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu.nh.tar.gz) + +## Example usage (torch.hub) + +#### Translation + +```python +import torch + +# English to Tamil translation +en2ta = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.en-ta') +en2ta.translate("Machine learning is great!") # 'இயந்திரக் கற்றல் அருமை!' + +# Tamil to English translation +ta2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.ta-en') +ta2en.translate("இயந்திரக் கற்றல் அருமை!") # 'Machine learning is great!' + +# English to Inuktitut translation +en2iu = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.en-iu.news') +en2iu.translate("machine learning is great!") # 'ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ ᐱᐅᔪᒻᒪᕆᒃ!' + +# Inuktitut to English translation +iu2en = torch.hub.load('pytorch/fairseq', 'transformer.wmt20.iu-en.news') +iu2en.translate("ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ ᐱᐅᔪᒻᒪᕆᒃ!") # 'Machine learning excellence!' +``` + +#### Language Modeling + +```python +# Sample from the English LM +en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt20.en') +en_lm.sample("Machine learning is") # 'Machine learning is a type of artificial intelligence that uses machine learning to learn from data and make predictions.' + +# Sample from the Tamil LM +ta_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt20.ta') +ta_lm.sample("இயந்திரக் கற்றல் என்பது செயற்கை நுண்ணறிவின்") # 'இயந்திரக் கற்றல் என்பது செயற்கை நுண்ணறிவின் ஒரு பகுதியாகும்.' + +# Sample from the Inuktitut LM +iu_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt20.iu.news') +iu_lm.sample("ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ") # 'ᖃᒧᑕᐅᔭᓄᑦ ᐃᓕᓐᓂᐊᕐᓂᖅ, ᐊᒻᒪᓗ ᓯᓚᐅᑉ ᐊᓯᙳᖅᐸᓪᓕᐊᓂᖓᓄᑦ ᖃᓄᐃᓕᐅᕈᑎᒃᓴᑦ, ᐃᓚᖃᖅᖢᑎᒃ ᐅᑯᓂᖓ:' +``` + +## Citation +```bibtex +@inproceedings{chen2020facebook + title={Facebook AI's WMT20 News Translation Task Submission}, + author={Peng-Jen Chen and Ann Lee and Changhan Wang and Naman Goyal and Angela Fan and Mary Williamson and Jiatao Gu}, + booktitle={Proc. of WMT}, + year={2020}, +} +``` diff --git a/fairseq/binarizer.py b/fairseq/binarizer.py index 0255c084b5..c736c8754d 100644 --- a/fairseq/binarizer.py +++ b/fairseq/binarizer.py @@ -46,7 +46,13 @@ def replaced_consumer(word, idx): # next(f) breaks f.tell(), hence readline() must be used line = safe_readline(f) while line: - if end > 0 and f.tell() > end: + # f.tell() does not always give the byte position in the file + # sometimes it skips to a very large number + # it is unlikely that through a normal read we go from + # end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely + # that the procedure breaks by the undeterministic behavior of + # f.tell() + if end > 0 and f.tell() > end and f.tell() < end + 2**32: break if already_numberized: id_strings = line.strip().split() diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 3038a1ebcc..764de0f5d1 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -5,12 +5,13 @@ import ast import collections +import contextlib import logging import os import re import traceback from collections import OrderedDict -from typing import Optional, Union +from typing import Any, Dict, Optional, Union import torch from fairseq.dataclass.configs import CheckpointConfig, FairseqConfig @@ -46,9 +47,6 @@ def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss): if not trainer.is_data_parallel_master: return - def is_better(a, b): - return a >= b if cfg.maximize_best_checkpoint_metric else a <= b - write_timer = meters.StopwatchMeter() write_timer.start() @@ -56,6 +54,11 @@ def is_better(a, b): end_of_epoch = epoch_itr.end_of_epoch() updates = trainer.get_num_updates() + logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") + + def is_better(a, b): + return a >= b if cfg.maximize_best_checkpoint_metric else a <= b + suffix = cfg.checkpoint_suffix or "" checkpoint_conds = collections.OrderedDict() checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( @@ -90,11 +93,18 @@ def is_better(a, b): if len(checkpoints) > 0: trainer.save_checkpoint(checkpoints[0], extra_state) for cp in checkpoints[1:]: - PathManager.copy(checkpoints[0], cp, overwrite=True) + # assert PathManager.copy( + # checkpoints[0], cp, overwrite=True + # ), f"Failed to copy {checkpoints[0]} to {cp}" + try: + os.remove(cp) + except: + pass + os.link(checkpoints[0], cp) write_timer.stop() logger.info( - "saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( + "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( checkpoints[0], epoch, updates, val_loss, write_timer.sum ) ) @@ -239,7 +249,13 @@ def load_checkpoint_to_cpu(path, arg_overrides=None): def load_model_ensemble( - filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1 + filenames, + arg_overrides: Optional[Dict[str, Any]] = None, + task=None, + strict=True, + suffix="", + num_shards=1, + state=None, ): """Loads an ensemble of models. @@ -259,21 +275,32 @@ def load_model_ensemble( strict, suffix, num_shards, + state, ) return ensemble, args def load_model_ensemble_and_task( - filenames, arg_overrides=None, task=None, strict=True, suffix="", num_shards=1 + filenames, + arg_overrides: Optional[Dict[str, Any]] = None, + task=None, + strict=True, + suffix="", + num_shards=1, + state=None, ): + assert state is None or len(filenames) == 1 + from fairseq import tasks assert not ( strict and num_shards > 1 ), "Cannot load state dict with strict=True and checkpoint shards > 1" ensemble = [] + cfg = None for filename in filenames: orig_filename = filename + assert num_shards > 0 for shard_idx in range(num_shards): if num_shards == 1: filename = filename.replace(".pt", suffix + ".pt") @@ -282,7 +309,8 @@ def load_model_ensemble_and_task( if not PathManager.exists(filename): raise IOError("Model file not found: {}".format(filename)) - state = load_checkpoint_to_cpu(filename, arg_overrides) + if state is None: + state = load_checkpoint_to_cpu(filename, arg_overrides) if "args" in state and state["args"] is not None: cfg = convert_namespace_to_omegaconf(state["args"]) elif "cfg" in state and state["cfg"] is not None: @@ -299,6 +327,10 @@ def load_model_ensemble_and_task( model = task.build_model(cfg.model) model.load_state_dict(state["model"], strict=strict, model_cfg=cfg.model) + + # reset state so it gets loaded for the next model in ensemble + state = None + ensemble.append(model) return ensemble, cfg, task @@ -385,8 +417,15 @@ def save_state( # keep everything on CPU state_dict = utils.move_to_cpu(state_dict) - with PathManager.open(filename, "wb") as f: - torch_persistent_save(state_dict, f) + if PathManager.supports_rename(filename): + # do atomic save + with PathManager.open(filename + ".tmp", "wb") as f: + torch_persistent_save(state_dict, f) + PathManager.rename(filename + ".tmp", filename) + else: + # fallback to non-atomic save + with PathManager.open(filename, "wb") as f: + torch_persistent_save(state_dict, f) def _upgrade_state_dict(state): @@ -428,6 +467,12 @@ def _upgrade_state_dict(state): # keep track of number of updates if "num_updates" not in state["optimizer_history"][-1]: state["optimizer_history"][-1]["num_updates"] = 0 + # old model checkpoints may not have separate source/target positions + if hasattr(state["args"], "max_positions") and not hasattr( + state["args"], "max_source_positions" + ): + state["args"].max_source_positions = state["args"].max_positions + state["args"].max_target_positions = state["args"].max_positions # use stateful training data iterator if "train_iterator" not in state["extra_state"]: state["extra_state"]["train_iterator"] = { @@ -435,7 +480,6 @@ def _upgrade_state_dict(state): "iterations_in_epoch": state["extra_state"].get("batch_offset", 0), } - # old model checkpoints may not have separate source/target positions # backward compatibility, cfg updates if "args" in state and state["args"] is not None: # default to translation task @@ -451,24 +495,38 @@ def _upgrade_state_dict(state): state["extra_state"]["train_iterator"]["epoch"] = max( state["extra_state"]["train_iterator"].get("epoch", 1), 1 ) - + # --remove-bpe ==> --postprocess if hasattr(state["args"], "remove_bpe"): state["args"].post_process = state["args"].remove_bpe + # --min-lr ==> --stop-min-lr + if hasattr(state["args"], "min_lr"): + state["args"].stop_min_lr = state["args"].min_lr + del state["args"].min_lr + # binary_cross_entropy => wav2vec criterion + if ( + hasattr(state["args"], "criterion") + and state["args"].criterion == "binary_cross_entropy" + ): + state["args"].criterion = "wav2vec" + # speech_pretraining => audio pretraining + if ( + hasattr(state["args"], "task") + and state["args"].task == "speech_pretraining" + ): + state["args"].task = "audio_pretraining" + # audio_cpc => wav2vec + if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc": + state["args"].arch = "wav2vec" + # convert legacy float learning rate to List[float] + if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float): + state["args"].lr = [state["args"].lr] state["cfg"] = convert_namespace_to_omegaconf(state["args"]) if "cfg" in state and state["cfg"] is not None: with open_dict(state["cfg"]): - if state["cfg"].task is not None: - if hasattr(state["cfg"].task, "max_positions") and not hasattr( - state["cfg"].task, "max_source_positions" - ): - state["cfg"].task.max_source_positions = state[ - "cfg" - ].task.max_positions - state["cfg"].task.max_target_positions = state[ - "cfg" - ].task.max_positions + # any upgrades for Hydra-based configs + pass return state @@ -553,8 +611,11 @@ def create_pruning_pass(layers_to_keep, layer_name): # Since layers are now pruned, *_layers_to_keep are no longer needed. # This is more of "It would make it work fix" rather than a proper fix. - - with open_dict(model_cfg): + if isinstance(model_cfg, DictConfig): + context = open_dict(model_cfg) + else: + context = contextlib.ExitStack() + with context: if hasattr(model_cfg, "encoder_layers_to_keep"): model_cfg.encoder_layers_to_keep = None if hasattr(model_cfg, "decoder_layers_to_keep"): diff --git a/fairseq/config/__init__.py b/fairseq/config/__init__.py new file mode 100644 index 0000000000..6264236915 --- /dev/null +++ b/fairseq/config/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/config/config.yaml b/fairseq/config/config.yaml index 039609aece..e20d914b9b 100644 --- a/fairseq/config/config.yaml +++ b/fairseq/config/config.yaml @@ -1,10 +1,15 @@ # @package _group_ + +hydra: + run: + dir: . + defaults: - - task: language_modeling + - task: null - model: null - criterion: cross_entropy - - optimizer: adam - - lr_scheduler: cosine + - optimizer: null + - lr_scheduler: fixed - bpe: null - tokenizer: null - scoring: null diff --git a/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml b/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml new file mode 100644 index 0000000000..ee1329bf46 --- /dev/null +++ b/fairseq/config/model/wav2vec/vq_wav2vec_gumbel.yaml @@ -0,0 +1,5 @@ +# @package _group_ +activation: gelu +vq_type: gumbel +vq_depth: 2 +combine_groups: true diff --git a/fairseq/config/model/wav2vec2/wav2vec2_base.yaml b/fairseq/config/model/wav2vec2/wav2vec2_base.yaml new file mode 100644 index 0000000000..ce65499b80 --- /dev/null +++ b/fairseq/config/model/wav2vec2/wav2vec2_base.yaml @@ -0,0 +1,8 @@ +# @package _group_ + +quantize_targets: true +final_dim: 256 +encoder_layerdrop: 0.05 +dropout_input: 0.1 +dropout_features: 0.1 +feature_grad_mult: 0.1 diff --git a/fairseq/config/model/wav2vec2/wav2vec2_large.yaml b/fairseq/config/model/wav2vec2/wav2vec2_large.yaml new file mode 100644 index 0000000000..5846f75243 --- /dev/null +++ b/fairseq/config/model/wav2vec2/wav2vec2_large.yaml @@ -0,0 +1,20 @@ +# @package _group_ + +quantize_targets: true +extractor_mode: layer_norm +layer_norm_first: true +final_dim: 768 +latent_temp: [2.0,0.1,0.999995] +encoder_layerdrop: 0.0 +dropout_input: 0.0 +dropout_features: 0.0 +dropout: 0.0 +attention_dropout: 0.0 +conv_bias: true + +encoder_layers: 24 +encoder_embed_dim: 1024 +encoder_ffn_embed_dim: 4096 +encoder_attention_heads: 16 + +feature_grad_mult: 1.0 diff --git a/fairseq/criterions/ctc.py b/fairseq/criterions/ctc.py index 6b77ce47eb..8cb1331825 100644 --- a/fairseq/criterions/ctc.py +++ b/fairseq/criterions/ctc.py @@ -6,39 +6,92 @@ import math from argparse import Namespace +from dataclasses import dataclass, field +from omegaconf import II +from typing import Optional import torch import torch.nn.functional as F from fairseq import metrics, utils -from fairseq.criterions import LegacyFairseqCriterion, register_criterion +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass from fairseq.data.data_utils import post_process +from fairseq.tasks import FairseqTask from fairseq.logging.meters import safe_round -@register_criterion("ctc") -class CtcCriterion(LegacyFairseqCriterion): - def __init__(self, args, task): - super().__init__(args, task) - self.blank_idx = task.target_dictionary.bos() +@dataclass +class CtcCriterionConfig(FairseqDataclass): + zero_infinity: bool = field( + default=False, + metadata={"help": "zero inf loss when source length <= target length"}, + ) + sentence_avg: bool = II("optimization.sentence_avg") + post_process: str = field( + default="letter", + metadata={ + "help": "how to post process predictions into words. can be letter, " + "wordpiece, BPE symbols, etc. " + "See fairseq.data.data_utils.post_process() for full list of options" + }, + ) + wer_kenlm_model: Optional[str] = field( + default=None, + metadata={ + "help": "if this is provided, use kenlm to compute wer (along with other wer_* args)" + }, + ) + wer_lexicon: Optional[str] = field( + default=None, + metadata={"help": "lexicon to use with wer_kenlm_model"}, + ) + wer_lm_weight: float = field( + default=2.0, + metadata={"help": "lm weight to use with wer_kenlm_model"}, + ) + wer_word_score: float = field( + default=-1.0, + metadata={"help": "lm word score to use with wer_kenlm_model"}, + ) + + wer_args: Optional[str] = field( + default=None, + metadata={ + "help": "DEPRECATED: tuple of (wer_kenlm_model, wer_lexicon, wer_lm_weight, wer_word_score)" + }, + ) + + +@register_criterion("ctc", dataclass=CtcCriterionConfig) +class CtcCriterion(FairseqCriterion): + def __init__(self, cfg: CtcCriterionConfig, task: FairseqTask): + super().__init__(task) + self.blank_idx = task.target_dictionary.index(task.blank_symbol) self.pad_idx = task.target_dictionary.pad() self.eos_idx = task.target_dictionary.eos() - self.post_process = args.post_process if args.post_process else "letter" + self.post_process = cfg.post_process - if args.wer_args is not None: - from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder + if cfg.wer_args is not None: + ( + cfg.wer_kenlm_model, + cfg.wer_lexicon, + cfg.wer_lm_weight, + cfg.wer_word_score, + ) = eval(cfg.wer_args) - wer_compute_kenlm, wer_lexicon, lm_w, ws_w = eval(args.wer_args) + if cfg.wer_kenlm_model is not None: + from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder dec_args = Namespace() dec_args.nbest = 1 dec_args.criterion = "ctc" - dec_args.kenlm_model = wer_compute_kenlm - dec_args.lexicon = wer_lexicon + dec_args.kenlm_model = cfg.wer_kenlm_model + dec_args.lexicon = cfg.wer_lexicon dec_args.beam = 50 dec_args.beam_size_token = min(50, len(task.target_dictionary)) dec_args.beam_threshold = min(50, len(task.target_dictionary)) - dec_args.lm_weight = lm_w - dec_args.word_score = ws_w + dec_args.lm_weight = cfg.wer_lm_weight + dec_args.word_score = cfg.wer_word_score dec_args.unk_weight = -math.inf dec_args.sil_weight = 0 @@ -46,31 +99,8 @@ def __init__(self, args, task): else: self.w2l_decoder = None - self.zero_infinity = args.zero_infinity - self.sentence_avg = args.sentence_avg - - @staticmethod - def add_args(parser): - """Add criterion-specific arguments to the parser.""" - parser.add_argument( - "--zero-infinity", action="store_true", help="zero inf loss" - ) - try: - parser.add_argument( - "--post-process", - "--remove-bpe", - default="letter", - help="remove BPE tokens before scoring (can be set to sentencepiece, letter, and more)", - ) - except: - pass # this option might have been added from eval args - parser.add_argument( - "--wer-args", - type=str, - default=None, - help="options for wer computation on valid set using 4 gram lm. this should be a tuple of 4 elements: path to 4-gram lm, \ - path to lexicon, lm score, word score", - ) + self.zero_infinity = cfg.zero_infinity + self.sentence_avg = cfg.sentence_avg def forward(self, model, sample, reduce=True): net_output = model(**sample["net_input"]) @@ -88,7 +118,10 @@ def forward(self, model, sample, reduce=True): sample["target"] != self.eos_idx ) targets_flat = sample["target"].masked_select(pad_mask) - target_lengths = sample["target_lengths"] + if "target_lengths" in sample: + target_lengths = sample["target_lengths"] + else: + target_lengths = pad_mask.sum(-1) with torch.backends.cudnn.flags(enabled=False): loss = F.ctc_loss( diff --git a/fairseq/criterions/fairseq_criterion.py b/fairseq/criterions/fairseq_criterion.py index b2eda1a7e4..ff4beb0250 100644 --- a/fairseq/criterions/fairseq_criterion.py +++ b/fairseq/criterions/fairseq_criterion.py @@ -7,8 +7,8 @@ from typing import Any, Dict, List from fairseq import metrics, utils +from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import gen_parser_from_dataclass -from omegaconf import DictConfig from torch.nn.modules.loss import _Loss @@ -28,7 +28,7 @@ def add_args(cls, parser): gen_parser_from_dataclass(parser, dc()) @classmethod - def build_criterion(cls, cfg: DictConfig, task): + def build_criterion(cls, cfg: FairseqDataclass, task): """Construct a criterion from command-line args.""" # arguments in the __init__. init_args = {} @@ -46,6 +46,8 @@ def build_criterion(cls, cfg: DictConfig, task): if p.name == "task": init_args["task"] = task + elif p.name == "cfg": + init_args["cfg"] = cfg elif hasattr(cfg, p.name): init_args[p.name] = getattr(cfg, p.name) elif p.default != p.empty: diff --git a/fairseq/criterions/model_criterion.py b/fairseq/criterions/model_criterion.py new file mode 100644 index 0000000000..8e366a5d85 --- /dev/null +++ b/fairseq/criterions/model_criterion.py @@ -0,0 +1,138 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass, field +from typing import Dict, List + +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelCriterionConfig(FairseqDataclass): + loss_weights: Dict[str, float] = field( + default_factory=dict, + metadata={"help": "weights for the loss terms"}, + ) + log_keys: List[str] = field( + default_factory=list, + metadata={"help": "additional output keys to log"}, + ) + + +@register_criterion("model", dataclass=ModelCriterionConfig) +class ModelCriterion(FairseqCriterion): + """ + This criterion relies on the model to supply losses. + The losses should be a dictionary of name -> scalar returned by + the model either by including it in the net_output dict or by + implementing a get_losses(net_output, sample) method. The final loss is + a scaled sum of all losses according to weights in loss_weights. + If no weights are provided, then all losses are scaled by 1.0. + + The losses will be automatically logged. Additional keys from + net_output dict can be logged via the log_keys parameter. + """ + + def __init__(self, task, loss_weights=None, log_keys=None): + super().__init__(task) + self.loss_weights = loss_weights + self.log_keys = log_keys + + def forward(self, model, sample, reduce=True): + net_output = model(**sample["net_input"]) + + sample_size = net_output["sample_size"] + scaled_losses = {} + + if hasattr(model, "get_losses"): + losses = model.get_losses(net_output, sample) + elif isinstance(net_output, dict) and "losses" in net_output: + losses = net_output["losses"] + else: + raise Exception("Could not retrieve losses") + + for lk, p in losses.items(): + try: + coef = 1.0 if len(self.loss_weights) == 0 else self.loss_weights[lk] + except KeyError: + logger.error( + f"weight for loss {lk} is not in loss_weights ({self.loss_weights})" + ) + raise + if coef != 0 and p is not None: + scaled_losses[lk] = coef * p.float() + + loss = sum(scaled_losses.values()) + if reduce and loss.numel() > 1: + loss = loss.sum() + + logging_output = { + "loss": loss.data, + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + "_world_size": 1, + } + + for lk in self.log_keys: + if lk in net_output: + logging_output[lk] = float(net_output[lk]) + + if len(scaled_losses) > 1: + for lk, l in scaled_losses.items(): + logging_output[f"loss_{lk}"] = l.item() + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + ntokens = utils.item(sum(log.get("ntokens", 0) for log in logging_outputs)) + nsentences = utils.item( + sum(log.get("nsentences", 0) for log in logging_outputs) + ) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + + metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) + metrics.log_scalar("ntokens", ntokens) + metrics.log_scalar("nsentences", nsentences) + + builtin_keys = { + "loss", + "ntokens", + "nsentences", + "sample_size", + "_world_size", + } + + world_size = utils.item( + sum(log.get("_world_size", 0) for log in logging_outputs) + ) + + for k in logging_outputs[0]: + if k not in builtin_keys: + val = sum(log.get(k, 0) for log in logging_outputs) + if k.startswith("loss_"): + metrics.log_scalar(k, val / sample_size, sample_size, round=3) + else: + metrics.log_scalar(k, val / world_size, round=3) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True diff --git a/fairseq/criterions/probe_criterion.py b/fairseq/criterions/probe_criterion.py new file mode 100644 index 0000000000..e8029fedb6 --- /dev/null +++ b/fairseq/criterions/probe_criterion.py @@ -0,0 +1,71 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass +from fairseq.logging.meters import safe_round +from fairseq.models.probed_model import reduce_probe_metrics + +@dataclass +class ProbeCriterionConfig(FairseqDataclass): + pass + + +@register_criterion("probes", dataclass=ProbeCriterionConfig) +class ProbeCriterion(FairseqCriterion): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + sample_size = 1 + + probe_loss, probe_log_outs = model.get_probe_losses(sample) + loss = probe_loss + + logging_output = { + "loss": loss.item(), + "ntokens": 1, + "nsentences": 1, + "sample_size": sample_size, + } + logging_output.update(probe_log_outs) + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = utils.item(sum(log.get("loss", 0) for log in logging_outputs)) + sample_size = utils.item( + sum(log.get("sample_size", 0) for log in logging_outputs) + ) + metrics.log_scalar( + "loss", loss_sum / sample_size, sample_size, round=3 + ) + reduce_probe_metrics(logging_outputs, metrics) + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 6ac7557dcc..f96c56d1d9 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -4,35 +4,44 @@ # LICENSE file in the root directory of this source tree. import math +from dataclasses import dataclass, field +from typing import List, Optional import torch import torch.nn.functional as F from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass from fairseq.logging.meters import safe_round - - -@register_criterion("wav2vec") +from fairseq.models.probed_model import reduce_probe_metrics + +@dataclass +class Wav2VecCriterionConfig(FairseqDataclass): + infonce: bool = field( + default=False, + metadata={ + "help": "if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)" + }, + ) + loss_weights: Optional[List[float]] = field( + default=None, + metadata={"help": "weights for additional loss terms (not first one)"}, + ) + log_keys: List[str] = field( + default_factory=lambda: [], + metadata={"help": "output keys to log"}, + ) + + +@register_criterion("wav2vec", dataclass=Wav2VecCriterionConfig) class Wav2vecCriterion(FairseqCriterion): def __init__(self, task, infonce=False, loss_weights=None, log_keys=None): super().__init__(task) self.infonce = infonce - self.loss_weights = None if loss_weights is None else eval(loss_weights) - self.log_keys = [] if log_keys is None else eval(log_keys) + self.loss_weights = loss_weights + self.log_keys = [] if log_keys is None else log_keys - @staticmethod - def add_args(parser): - """Add criterion-specific arguments to the parser.""" - # fmt: off - parser.add_argument('--infonce', action='store_true', - help='if set, uses cross entropy instead of binary cross entropy (i.e. InfoNCE loss)') - parser.add_argument('--loss-weights', type=str, default=None, - help='weights for additional loss terms (not first one)') - parser.add_argument('--log-keys', type=str, default=None, - help='output keys to log') - # fmt: on - - def forward(self, model, sample, reduce=True, log_pred=False): + def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. Returns a tuple with three elements: @@ -100,6 +109,11 @@ def forward(self, model, sample, reduce=True, log_pred=False): for i, l in enumerate(losses): logging_output[f"loss_{i}"] = l.item() + if hasattr(model, 'get_probe_losses'): + probe_loss, probe_log_outs = model.get_probe_losses(sample) + loss += probe_loss + logging_output.update(probe_log_outs) + if self.infonce: with torch.no_grad(): if logits.numel() == 0: @@ -116,9 +130,6 @@ def forward(self, model, sample, reduce=True, log_pred=False): logging_output["correct"] = corr logging_output["count"] = count - if log_pred: - logging_output["logits"] = logits.cpu().numpy() - logging_output["target"] = target.cpu().numpy() return loss, sample_size, logging_output @staticmethod @@ -164,15 +175,19 @@ def reduce_metrics(logging_outputs) -> None: "count", } + handled_keys = reduce_probe_metrics(logging_outputs, metrics) + builtin_keys.update(handled_keys) + for k in logging_outputs[0]: if k not in builtin_keys: - val = sum(log.get(k, 0) for log in logging_outputs) / len( - logging_outputs - ) + val = sum(log.get(k, 0) for log in logging_outputs) if k.startswith("loss"): - metrics.log_scalar(k, val / sample_size / math.log(2), sample_size) + metrics.log_scalar( + k, val / sample_size / math.log(2), sample_size, round=3 + ) else: - metrics.log_scalar(k, val, round=3) + metrics.log_scalar(k, val / len(logging_outputs), round=3) + @staticmethod def logging_outputs_can_be_summed() -> bool: diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 81f457365a..1efe352dd2 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -346,8 +346,12 @@ def post_process(sentence: str, symbol: str): sentence = sentence.replace(" ", "").replace("|", " ").strip() elif symbol == "_EOW": sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() - elif symbol is not None and symbol != "none": - sentence = (sentence + " ").replace(symbol, "").rstrip() + elif symbol in {"subword_nmt", "@@ "}: + sentence = (sentence + " ").replace("@@ ", "").rstrip() + elif symbol == "none": + pass + elif symbol is not None: + raise NotImplementedError(f"Unknown post_process option: {symbol}") return sentence diff --git a/fairseq/data/dictionary.py b/fairseq/data/dictionary.py index e2df08e092..127d023f4c 100644 --- a/fairseq/data/dictionary.py +++ b/fairseq/data/dictionary.py @@ -69,6 +69,7 @@ def string( escape_unk=False, extra_symbols_to_ignore=None, unk_string=None, + include_eos=False, ): """Helper for converting a tensor of token indices to a string. @@ -76,7 +77,7 @@ def string( """ if torch.is_tensor(tensor) and tensor.dim() == 2: return "\n".join( - self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore) + self.string(t, bpe_symbol, escape_unk, extra_symbols_to_ignore, include_eos=include_eos) for t in tensor ) @@ -334,7 +335,13 @@ def _add_file_to_dictionary_single_worker( for word in tokenize(line): counter.update([word]) counter.update([eos_word]) - if f.tell() > end: + # f.tell() returns only an opaque number which can + # return to the position in the file via f.seek() + # and does not necessarily represent a byte position + # in the file. However, f.tell() is faithful to the + # byte position _most of the time_. Thus we can just + # check against the file size to prevent early exit. + if f.tell() > end and f.tell() < size: break line = f.readline() return counter diff --git a/fairseq/data/encoders/hf_byte_bpe.py b/fairseq/data/encoders/hf_byte_bpe.py index 92d2c3922c..c508578d41 100644 --- a/fairseq/data/encoders/hf_byte_bpe.py +++ b/fairseq/data/encoders/hf_byte_bpe.py @@ -7,6 +7,7 @@ from fairseq.data.encoders import register_bpe from fairseq.dataclass import FairseqDataclass +from fairseq import file_utils @dataclass @@ -28,9 +29,12 @@ def __init__(self, cfg): "Please install huggingface/tokenizers with: " "pip install tokenizers" ) + bpe_vocab = file_utils.cached_path(cfg.bpe_vocab) + bpe_merges = file_utils.cached_path(cfg.bpe_merges) + self.bpe = ByteLevelBPETokenizer( - cfg.bpe_vocab, - cfg.bpe_merges, + bpe_vocab, + bpe_merges, add_prefix_space=cfg.bpe_add_prefix_space, ) diff --git a/fairseq/data/encoders/nltk_tokenizer.py b/fairseq/data/encoders/nltk_tokenizer.py index ee164710a0..0ab92377b3 100644 --- a/fairseq/data/encoders/nltk_tokenizer.py +++ b/fairseq/data/encoders/nltk_tokenizer.py @@ -4,9 +4,10 @@ # LICENSE file in the root directory of this source tree. from fairseq.data.encoders import register_tokenizer +from fairseq.dataclass import FairseqDataclass -@register_tokenizer("nltk") +@register_tokenizer("nltk", dataclass=FairseqDataclass) class NLTKTokenizer(object): def __init__(self, *unused): try: diff --git a/fairseq/data/encoders/space_tokenizer.py b/fairseq/data/encoders/space_tokenizer.py index 7c7f644d5c..925ad41b7c 100644 --- a/fairseq/data/encoders/space_tokenizer.py +++ b/fairseq/data/encoders/space_tokenizer.py @@ -6,9 +6,10 @@ import re from fairseq.data.encoders import register_tokenizer +from fairseq.dataclass import FairseqDataclass -@register_tokenizer("space") +@register_tokenizer("space", dataclass=FairseqDataclass) class SpaceTokenizer(object): def __init__(self, *unused): self.space_tok = re.compile(r"\s+") diff --git a/fairseq/data/handwriting/aligner.py b/fairseq/data/handwriting/aligner.py index 79a3e45dc4..3a03f994ae 100644 --- a/fairseq/data/handwriting/aligner.py +++ b/fairseq/data/handwriting/aligner.py @@ -290,6 +290,7 @@ def stretchToLength(self, path_, newSz_, insertSeparators_=False, symbolBoundari symbolBoundaries_[t] = [ newIdx, symbolBoundaries_[t][1] ] # Optional insert of blanks between repeats like 'oo' + # [!] this is actually used like that, and symbol alignment alg when cropping should care for that (and new one should do now) if (symbolBoundaries_ is not None ) and insertSeparators_: self.insertBlanksAsSymbolSeperators(newPath, symbolBoundaries_) diff --git a/fairseq/data/handwriting/raw_handwriting_dataset.py b/fairseq/data/handwriting/raw_handwriting_dataset.py index 6f6eb57519..b710f7eaee 100644 --- a/fairseq/data/handwriting/raw_handwriting_dataset.py +++ b/fairseq/data/handwriting/raw_handwriting_dataset.py @@ -12,6 +12,7 @@ import torch import torch.nn.functional as F +from .utils import num_between from . import scribblelens from .. import FairseqDataset @@ -83,17 +84,18 @@ def postprocess(self, feats, curr_sample_rate): feats = F.layer_norm(feats, feats.shape) return feats - def crop_to_max_size(self, wav, target_size_dim1, alignment=None): + def crop_to_max_size(self, wav, target_size_dim1, alignment=None): # TODO perhaps change to just return indices, a bit cleaner? # if alignment set, cut it too - TODO maybe also mask half a letter etc., also in data! + # but maybe do this on crop labels stuff (mask if less than half of a letter visible or so) size = wav.shape[1] #len(wav) diff = size - target_size_dim1 if diff <= 0: - if alignment: - return wav, alignment + if alignment is not None: + return wav, alignment, (0, size-1) else: - return wav + return wav, (0, size-1) if self.shuffle: start = np.random.randint(0, diff + 1) @@ -101,38 +103,13 @@ def crop_to_max_size(self, wav, target_size_dim1, alignment=None): # Deterministically pick the middle part start = (diff + 1) //2 end = size - diff + start - if alignment: - return wav[:, start:end], alignment[start:end] + if alignment is not None: + return wav[:, start:end], alignment[start:end], (start, end-1) else: - return wav[:, start:end] + return wav[:, start:end], (start, end-1) # end inclusive def collater(self, samples): - # TODO stuff with labels - # collated = self.dataset.collater(samples) - # if len(collated) == 0: - # return collated - # indices = set(collated["id"].tolist()) - # target = [s["label"] for s in samples if s["id"] in indices] - - # if self.batch_targets: - # collated["target_lengths"] = torch.LongTensor([len(t) for t in target]) - # target = data_utils.collate_tokens(target, pad_idx=self.pad, left_pad=False) - # collated["ntokens"] = collated["target_lengths"].sum().item() - # else: - # collated["ntokens"] = sum([len(t) for t in target]) - - # collated["target"] = target - - # if self.add_to_input: - # eos = target.new_full((target.size(0), 1), self.eos) - # collated["target"] = torch.cat([target, eos], dim=-1).long() - # collated["net_input"]["prev_output_tokens"] = torch.cat([eos, target], dim=-1).long() - # collated["ntokens"] += target.size(0) - #return collated - - - samples = [ s for s in samples @@ -167,17 +144,38 @@ def collater(self, samples): if self.labels: collated_labels_nontensor = [] #collated_texts_nontensor = [] # TODO - collated_alignments = samples[0]["alignment"].new_zeros((len(sources), target_size)) + collated_alignments = samples[0]["alignment"].new_full((len(sources), target_size), self.label_pad_idx) + available_labels = torch.BoolTensor(size=(len(samples),)).fill_(False) + available_alignments = torch.BoolTensor(size=(len(samples),)).fill_(False) + labels_texts = [] + alignments_texts = [] + # labels_with_ranges_arr is not a tensor, it just has None where unavailable + + labels_with_ranges_arr = [] for i, (sample, size) in enumerate(zip(samples, sizes)): source = sample["source"] diff = size - target_size if diff == 0: collated_sources[i] = source if self.labels: - collated_labels_nontensor.append(sample["label"]) + if sample["label_available"]: + available_labels[i] = True + collated_labels_nontensor.append(sample["label"]) + labels_texts.append(sample["label_text"]) + else: + collated_labels_nontensor.append(None) + labels_texts.append([]) #collated_texts_nontensor.append(sample["text"]) - collated_alignments[i] = sample["alignment"] + if sample["alignment_available"]: + available_alignments[i] = True + collated_alignments[i] = sample["alignment"] + labels_with_ranges = self.get_letter_ranges(sample["alignment"], sample["alignment_idx"]) + labels_with_ranges_arr.append([(a, (c,d)) for a, _, (c, d) in labels_with_ranges]) + alignments_texts.append(sample["alignment_text"]) + else: + labels_with_ranges_arr.append(None) + alignments_texts.append([]) elif diff < 0: assert self.pad collated_sources[i] = torch.cat( @@ -186,19 +184,52 @@ def collater(self, samples): ) padding_mask[i, :, diff:] = True if self.labels: - collated_alignments[i] = torch.cat([sample["alignment"], sample["alignment"].new_full((-diff,), self.label_pad_idx)]) - coll_labels = sample["label"] #self.collate_labels(collated_alignments[i], sample["label"], sample["text"]) - collated_labels_nontensor.append(coll_labels) - #collated_texts_nontensor.append(coll_text) + if sample["label_available"]: + available_labels[i] = True + collated_labels_nontensor.append(sample["label"]) + labels_texts.append(sample["label_text"]) + else: + collated_labels_nontensor.append(None) + labels_texts.append([]) + if sample["alignment_available"]: + available_alignments[i] = True + collated_alignments[i] = torch.cat([sample["alignment"], sample["alignment"].new_full((-diff,), self.label_pad_idx)]) + labels_with_ranges = self.get_letter_ranges(sample["alignment"], sample["alignment_idx"]) + labels_with_ranges_arr.append([(a, (c,d)) for a, _, (c, d) in labels_with_ranges]) + alignments_texts.append(sample["alignment_text"]) + else: + labels_with_ranges_arr.append(None) + alignments_texts.append([]) else: - # only case with cropping TODO fix case with double letters without space between + # only case with cropping if self.labels: - collated_sources[i], collated_alignments[i] = self.crop_to_max_size(source, target_size, alignment=sample["alignment"]) - coll_labels = self.collate_labels(collated_alignments[i], sample["label"], sample["text"]) - collated_labels_nontensor.append(coll_labels) + collated_sources[i], collated_alignments[i], (start, end) = self.crop_to_max_size(source, target_size, alignment=sample["alignment"]) + + if sample["label_available"]: + available_labels[i] = True + collated_labels_nontensor.append(sample["label"]) # if no allignments, can't do better (or could put "None"); if there are allignments, will amend below + labels_texts.append(sample["label_text"]) # if no allignments, can't do better (or could put "None"); if there are allignments, will amend below + else: + collated_labels_nontensor.append(None) + labels_texts.append([]) + if sample["label_available"] and sample["alignment_available"] and torch.all(torch.eq(sample["label"], sample["alignment_idx"])): + available_alignments[i] = True + # also update alignments - possible padding of halves of letters on the borders etc. + coll_labels, coll_labels_w_spaces, labels_with_ranges, collated_alignments[i], pad_begin = self.collate_labels(sample["alignment"], collated_alignments[i], sample["alignment_idx"], start, end) #, sample["text"]) + if pad_begin is not None: + padding_mask[i, :, pad_begin:] = True + cropped_text = ''.join([sample["label_text"][i] for _, i, _ in labels_with_ranges]) + if sample["label_available"]: + labels_texts[-1] = cropped_text # fix non-collated appended stuff + collated_labels_nontensor[-1] = coll_labels + alignments_texts.append(cropped_text) + labels_with_ranges_arr.append([(a, (c,d)) for a, _, (c, d) in labels_with_ranges]) + else: + labels_with_ranges_arr.append(None) + alignments_texts.append([]) #collated_texts_nontensor.append(coll_text) else: - collated_sources[i] = self.crop_to_max_size(source, target_size) + collated_sources[i], _ = self.crop_to_max_size(source, target_size) input = {"source": collated_sources} if self.pad: @@ -208,34 +239,172 @@ def collater(self, samples): collated_labels = torch.IntTensor(size=(len(collated_labels_nontensor), max([len(i) for i in collated_labels_nontensor]))).fill_(self.label_pad_idx) for i, label in enumerate(collated_labels_nontensor): collated_labels[i][:len(label)] = torch.tensor(label) - # TODO check collate labels to common length in a tensor - # TODO EOS stuff (?) - target_lengths = torch.LongTensor([len(t) for t in collated_labels_nontensor]) + + # TODO EOS stuff (?) maybe rather as an option + + # zeros where None + target_lengths = torch.LongTensor([len(t) if t is not None else 0 for t in collated_labels_nontensor]) + input["alignments"] = collated_alignments + + # [!] stuff with "_available" tells if data "\is actually present in the tensors or are there some defaults or sth return { "id": torch.LongTensor([s["id"] for s in samples]), "net_input": input, - "target_lengths": target_lengths, + # [!] format as for ctc criterion (["target"]), other criterions would need to be updated to use labels "target": collated_labels, # data_utils.collate_tokens(collated_labels_nontensor, pad_idx=self.pad, left_pad=False), - "ntokens": target_lengths.sum().item(), - "alignments": collated_alignments - #"label_texts": collated_texts_nontensor, # TODO? non-collated texts of collated stuff + "target_available": available_labels, + "target_texts": labels_texts, + "target_lengths": target_lengths, # 0 where there are no labels + "ntokens": target_lengths.sum().item(), # only sums tokens where there are labels + "alignments": collated_alignments, + "alignments_available": available_alignments, + "alignments_texts": alignments_texts, + "labels_with_ranges": labels_with_ranges_arr # None where unavailable; (char_id) } else: return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input} - def collate_labels(self, collated_alignments, full_label, full_text): # label is a list, text is a string + @staticmethod + def get_chars_ranges_uniform(num_chars, chars_begin, chars_end): + full_len_diff = chars_end - chars_begin + for_1_letter = float(full_len_diff) / float(num_chars) + # assuming will always have at least 1 frame for a letter on average + begin = chars_begin + end = int(round(chars_begin + for_1_letter)) + collated_label = [] + for j in range(num_chars): + collated_label.append((begin, end)) + begin = end + 1 + end = int(round(chars_begin + (j+2)*for_1_letter)) + return collated_label + + def get_letter_ranges(self, full_alignments_original, full_label_original, cut_start_original=None, cut_end_original=None): + + cut_start = 0 if cut_start_original is None else cut_start_original + cut_end = len(full_alignments_original) - 1 if cut_end_original is None else cut_end_original + last_idx = self.label_blank_idx + full_alignments = torch.cat([full_alignments_original, full_alignments_original.new_full((1,), self.label_pad_idx)]) # for no special case #decode_dict = {x.item(): y for x,y in zip(full_label, full_text)} # can zip like that as full_label is already a list - collated_label = [] - # [!] TODO fix case with double letters and stuff - for num in collated_alignments: - if num.item() != last_idx and num.item() != self.label_pad_idx: - last_idx = num - if num.item() != self.label_blank_idx: - collated_label.append(num.item()) - #collated_text = ''.join([decode_dict[x] for x in collated_label]) - return collated_label #, collated_text + naive_collated_label = [] + naive_letter_ranges = [] + letter_begin = 0 + + for i, numTensor in enumerate(full_alignments): + num = numTensor.item() + if num != last_idx: + if last_idx != self.label_pad_idx and last_idx != self.label_eos_idx: + naive_collated_label.append(last_idx) + naive_letter_ranges.append((letter_begin, i-1)) + letter_begin = i + last_idx = num + + full_label = torch.cat([full_label_original, full_label_original.new_full((1,), self.label_pad_idx)]) + # ^ so will always append thing before pad without additional case after the loop + label_qties = [] + last_idx = full_label[0] #self.label_blank_idx + qty = 0 + for numTensor in full_label: + num = numTensor.item() + if num != last_idx and last_idx != self.label_pad_idx and last_idx != self.label_eos_idx: + label_qties.append((last_idx, qty)) + qty = 1 + else: + qty += 1 + last_idx = num + + # TODO if stuff empty, return empty tensor or array or so + + next_ground_id = 0 + next_ground_seen = False + last_idx = self.label_blank_idx + collated_labels_with_ranges = [] + naive_letter_ranges.append((-1, -1)) + naive_collated_label.append(self.label_pad_idx) + ground_ranges = [] + max_end_before = -1 + min_begin_after = 2*len(full_alignments_original) + label_idx = -1 # to track positions from original label + + for char, (begin, end) in list(zip(naive_collated_label, naive_letter_ranges)): + # first append any blanks from ground truth, blanks from naive stuff are omitted later in the loop + # can also do it this way - would ignore some random 1-length blanks in alignments, although that should rather NOT happen + if label_qties[next_ground_id][0] == self.label_blank_idx: # rather can't happen, blanks not in label; TODO? maybe could also treat pad/sth else similarly, but rather not needed + # don't append blanks there, those are not spaces! + next_ground_id += 1 + if next_ground_id >= len(label_qties): + break + if char == self.label_blank_idx: # omit blanks in naive stuff not existent in ground, and calculate repetitions otherwise; blanks calculated from ground, above + continue # TODO? maybe could also treat pad/sth else similarly, but rather not needed + # from here no blanks in both places + if next_ground_seen and char != label_qties[next_ground_id][0]: + if len(ground_ranges) == label_qties[next_ground_id][1]: + #_ = RawHandwritingDataset.get_chars_ranges_uniform(3, begin, end) + ranges = ground_ranges + else: + ranges = RawHandwritingDataset.get_chars_ranges_uniform(label_qties[next_ground_id][1], begin, end) + for a, b in ranges: + mid = (a + b) // 2 + label_idx += 1 + if mid < cut_start: + max_end_before = max(max_end_before, b - cut_start) + continue + elif mid > cut_end: + min_begin_after = min(min_begin_after, a - cut_start) + continue + collated_labels_with_ranges.append([label_qties[next_ground_id][0], label_idx, [a, b]]) # need to update with label_qties[next_ground_id][0], NOT char - char is next + next_ground_id += 1 + next_ground_seen = False + ground_ranges = [] # to be all seen in char == label_qties[next_ground_id][0] case - also in this loop spin + if next_ground_id >= len(label_qties): + break + if char == label_qties[next_ground_id][0]: + next_ground_seen = True + ground_ranges.append((begin, end)) + # TODO else some error/warning or just ignore? could still work with messy alignments then + + # no blanks, no need to span stuff; spaces according to alignment + + if cut_start_original is not None or cut_end_original is not None: + return collated_labels_with_ranges, ((max_end_before if max_end_before >= 0 else None), (min_begin_after if min_begin_after < cut_end - cut_start + 1 else None)) + else: + return collated_labels_with_ranges + + # TODO separate function for getting the ranges of the letters and do this also when not cropping + # modifies initial collated_alignments if needed + def collate_labels(self, full_alignments_original, collated_alignments, full_label_original, cut_start, cut_end): #, full_text): # label is a list, text is a string + + collated_labels_with_ranges, (mask_to, mask_from) = self.get_letter_ranges(full_alignments_original, full_label_original, cut_start_original=cut_start, cut_end_original=cut_end) + + # mask_to & mask_from are indices + + pad_begin = None + # [!] masks before first in collated_labels_with_ranges and after last - if cut_start < first range, similarly with end + # (case when we have there e.g. = start and num <= end \ No newline at end of file diff --git a/fairseq/data/iterators.py b/fairseq/data/iterators.py index ef41fed739..0f55026ef8 100644 --- a/fairseq/data/iterators.py +++ b/fairseq/data/iterators.py @@ -138,6 +138,10 @@ def load_state_dict(self, state_dict): """Copies the state of the iterator from the given *state_dict*.""" raise NotImplementedError + @property + def first_batch(self): + return "DUMMY" + class StreamingEpochBatchIterator(EpochBatchIterating): def __init__( diff --git a/fairseq/data/lm_context_window_dataset.py b/fairseq/data/lm_context_window_dataset.py index 29ad887b7d..39512797bc 100644 --- a/fairseq/data/lm_context_window_dataset.py +++ b/fairseq/data/lm_context_window_dataset.py @@ -11,10 +11,23 @@ class LMContextWindowDataset(FairseqDataset): - """Wraps a MonolingualDataset and provides more context for evaluation.""" - - def __init__(self, dataset, tokens_per_sample, context_window, pad_idx): - assert isinstance(dataset, MonolingualDataset) + """ + Wraps a MonolingualDataset and provides more context for evaluation. + + Each item in the new dataset will have a maximum size of + ``tokens_per_sample + context_window``. + + Args: + dataset: dataset to wrap + tokens_per_sample (int): the max number of tokens in each dataset item + context_window (int): the number of accumulated tokens to add to each + dataset item + pad_idx (int): padding symbol + """ + + def __init__( + self, dataset, tokens_per_sample: int, context_window: int, pad_idx: int + ): assert context_window > 0 self.dataset = dataset self.tokens_per_sample = tokens_per_sample diff --git a/fairseq/data/mask_tokens_dataset.py b/fairseq/data/mask_tokens_dataset.py index 8ea86245f7..b239013c80 100644 --- a/fairseq/data/mask_tokens_dataset.py +++ b/fairseq/data/mask_tokens_dataset.py @@ -39,6 +39,10 @@ class MaskTokensDataset(BaseWrapperDataset): over vocab indices, indicating whether it is the beginning of a word. We will extend any mask to encompass the whole word. bpe: BPE to use for whole-word masking. + mask_multiple_length : repeat each mask index multiple times. Default + value is 1. + mask_stdev : standard deviation of masks distribution in case of + multiple masking. Default value is 0. """ @classmethod @@ -63,11 +67,15 @@ def __init__( random_token_prob: float = 0.1, freq_weighted_replacement: bool = False, mask_whole_words: torch.Tensor = None, + mask_multiple_length: int = 1, + mask_stdev: float = 0.0, ): assert 0.0 < mask_prob < 1.0 assert 0.0 <= random_token_prob <= 1.0 assert 0.0 <= leave_unmasked_prob <= 1.0 assert random_token_prob + leave_unmasked_prob <= 1.0 + assert mask_multiple_length >= 1 + assert mask_stdev >= 0.0 self.dataset = dataset self.vocab = vocab @@ -79,6 +87,8 @@ def __init__( self.leave_unmasked_prob = leave_unmasked_prob self.random_token_prob = random_token_prob self.mask_whole_words = mask_whole_words + self.mask_multiple_length = mask_multiple_length + self.mask_stdev = mask_stdev if random_token_prob > 0.0: if freq_weighted_replacement: @@ -122,10 +132,39 @@ def __getitem__(self, index: int): mask = np.full(sz, False) num_mask = int( # add a random number for probabilistic rounding - self.mask_prob * sz + self.mask_prob * sz / float(self.mask_multiple_length) + np.random.rand() ) - mask[np.random.choice(sz, num_mask, replace=False)] = True + + # multiple masking as described in the vq-wav2vec paper (https://arxiv.org/abs/1910.05453) + mask_idc = np.random.choice(sz, num_mask, replace=False) + if self.mask_stdev > 0.0: + lengths = np.random.normal( + self.mask_multiple_length, self.mask_stdev, size=num_mask + ) + lengths = [max(0, int(round(x))) for x in lengths] + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ], + dtype=np.int64, + ) + else: + mask_idc = np.concatenate( + [mask_idc + i for i in range(self.mask_multiple_length)] + ) + mask_idc = mask_idc[mask_idc < len(mask)] + try: + mask[mask_idc] = True + except: # something wrong + print( + "Assigning mask indexes {} to mask {} failed!".format( + mask_idc, mask + ) + ) + raise if self.return_masked_tokens: # exit early if we're just returning the masked tokens diff --git a/fairseq/data/monolingual_dataset.py b/fairseq/data/monolingual_dataset.py index ec73f1fda8..bf7aa86f6c 100644 --- a/fairseq/data/monolingual_dataset.py +++ b/fairseq/data/monolingual_dataset.py @@ -70,16 +70,16 @@ def __init__( dataset, sizes, src_vocab, - tgt_vocab, - add_eos_for_other_targets, - shuffle, + tgt_vocab=None, + add_eos_for_other_targets=False, + shuffle=False, targets=None, add_bos_token=False, ): self.dataset = dataset self.sizes = np.array(sizes) self.vocab = src_vocab - self.tgt_vocab = tgt_vocab + self.tgt_vocab = tgt_vocab or src_vocab self.add_eos_for_other_targets = add_eos_for_other_targets self.shuffle = shuffle self.add_bos_token = add_bos_token diff --git a/fairseq/data/multilingual/multilingual_data_manager.py b/fairseq/data/multilingual/multilingual_data_manager.py index 8c14f4e3ad..21fb23c047 100644 --- a/fairseq/data/multilingual/multilingual_data_manager.py +++ b/fairseq/data/multilingual/multilingual_data_manager.py @@ -236,7 +236,7 @@ def add_args(parser): ) parser.add_argument( "--virtual-epoch-size", - default=1000000, + default=None, type=int, help="virtual epoch size to speed up data loading", ) @@ -1040,3 +1040,38 @@ def load_sampled_multi_epoch_dataset( ) else: return self.load_into_concat_dataset(split, datasets, data_param_list) + + def load_sampled_multi_dataset( + self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs + ): + datasets, data_param_list = self.load_split_datasets( + split, training, epoch, combine, shard_epoch=shard_epoch, **kwargs + ) + if training and split == getattr(self.args, "train_subset", None): + sample_ratios = self.get_sampling_ratios(data_param_list, datasets, epoch) + return SampledMultiDataset( + OrderedDict(datasets), + epoch=epoch, + # valid and test datasets will be degerate to concating datasets: + sampling_ratios=sample_ratios, + eval_key=None, + collate_format=CollateFormat.single, + virtual_size=self.args.virtual_data_size, + split=split, + # if not using lang_tok altering, simplified to use the same collater + shared_collater=self._shared_collater(), + ) + else: + return self.load_into_concat_dataset(split, datasets, data_param_list) + + def load_dataset( + self, split, training, epoch=0, combine=False, shard_epoch=None, **kwargs + ): + if self.args.virtual_epoch_size is None: + return self.load_sampled_multi_dataset( + split, training, epoch, combine, shard_epoch, **kwargs + ) + else: + return self.load_sampled_multi_epoch_dataset( + split, training, epoch, combine, shard_epoch, **kwargs + ) diff --git a/fairseq/data/plasma_utils.py b/fairseq/data/plasma_utils.py index 2b12646783..f4bb6472d7 100644 --- a/fairseq/data/plasma_utils.py +++ b/fairseq/data/plasma_utils.py @@ -60,7 +60,7 @@ def start_server(self): def client(self): if self._client is None: assert self.path is not None - self._client = self.plasma.connect(self.path) + self._client = self.plasma.connect(self.path, num_retries=200) return self._client def __getstate__(self): diff --git a/fairseq/data/token_block_utils_fast.pyx b/fairseq/data/token_block_utils_fast.pyx index 5a2f16ec34..08af4f3061 100644 --- a/fairseq/data/token_block_utils_fast.pyx +++ b/fairseq/data/token_block_utils_fast.pyx @@ -170,7 +170,7 @@ cdef class DatasetSearcher(object): self.current_offset += to_consume self.current_i += to_consume else: - assert remaining > 0 + assert remaining >= 0 self.current_i += remaining self.current_index += 1 self.current_offset = 0 diff --git a/fairseq/dataclass/configs.py b/fairseq/dataclass/configs.py index ec921a41d7..caf4a7a2b8 100644 --- a/fairseq/dataclass/configs.py +++ b/fairseq/dataclass/configs.py @@ -17,6 +17,7 @@ GENERATION_DECODING_FORMAT_CHOICES, LOG_FORMAT_CHOICES, PIPELINE_CHECKPOINT_CHOICES, + PRINT_ALIGNMENT_CHOICES, ZERO_SHARDING_CHOICES, ) @@ -108,6 +109,12 @@ class CommonConfig(FairseqDataclass): "help": "Weights and Biases project name to use for logging" }, ) + azureml_logging: Optional[bool] = field( + default=False, + metadata={ + "help": "Log scalars to AzureML context" + }, + ) seed: int = field( default=1, metadata={"help": "pseudo random number generator seed"} ) @@ -173,6 +180,12 @@ class CommonConfig(FairseqDataclass): profile: bool = field( default=False, metadata={"help": "enable autograd profiler emit_nvtx"} ) + reset_logging: bool = field( + default=True, + metadata={ + "help": "when using Hydra, reset the logging at the beginning of training" + }, + ) @dataclass @@ -203,10 +216,6 @@ class DistributedTrainingConfig(FairseqDataclass): }, ) device_id: int = field( - default=0, - metadata={"help": "which GPU to use (usually configured automatically)"}, - ) - local_rank: int = field( default=0, metadata={ "help": "which GPU to use (usually configured automatically)", @@ -243,6 +252,13 @@ class DistributedTrainingConfig(FairseqDataclass): default=False, metadata={"help": "[deprecated] this is now defined per Criterion"}, ) + heartbeat_timeout: int = field( + default=-1, + metadata={ + "help": "kill the job if no progress is made in N seconds; " + "set to -1 to disable" + } + ) broadcast_buffers: bool = field( default=False, metadata={ @@ -400,14 +416,14 @@ class DatasetConfig(FairseqDataclass): default=False, metadata={"help": "disable validation"} ) max_tokens_valid: Optional[int] = field( - default=None, + default=II("dataset.max_tokens"), metadata={ "help": "maximum number of tokens in a validation batch" " (defaults to --max-tokens)" }, ) batch_size_valid: Optional[int] = field( - default=None, + default=II("dataset.batch_size"), metadata={ "help": "batch size of the validation batch (defaults to --batch-size)", "argparse_alias": "--max-sentences-valid", @@ -463,7 +479,7 @@ class OptimizationConfig(FairseqDataclass): " (note: this may be interpreted differently depending on --lr-scheduler)" }, ) - min_lr: float = field( + stop_min_lr: float = field( default=-1.0, metadata={"help": "stop training when the learning rate reaches this minimum"}, ) @@ -581,6 +597,13 @@ class CheckpointConfig(FairseqDataclass): "the checkpoint" }, ) + load_checkpoint_on_all_dp_ranks: bool = field( + default=False, + metadata={ + "help": "load checkpoints on all data parallel devices " + "(default: only load on rank 0 and broadcast to other devices)" + }, + ) model_parallel_size: int = II("common.model_parallel_size") distributed_rank: int = II("distributed_training.distributed_rank") @@ -728,10 +751,12 @@ class GenerationConfig(FairseqDataclass): default=-1.0, metadata={"help": "strength of diversity penalty for Diverse Siblings Search"}, ) - print_alignment: bool = field( - default=False, + print_alignment: Optional[PRINT_ALIGNMENT_CHOICES] = field( + default=None, metadata={ - "help": "if set, uses attention feedback to compute and print alignment to source tokens" + "help": "if set, uses attention feedback to compute and print alignment to source tokens " + "(valid options are: hard, soft, otherwise treated as hard alignment)", + "argparse_const": "hard", }, ) print_step: bool = field( @@ -813,9 +838,11 @@ class CommonEvalConfig(FairseqDataclass): post_process: Optional[str] = field( default=None, metadata={ - "help": "post-process text by removing pre-processing such as BPE, letter segmentation, etc " - "(valid options are: sentencepiece, wordpiece, letter, _EOW, none, otherwise treated as BPE symbol)", - "argparse_const": "@@ ", + "help": ( + "post-process text by removing BPE, letter segmentation, etc. " + "Valid options can be found in fairseq.data.utils.post_process." + ), + "argparse_const": "subword_nmt", "argparse_alias": "--remove-bpe", }, ) @@ -874,7 +901,7 @@ class InteractiveConfig(FairseqDataclass): @dataclass -class FairseqConfig(object): +class FairseqConfig(FairseqDataclass): common: CommonConfig = CommonConfig() common_eval: CommonEvalConfig = CommonEvalConfig() distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() diff --git a/fairseq/dataclass/constants.py b/fairseq/dataclass/constants.py index fad04f3482..46881786a8 100644 --- a/fairseq/dataclass/constants.py +++ b/fairseq/dataclass/constants.py @@ -3,11 +3,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from enum import Enum +from enum import Enum, EnumMeta from typing import List -class StrEnum(Enum): +class StrEnumMeta(EnumMeta): + # this is workaround for submitit pickling leading to instance checks failing in hydra for StrEnum, see + # https://github.com/facebookresearch/hydra/issues/1156 + @classmethod + def __instancecheck__(cls, other): + return "enum" in str(type(other)) + + +class StrEnum(Enum, metaclass=StrEnumMeta): def __str__(self): return self.value @@ -36,3 +44,4 @@ def ChoiceEnum(choices: List[str]): ) ZERO_SHARDING_CHOICES = ChoiceEnum(["none", "os"]) PIPELINE_CHECKPOINT_CHOICES = ChoiceEnum(["always", "never", "except_last"]) +PRINT_ALIGNMENT_CHOICES = ChoiceEnum(["hard", "soft"]) diff --git a/fairseq/dataclass/utils.py b/fairseq/dataclass/utils.py index d73af240b9..45e7ed9170 100644 --- a/fairseq/dataclass/utils.py +++ b/fairseq/dataclass/utils.py @@ -4,19 +4,23 @@ # LICENSE file in the root directory of this source tree. import ast +import inspect +import logging import os import re from argparse import ArgumentError, ArgumentParser, Namespace from dataclasses import _MISSING_TYPE, MISSING from enum import Enum -import inspect from typing import Any, Dict, List, Tuple, Type from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.configs import FairseqConfig +from hydra.core.global_hydra import GlobalHydra from hydra.experimental import compose, initialize from omegaconf import DictConfig, OmegaConf, open_dict +logger = logging.getLogger(__name__) + def eval_str_list(x, x_type=float): if x is None: @@ -153,7 +157,11 @@ def get_kwargs_from_dc( if isinstance(kwargs["default"], str) and kwargs["default"].startswith( "${" ): - continue + if kwargs["help"] is None: + # this is a field with a name that will be added elsewhere + continue + else: + del kwargs["default"] if delete_default: del kwargs["default"] try: @@ -210,7 +218,9 @@ def get_default(f): isinstance(val, str) and not val.startswith("${") # not interpolation and field_type != str - and not issubclass(field_type, Enum) # not choices enum + and ( + not inspect.isclass(field_type) or not issubclass(field_type, Enum) + ) # not choices enum ): # upgrade old models that stored complex parameters as string val = ast.literal_eval(val) @@ -218,20 +228,35 @@ def get_default(f): if isinstance(val, tuple): val = list(val) - if getattr(v.type, "__origin__", None) is List: + v_type = getattr(v.type, "__origin__", None) + if ( + (v_type is List or v_type is list) + # skip interpolation + and not (isinstance(val, str) and val.startswith("${")) + ): # if type is int but val is float, then we will crash later - try to convert here t_args = v.type.__args__ if len(t_args) == 1: val = list(map(t_args[0], val)) + elif val is not None and (field_type is int or field_type is bool or field_type is float): + try: + val = field_type(val) + except: + pass # ignore errors here, they are often from interpolation args if val is None: overrides.append("{}.{}=null".format(sub_node, k)) elif val == "": overrides.append("{}.{}=''".format(sub_node, k)) elif isinstance(val, str): + val = val.replace("'", r"\'") overrides.append("{}.{}='{}'".format(sub_node, k, val)) elif isinstance(val, FairseqDataclass): overrides += _override_attr(f"{sub_node}.{k}", type(val), args) + elif isinstance(val, Namespace): + sub_overrides, _ = override_module_args(val) + for so in sub_overrides: + overrides.append(f"{sub_node}.{k}.{so}") else: overrides.append("{}.{}={}".format(sub_node, k, val)) @@ -320,8 +345,15 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: # configs will be in fairseq/config after installation config_path = os.path.join("..", "config") + GlobalHydra.instance().clear() + with initialize(config_path=config_path): - composed_cfg = compose("config", overrides=overrides, strict=False) + try: + composed_cfg = compose("config", overrides=overrides, strict=False) + except: + logger.error("Error when composing. Overrides: " + str(overrides)) + raise + for k in deletes: composed_cfg[k] = None @@ -373,7 +405,8 @@ def convert_namespace_to_omegaconf(args: Namespace) -> DictConfig: def populate_dataclass( - args: Namespace, dataclass: FairseqDataclass + dataclass: FairseqDataclass, + args: Namespace, ) -> FairseqDataclass: for k in dataclass.__dataclass_fields__.keys(): if k.startswith("_"): @@ -382,7 +415,7 @@ def populate_dataclass( if hasattr(args, k): setattr(dataclass, k, getattr(args, k)) - return dataclass + return dataclass def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): @@ -395,6 +428,9 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): # "k in cfg" will return false if its a "mandatory value (e.g. ???)" if k in cfg and isinstance(cfg[k], DictConfig): overwrite_args_by_name(cfg[k], overrides) + elif k in cfg and isinstance(cfg[k], Namespace): + for override_key, val in overrides.items(): + setattr(cfg[k], override_key, val) elif k in overrides: if ( k in REGISTRIES @@ -409,9 +445,8 @@ def overwrite_args_by_name(cfg: DictConfig, overrides: Dict[str, any]): cfg[k] = overrides[k] -def merge_with_parent(dc: FairseqDataclass, cfg: DictConfig): - dc_instance = DictConfig(dc) - dc_instance.__dict__["_parent"] = cfg.__dict__["_parent"] - cfg = OmegaConf.merge(dc_instance, cfg) - OmegaConf.set_struct(cfg, True) - return cfg +def merge_with_parent(dc: FairseqDataclass, cfg: FairseqDataclass): + merged_cfg = OmegaConf.merge(dc, cfg) + merged_cfg.__dict__["_parent"] = cfg.__dict__["_parent"] + OmegaConf.set_struct(merged_cfg, True) + return merged_cfg diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 9059d8aa2b..8f98ac88f9 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -14,6 +14,7 @@ import warnings from argparse import Namespace from collections import OrderedDict +from dataclasses import dataclass from typing import Any, Dict, List, Mapping, Optional import torch @@ -160,8 +161,9 @@ def infer_init_method(cfg: DistributedTrainingConfig, force_distributed=False): elif cfg.distributed_world_size > 1 or force_distributed: # fallback for single node with multiple GPUs - assert cfg.distributed_world_size <= torch.cuda.device_count(), \ - f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" + assert ( + cfg.distributed_world_size <= torch.cuda.device_count() + ), f"world size is {cfg.distributed_world_size} but have {torch.cuda.device_count()} available devices" port = random.randint(10000, 20000) cfg.distributed_init_method = "tcp://localhost:{port}".format(port=port) @@ -375,8 +377,10 @@ def get_world_size(group): assert group[0] == "tpu" my_group = _find_my_group(group[1]) return len(my_group) - else: + elif torch.distributed.is_initialized(): return dist.get_world_size(group=group) + else: + return 1 def get_global_group(): @@ -415,6 +419,7 @@ def get_data_parallel_group(): global _USE_MEGATRON if _USE_MEGATRON: from fairseq.model_parallel.megatron import mpu + return mpu.get_data_parallel_group() else: return get_global_group() @@ -434,6 +439,7 @@ def get_model_parallel_group(): global _USE_MEGATRON if _USE_MEGATRON: from fairseq.model_parallel.megatron import mpu + return mpu.get_model_parallel_group() else: return None @@ -637,44 +643,127 @@ def get_from_stack(key): return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) -# From fairscale/optim/utils.py +def broadcast_tensors( + tensors: Optional[List[torch.Tensor]], + src_rank: int, + group: object, + dist_device: Optional[torch.device] = None, +) -> List[torch.Tensor]: + """ + Broadcasts a list of tensors without other (non-src) ranks needing to know + the dtypes/shapes of the tensors. + """ + if dist_device is None: + if torch.distributed.get_backend(group) == "nccl": + dist_device = torch.device("cuda") + else: + dist_device = torch.device("cpu") + + # share metadata first to simplify transfer + is_src_rank = (get_rank(group) == src_rank) + if is_src_rank: + metadata = [ + {"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors + ] + metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device) + else: + metadata = _broadcast_object_slow(None, src_rank, group, dist_device) + + out_tensors = [] + for i, meta in enumerate(metadata): + if is_src_rank: + tensor = tensors[i] + broadcast(tensors[i].to(dist_device), src=src_rank, group=group) + else: + tensor = torch.zeros( + [meta["size"].numel()], dtype=meta["dtype"], device=dist_device + ) + broadcast(tensor, src=src_rank, group=group) + tensor = tensor.view(meta["size"]).to(meta["device"]) + out_tensors.append(tensor) + return out_tensors + + def broadcast_object( obj: Any, src_rank: int, group: object, dist_device: Optional[torch.device] = None, - dist_length_dtype: Optional[torch.dtype] = torch.long, - dist_dtype: Optional[torch.dtype] = torch.uint8, ) -> Any: - """ - Either broadcast from master to the fleet (default), - or use the src setting as the original rank. - """ + """Broadcast an arbitrary Python object to other workers.""" if dist_device is None: if torch.distributed.get_backend(group) == "nccl": dist_device = torch.device("cuda") else: dist_device = torch.device("cpu") + if get_rank(group) == src_rank: + # split the tensors from the non-tensors so we can broadcast them + # directly, avoiding unnecessary serialization/deserialization + tensors = [] + obj = _split_tensors_from_obj(obj, tensors) + obj = _broadcast_object_slow(obj, src_rank, group, dist_device) + tensors = broadcast_tensors(tensors, src_rank, group, dist_device) + else: + obj = _broadcast_object_slow(None, src_rank, group, dist_device) + tensors = broadcast_tensors(None, src_rank, group, dist_device) + return _put_tensors_in_obj(obj, tensors) + + +def _broadcast_object_slow( + obj: Any, src_rank: int, group: object, dist_device: torch.device, +) -> Any: if get_rank(group) == src_rank: # Emit data buffer = io.BytesIO() torch.save(obj, buffer) - data = bytearray(buffer.getbuffer()) - length_tensor = torch.tensor( - [len(data)], dtype=dist_length_dtype, device=dist_device - ) - broadcast(length_tensor, src=src_rank, group=group) - data_send_tensor = torch.tensor(data, dtype=dist_dtype, device=dist_device) - broadcast(data_send_tensor, src=src_rank, group=group) + buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device) + length = torch.LongTensor([len(buffer)]).to(dist_device) + broadcast(length, src=src_rank, group=group) + broadcast(buffer, src=src_rank, group=group) else: # Fetch from the source - length_tensor = torch.tensor([0], dtype=dist_length_dtype, device=dist_device) - broadcast(length_tensor, src=src_rank, group=group) - data_recv_tensor = torch.zeros( - [int(length_tensor.item())], dtype=dist_dtype, device=dist_device - ) - broadcast(data_recv_tensor, src=src_rank, group=group) - buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) + length = torch.LongTensor([0]).to(dist_device) + broadcast(length, src=src_rank, group=group) + buffer = torch.ByteTensor(int(length.item())).to(dist_device) + broadcast(buffer, src=src_rank, group=group) + buffer = io.BytesIO(buffer.cpu().numpy()) obj = torch.load(buffer, map_location="cpu") return obj + + +@dataclass(frozen=True) +class _TensorPlaceholder: + index: int + + +def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: + if torch.is_tensor(obj): + placeholder = _TensorPlaceholder(index=len(tensors)) + tensors.append(obj) + return placeholder + elif isinstance(obj, dict): + return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()} + elif isinstance(obj, list): + return [_split_tensors_from_obj(v, tensors) for v in obj] + elif isinstance(obj, tuple): + return tuple(_split_tensors_from_obj(v, tensors) for v in obj) + elif isinstance(obj, set): + return {_split_tensors_from_obj(v, tensors) for v in obj} + else: + return obj + + +def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: + if isinstance(obj, _TensorPlaceholder): + return tensors[obj.index] + elif isinstance(obj, dict): + return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()} + elif isinstance(obj, list): + return [_put_tensors_in_obj(v, tensors) for v in obj] + elif isinstance(obj, tuple): + return tuple(_put_tensors_in_obj(v, tensors) for v in obj) + elif isinstance(obj, set): + return {_put_tensors_in_obj(v, tensors) for v in obj} + else: + return obj diff --git a/fairseq/file_io.py b/fairseq/file_io.py index d667256922..7d6c28dccd 100644 --- a/fairseq/file_io.py +++ b/fairseq/file_io.py @@ -5,14 +5,30 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import logging import os import shutil from typing import List, Optional +logger = logging.getLogger(__file__) + + try: from fvcore.common.file_io import PathManager as FVCorePathManager + try: + # [FB only - for now] AWS PathHandler for PathManager + from .fb_pathhandlers import S3PathHandler + + FVCorePathManager.register_handler(S3PathHandler()) + except KeyError: + logging.warning("S3PathHandler already registered.") + except ImportError: + logging.debug( + "S3PathHandler couldn't be imported. Either missing fb-only files, or boto3 module." + ) + except ImportError: FVCorePathManager = None @@ -97,7 +113,7 @@ def rm(path: str) -> None: @staticmethod def chmod(path: str, mode: int) -> None: - if "manifold" not in path: + if not PathManager.path_requires_pathmanager(path): os.chmod(path, mode) @staticmethod @@ -114,3 +130,21 @@ def copy_from_local( local_path=local_path, dst_path=dst_path, overwrite=overwrite, **kwargs ) return shutil.copyfile(local_path, dst_path) + + @staticmethod + def path_requires_pathmanager(path: str) -> bool: + """Do we require PathManager to access given path?""" + if FVCorePathManager: + for p in FVCorePathManager._path_handlers.keys(): + if path.startswith(p): + return True + return False + + @staticmethod + def supports_rename(path: str) -> bool: + # PathManager doesn't yet support renames + return not PathManager.path_requires_pathmanager(path) + + @staticmethod + def rename(src: str, dst: str): + os.rename(src, dst) diff --git a/fairseq/hub_utils.py b/fairseq/hub_utils.py index 3be7078b7a..7de2e2b0d4 100644 --- a/fairseq/hub_utils.py +++ b/fairseq/hub_utils.py @@ -60,6 +60,8 @@ def from_pretrained( "code": "bpe_codes", "bpecodes": "bpe_codes", "sentencepiece.bpe.model": "sentencepiece_model", + "merges.txt": "bpe_merges", + "vocab.json": "bpe_vocab", }.items(): path = os.path.join(model_path, file) if os.path.exists(path): @@ -157,7 +159,7 @@ def generate( )[0] # build generator using current args as well as any kwargs - gen_args = copy.copy(self.cfg) + gen_args = copy.deepcopy(self.cfg.generation) with open_dict(gen_args): gen_args.beam = beam for k, v in kwargs.items(): @@ -180,7 +182,7 @@ def generate( if verbose: def getarg(name, default): - return getattr(gen_args, name, getattr(self.args, name, default)) + return getattr(gen_args, name, getattr(self.cfg, name, default)) for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs): src_str_with_unk = self.string(source_tokens) diff --git a/fairseq/logging/progress_bar.py b/fairseq/logging/progress_bar.py index 3183d2f476..e2a1711121 100644 --- a/fairseq/logging/progress_bar.py +++ b/fairseq/logging/progress_bar.py @@ -34,6 +34,8 @@ def progress_bar( tensorboard_logdir: Optional[str] = None, default_log_format: str = "tqdm", wandb_project: Optional[str] = None, + wandb_run_name: Optional[str] = None, + azureml_logging: Optional[bool] = False, ): if log_format is None: log_format = default_log_format @@ -62,7 +64,10 @@ def progress_bar( bar = TensorboardProgressBarWrapper(bar, tensorboard_logdir) if wandb_project: - bar = WandBProgressBarWrapper(bar, wandb_project) + bar = WandBProgressBarWrapper(bar, wandb_project, run_name=wandb_run_name) + + if azureml_logging: + bar = AzureMLProgressBarWrapper(bar) return bar @@ -356,6 +361,8 @@ def _log_to_tensorboard(self, stats, tag=None, step=None): writer.add_scalar(key, stats[key].val, step) elif isinstance(stats[key], Number): writer.add_scalar(key, stats[key], step) + elif torch.is_tensor(stats[key]) and stats[key].numel() == 1: + writer.add_scalar(key, stats[key].item(), step) writer.flush() @@ -368,15 +375,15 @@ def _log_to_tensorboard(self, stats, tag=None, step=None): class WandBProgressBarWrapper(BaseProgressBar): """Log to Weights & Biases.""" - def __init__(self, wrapped_bar, wandb_project): + def __init__(self, wrapped_bar, wandb_project, run_name=None): self.wrapped_bar = wrapped_bar if wandb is None: - logger.warning('wandb not found, pip install wandb') + logger.warning("wandb not found, pip install wandb") return # reinit=False to ensure if wandb.init() is called multiple times # within one process it still references the same run - wandb.init(project=wandb_project, reinit=False) + wandb.init(project=wandb_project, reinit=False, name=run_name) def __iter__(self): return iter(self.wrapped_bar) @@ -394,13 +401,63 @@ def print(self, stats, tag=None, step=None): def _log_to_wandb(self, stats, tag=None, step=None): if wandb is None: return + if step is None: + step = stats["num_updates"] + + prefix = "" if tag is None else tag + "/" + + for key in stats.keys() - {"num_updates"}: + if isinstance(stats[key], AverageMeter): + wandb.log({prefix + key: stats[key].val}, step=step) + elif isinstance(stats[key], Number): + wandb.log({prefix + key: stats[key]}, step=step) + + +try: + from azureml.core import Run +except ImportError: + Run = None + + +class AzureMLProgressBarWrapper(BaseProgressBar): + """Log to Azure ML""" + + def __init__(self, wrapped_bar): + self.wrapped_bar = wrapped_bar + if Run is None: + logger.warning("azureml.core not found, pip install azureml-core") + return + self.run = Run.get_context() + + def __exit__(self, *exc): + if Run is not None: + self.run.complete() + return False + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to AzureML""" + self._log_to_azureml(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats""" + self._log_to_azureml(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def _log_to_azureml(self, stats, tag=None, step=None): + if Run is None: + return if step is None: step = stats['num_updates'] prefix = '' if tag is None else tag + '/' for key in stats.keys() - {'num_updates'}: + name = prefix + key if isinstance(stats[key], AverageMeter): - wandb.log({prefix + key: stats[key].val}, step=step) + self.run.log_row(name=name, **{'step': step, key: stats[key].val}) elif isinstance(stats[key], Number): - wandb.log({prefix + key: stats[key]}, step=step) + self.run.log_row(name=name, **{'step': step, key: stats[key]}) diff --git a/fairseq/model_parallel/modules/multihead_attention.py b/fairseq/model_parallel/modules/multihead_attention.py index 4164bf9131..8eb9d09dad 100644 --- a/fairseq/model_parallel/modules/multihead_attention.py +++ b/fairseq/model_parallel/modules/multihead_attention.py @@ -93,11 +93,6 @@ def __init__( embed_dim, embed_dim, bias=bias, input_is_parallel=True ) - self.tpu = False - - def prepare_for_tpu_(self, **kwargs): - self.tpu = True - def forward( self, query, @@ -123,6 +118,8 @@ def forward( assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] + is_tpu = query.device.type == "xla" + if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if saved_state is not None and "prev_key" in saved_state: @@ -250,7 +247,7 @@ def forward( attn_weights = attn_weights.view( bsz, self.num_heads_partition, tgt_len, src_len ) - if not self.tpu: + if not is_tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), diff --git a/fairseq/models/__init__.py b/fairseq/models/__init__.py index b987966749..135530d5c0 100644 --- a/fairseq/models/__init__.py +++ b/fairseq/models/__init__.py @@ -8,11 +8,9 @@ import importlib import os -import fairseq from fairseq.dataclass import FairseqDataclass -from fairseq.dataclass.utils import merge_with_parent +from fairseq.dataclass.utils import merge_with_parent, populate_dataclass from hydra.core.config_store import ConfigStore -from omegaconf import DictConfig, OmegaConf from .composite_encoder import CompositeEncoder from .distributed_fairseq_model import DistributedFairseqModel @@ -65,7 +63,11 @@ def build_model(cfg: FairseqDataclass, task): cfg = cfg[model_type] else: raise Exception( - "Could not infer model type from directory. Please add _name field to indicate model type" + "Could not infer model type from directory. Please add _name field to indicate model type. " + "Available models: " + + str(MODEL_DATACLASS_REGISTRY.keys()) + + " Requested model type: " + + model_type ) if model_type in ARCH_MODEL_REGISTRY: @@ -78,9 +80,18 @@ def build_model(cfg: FairseqDataclass, task): if model_type in MODEL_DATACLASS_REGISTRY: # set defaults from dataclass. note that arch name and model name can be the same dc = MODEL_DATACLASS_REGISTRY[model_type] - cfg = merge_with_parent(dc(), cfg) - - assert model is not None, f"Could not infer model type from {cfg}" + if isinstance(cfg, argparse.Namespace): + cfg = populate_dataclass(dc(), cfg) + else: + cfg = merge_with_parent(dc(), cfg) + + assert model is not None, ( + f"Could not infer model type from {cfg}. " + f"Available models: " + + str(MODEL_DATACLASS_REGISTRY.keys()) + + " Requested model type: " + + model_type + ) return model.build_model(cfg, task) diff --git a/fairseq/models/bart/model.py b/fairseq/models/bart/model.py index e105d6fc46..44f03b0162 100644 --- a/fairseq/models/bart/model.py +++ b/fairseq/models/bart/model.py @@ -6,6 +6,7 @@ BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension """ +from typing import Optional import logging @@ -24,6 +25,8 @@ @register_model("bart") class BARTModel(TransformerModel): + __jit_unused_properties__ = ["supported_targets"] + @classmethod def hub_models(cls): return { @@ -41,6 +44,8 @@ def __init__(self, args, encoder, decoder): self.apply(init_bert_params) self.classification_heads = nn.ModuleDict() + if hasattr(self.encoder, "dictionary"): + self.eos: int = self.encoder.dictionary.eos() @staticmethod def add_args(parser): @@ -71,10 +76,12 @@ def forward( src_tokens, src_lengths, prev_output_tokens, - features_only=False, - classification_head_name=None, - token_embeddings=None, - **kwargs, + features_only: bool = False, + classification_head_name: Optional[str] = None, + token_embeddings: Optional[torch.Tensor] = None, + return_all_hiddens: bool = True, + alignment_layer: Optional[int] = None, + alignment_heads: Optional[int] = None, ): if classification_head_name is not None: features_only = True @@ -83,22 +90,27 @@ def forward( src_tokens, src_lengths=src_lengths, token_embeddings=token_embeddings, - **kwargs, + return_all_hiddens=return_all_hiddens ) x, extra = self.decoder( prev_output_tokens, encoder_out=encoder_out, features_only=features_only, - **kwargs, + alignment_layer=alignment_layer, + alignment_heads=alignment_heads, + src_lengths=src_lengths, + return_all_hiddens=return_all_hiddens, ) - + eos: int = self.eos if classification_head_name is not None: sentence_representation = x[ - src_tokens.eq(self.encoder.dictionary.eos()), : + src_tokens.eq(eos), : ].view(x.size(0), -1, x.size(-1))[:, -1, :] - x = self.classification_heads[classification_head_name]( - sentence_representation - ) + for k, head in self.classification_heads.items(): + # for torch script only supports iteration + if k == classification_head_name: + x = head(sentence_representation) + break return x, extra @classmethod diff --git a/fairseq/models/distributed_fairseq_model.py b/fairseq/models/distributed_fairseq_model.py index b78a0125e3..909b3757b2 100644 --- a/fairseq/models/distributed_fairseq_model.py +++ b/fairseq/models/distributed_fairseq_model.py @@ -4,6 +4,10 @@ # LICENSE file in the root directory of this source tree. import inspect +import logging +import os +import signal +import threading import torch import torch.nn as nn @@ -12,6 +16,9 @@ from fairseq.legacy_distributed_data_parallel import LegacyDistributedDataParallel +logger = logging.getLogger(__name__) + + _GOSSIP_DISABLED = False try: import gossip @@ -97,12 +104,41 @@ def DistributedFairseqModel(args, model, process_group): else: raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) + heartbeat_timeout = getattr(args, "heartbeat_timeout", -1) + class _DistributedFairseqModel(ddp_class): - """Extend DistributedDataParallel to check for missing - attributes in the wrapped module.""" + """ + Extend DistributedDataParallel to check for missing attributes in the + wrapped module and to add a timeout to kill the job if no progress is + made (--heartbeat-timeout). + """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self._heartbeat_timeout = heartbeat_timeout + if self._heartbeat_timeout > 0: + self._heartbeat = threading.Event() + self._heartbeat_thread = threading.Thread( + target=self._check_heartbeat, + args=(os.getpid(),), + daemon=True, + ) + self._heartbeat_thread.start() + else: + self._heartbeat = None + + def _check_heartbeat(self, parent_pid): + self._heartbeat.wait() # wait for the first forward pass + while True: + self._heartbeat.clear() + success = self._heartbeat.wait(timeout=self._heartbeat_timeout) + if not success: + logger.error(( + "Killing job for not making progress in {} seconds. " + "Set --heartbeat-timeout=-1 to disable this timeout." + ).format(int(self._heartbeat_timeout))) + os.kill(parent_pid, signal.SIGKILL) + return def __getattr__(self, name): wrapped_module = super().__getattr__("module") @@ -110,6 +146,11 @@ def __getattr__(self, name): return getattr(wrapped_module, name) return super().__getattr__(name) + def forward(self, *args, **kwargs): + if self._heartbeat is not None: + self._heartbeat.set() + return super().forward(*args, **kwargs) + return _DistributedFairseqModel(**init_kwargs) diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 926d952f77..244cbc0c66 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -223,21 +223,6 @@ def apply_prepare_for_onnx_export_(module): self.apply(apply_prepare_for_onnx_export_) - def prepare_for_tpu_(self, **kwargs): - """Optionally modify model for use on TPUs.""" - seen = set() - - def apply_prepare_for_tpu_(module): - if ( - module != self - and hasattr(module, "prepare_for_tpu_") - and module not in seen - ): - seen.add(module) - module.prepare_for_tpu_(**kwargs) - - self.apply(apply_prepare_for_tpu_) - @classmethod def from_pretrained( cls, diff --git a/fairseq/models/huggingface/transformers b/fairseq/models/huggingface/transformers deleted file mode 160000 index 839f8a563c..0000000000 --- a/fairseq/models/huggingface/transformers +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 839f8a563cefcb7f2048b310024c217e7829a198 diff --git a/fairseq/models/nat/fairseq_nat_model.py b/fairseq/models/nat/fairseq_nat_model.py index 1dbc29d0f4..b09394112f 100644 --- a/fairseq/models/nat/fairseq_nat_model.py +++ b/fairseq/models/nat/fairseq_nat_model.py @@ -18,18 +18,23 @@ def ensemble_encoder(func): def wrapper(self, *args, **kwargs): if self.ensemble_models is None or len(self.ensemble_models) == 1: return func(self, *args, **kwargs) - encoder_outs = [func(model, *args, **kwargs) for model in self.ensemble_models] - _encoder_out = encoder_outs[0] + encoder_outs = [func(model, *args, **kwargs, return_all_hiddens=True) for model in self.ensemble_models] + _encoder_out = encoder_outs[0].copy() def stack(key): - outs = [getattr(e, key) for e in encoder_outs] - return torch.stack(outs, -1) if outs[0] is not None else None + outs = [e[key][0] for e in encoder_outs] + return [torch.stack(outs, -1) if outs[0] is not None else None] - return _encoder_out._replace( - encoder_out=stack("encoder_out"), - encoder_embedding=stack("encoder_embedding"), - encoder_states=stack("encoder_states"), - ) + _encoder_out["encoder_out"] = stack("encoder_out") + _encoder_out["encoder_embedding"] = stack("encoder_embedding") + + num_layers = len(_encoder_out["encoder_states"]) + if num_layers > 0: + _encoder_out["encoder_states"] = [ + torch.stack([e["encoder_states"][i] for e in encoder_outs], -1) + for i in range(num_layers) + ] + return _encoder_out return wrapper @@ -41,12 +46,18 @@ def wrapper(self, normalize=False, encoder_out=None, *args, **kwargs): self, normalize=normalize, encoder_out=encoder_out, *args, **kwargs ) + def _replace(encoder_out, new_val): + new_encoder_out = encoder_out.copy() + new_encoder_out["encoder_out"] = [new_val] + return new_encoder_out + action_outs = [ func( model, normalize=normalize, - encoder_out=encoder_out._replace( - encoder_out=encoder_out.encoder_out[:, :, :, i] + encoder_out=_replace( + encoder_out, + encoder_out["encoder_out"][0][:, :, :, i] ), *args, **kwargs diff --git a/fairseq/models/nat/levenshtein_transformer.py b/fairseq/models/nat/levenshtein_transformer.py index f7a3f003ca..9377c3c7f5 100644 --- a/fairseq/models/nat/levenshtein_transformer.py +++ b/fairseq/models/nat/levenshtein_transformer.py @@ -149,11 +149,11 @@ def forward_decoder( if max_ratio is None: max_lens = torch.zeros_like(output_tokens).fill_(255) else: - if encoder_out.encoder_padding_mask is None: - max_src_len = encoder_out.encoder_out.size(0) - src_lens = encoder_out.encoder_out.new(bsz).fill_(max_src_len) + if not encoder_out["encoder_padding_mask"]: + max_src_len = encoder_out["encoder_out"].size(0) + src_lens = encoder_out["encoder_out"].new(bsz).fill_(max_src_len) else: - src_lens = (~encoder_out.encoder_padding_mask).sum(1) + src_lens = (~encoder_out["encoder_padding_mask"][0]).sum(1) max_lens = (src_lens * max_ratio).clamp(min=10).long() # delete words @@ -256,7 +256,7 @@ def initialize_output_tokens(self, encoder_out, src_tokens): initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size() - ).type_as(encoder_out.encoder_out) + ).type_as(encoder_out["encoder_out"][0]) return DecoderOut( output_tokens=initial_output_tokens, @@ -357,8 +357,15 @@ def extract_features( for _, layer in enumerate(layers[:early_exit]): x, attn, _ = layer( x, - encoder_out.encoder_out if encoder_out is not None else None, - encoder_out.encoder_padding_mask if encoder_out is not None else None, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, self_attn_mask=None, self_attn_padding_mask=decoder_padding_mask, ) diff --git a/fairseq/models/nat/nonautoregressive_ensembles.py b/fairseq/models/nat/nonautoregressive_ensembles.py index 46bb8aac43..705a04fb49 100644 --- a/fairseq/models/nat/nonautoregressive_ensembles.py +++ b/fairseq/models/nat/nonautoregressive_ensembles.py @@ -83,14 +83,13 @@ def forward_decoder( if max_ratio is None: max_lens = output_tokens.new().fill_(255) else: - if encoder_outs[0].encoder_padding_mask is None: + if not encoder_outs[0]["encoder_padding_mask"]: src_lens = ( - encoder_outs[0] - .encoder_out.new(bsz) - .fill_(encoder_outs[0].encoder_out.size(1)) + encoder_outs[0]["encoder_out"][0].new(bsz) + .fill_(encoder_outs[0]["encoder_out"][0].size(1)) ) else: - src_lens = (~encoder_outs[0].encoder_padding_mask).sum(1) + src_lens = (~encoder_outs[0]["encoder_padding_mask"][0]).sum(1) max_lens = (src_lens * max_ratio).clamp(min=10).long() # delete words diff --git a/fairseq/models/nat/nonautoregressive_transformer.py b/fairseq/models/nat/nonautoregressive_transformer.py index 735297fc29..d114202d25 100644 --- a/fairseq/models/nat/nonautoregressive_transformer.py +++ b/fairseq/models/nat/nonautoregressive_transformer.py @@ -163,7 +163,7 @@ def initialize_output_tokens(self, encoder_out, src_tokens): initial_output_scores = initial_output_tokens.new_zeros( *initial_output_tokens.size() - ).type_as(encoder_out.encoder_out) + ).type_as(encoder_out["encoder_out"][0]) return DecoderOut( output_tokens=initial_output_tokens, @@ -233,8 +233,11 @@ def forward(self, normalize, encoder_out, prev_output_tokens, step=0, **unused): @ensemble_decoder def forward_length(self, normalize, encoder_out): - enc_feats = encoder_out.encoder_out # T x B x C - src_masks = encoder_out.encoder_padding_mask # B x T or None + enc_feats = encoder_out["encoder_out"][0] # T x B x C + if len(encoder_out["encoder_padding_mask"]) > 0: + src_masks = encoder_out["encoder_padding_mask"][0] # B x T + else: + src_masks = None enc_feats = _mean_pooling(enc_feats, src_masks) if self.sg_length_pred: enc_feats = enc_feats.detach() @@ -264,8 +267,11 @@ def extract_features( """ # embedding if embedding_copy: - src_embd = encoder_out.encoder_embedding - src_mask = encoder_out.encoder_padding_mask + src_embd = encoder_out["encoder_embedding"][0] + if len(encoder_out["encoder_padding_mask"]) > 0: + src_mask = encoder_out["encoder_padding_mask"][0] + else: + src_mask = None src_mask = ( ~src_mask if src_mask is not None @@ -297,8 +303,15 @@ def extract_features( x, attn, _ = layer( x, - encoder_out.encoder_out if encoder_out is not None else None, - encoder_out.encoder_padding_mask if encoder_out is not None else None, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, self_attn_mask=None, self_attn_padding_mask=decoder_padding_mask, ) @@ -353,8 +366,11 @@ def forward_copying_source(self, src_embeds, src_masks, tgt_masks): return copied_embedding def forward_length_prediction(self, length_out, encoder_out, tgt_tokens=None): - enc_feats = encoder_out.encoder_out # T x B x C - src_masks = encoder_out.encoder_padding_mask # B x T or None + enc_feats = encoder_out["encoder_out"][0] # T x B x C + if len(encoder_out["encoder_padding_mask"]) > 0: + src_masks = encoder_out["encoder_padding_mask"][0] # B x T + else: + src_masks = None if self.pred_length_offset: if src_masks is None: src_lengs = enc_feats.new_ones(enc_feats.size(1)).fill_( diff --git a/fairseq/models/probed_model.py b/fairseq/models/probed_model.py new file mode 100644 index 0000000000..5dbff35c76 --- /dev/null +++ b/fairseq/models/probed_model.py @@ -0,0 +1,187 @@ +import torch.nn +import torch.nn.functional as F +import logging + +logger = logging.getLogger(__name__) + + +def _pick_nth(tensor_or_sequence, which=0): + if isinstance(tensor_or_sequence, (list, tuple)): + tensor_or_sequence = tensor_or_sequence[which] + else: + if which > 0: + raise ValueError("Requested output not present") + return tensor_or_sequence + + +def _detach(tensor_or_iterable): + if isinstance(tensor_or_iterable, (list, tuple)): + return [_detach(elem) for elem in tensor_or_iterable] + elif isinstance(tensor_or_iterable, dict): + return {k: _detach(v) for k, v in tensor_or_iterable.items()} + else: + return tensor_or_iterable.detach() + + +def _compile_selector(selector, default): + if selector is None: + return default + elif isinstance(selector, str): + return eval(selector) + else: + return selector + + +class Probe(torch.nn.Module): + def __init__( + self, + model, + module_name, + backprop_to_main=False, + output_selector=None, + target_selector=None, + loss_weigth=1.0, + ): + super().__init__() + self._saved_tensor = None + self._target_selector = _compile_selector( + target_selector, default=lambda x: {"target": x} + ) + self._loss_weigth = loss_weigth + + output_selector = _compile_selector( + output_selector, default=lambda x: {"output": x} + ) + hook_fn = self._get_hook(output_selector, backprop_to_main) + self._attach(model, module_name, hook_fn) + if backprop_to_main: + logger.info("Registered an auxiliary loss at %s: %s", module_name, self) + else: + logger.info("Registered a probe at %s: %s", module_name, self) + + def _get_hook(self, output_selector, backprop_to_main): + def hook_fn(mod, unused_inputs, outputs): + outputs = output_selector(outputs) + if backprop_to_main: + self._saved_tensor = outputs + else: + self._saved_tensor = _detach(outputs) + + return hook_fn + + def _attach(self, model, module_name, hook_fn): + module = dict(model.named_modules())[module_name] + module.register_forward_hook(hook_fn) + + def compute_loss(self, minibatch): + self._saved_tensor.update(self._target_selector(minibatch)) + ret = self(**self._saved_tensor) + self._saved_tensor = None + return ret + + +class FeedForwardProbe(Probe): + def __init__( + self, + layer_dims, + activation="torch.nn.ReLU", + loss="torch.nn.CrossEntropyLoss", + **kwargs, + ): + super().__init__(**kwargs) + activation = eval(activation) + in_dim, last_dim, *rest = layer_dims + modules = [torch.nn.Linear(in_dim, last_dim)] + for dim in rest: + modules.append(activation()) + modules.append(torch.nn.Linear(last_dim, dim)) + last_dim = dim + self.layers = torch.nn.Sequential(*modules) + self.loss = eval(loss)() + + def forward(self, output, target): + output = self.layers(output) + return self.loss(output, target) + + +class Conv1DProbe(Probe): + def __init__(self, layer_dims, kernel_size=1, activation="torch.nn.ReLU", **kwargs): + super().__init__(**kwargs) + activation = eval(activation) + in_dim, last_dim, *rest = layer_dims + assert kernel_size % 2 == 1 + modules = [ + torch.nn.Conv1d(in_dim, last_dim, kernel_size, padding=kernel_size // 2) + ] + for dim in rest: + modules.append(activation()) + modules.append( + torch.nn.Conv1d(last_dim, dim, kernel_size, padding=kernel_size // 2) + ) + last_dim = dim + self.layers = torch.nn.Sequential(*modules) + self.loss = torch.nn.CrossEntropyLoss() + + def forward(self, output, target, padding_mask): + N, Cin, L = output.shape + Nm, Cpad, Lm = padding_mask.shape + assert Cpad == 1 + assert N == Nm + output = F.interpolate(output, scale_factor=Lm // L) + output = self.layers(output) + padding_mask = padding_mask.float().squeeze(1) + neg_mask = 1.0 - padding_mask + target = (target * neg_mask + padding_mask * self.loss.ignore_index).long() + loss = self.loss(output, target) + weigth = neg_mask.sum() + acc = (neg_mask * (torch.argmax(output, 1) == target).float()).sum() / weigth + probe_logs = { + "loss": loss.item(), + "loss_weigth": weigth.item(), + "acc": acc.item(), + "acc_weigth": weigth.item(), + } + # logging.info("Probe logs: %s", probe_logs) + return loss * self._loss_weigth, probe_logs + + +class ProbedModel: + """A model which can attach small probes to analyze model behavior.""" + + def _build_probe(self, cls, **kwargs): + cls = eval(cls) + return cls(model=self, **kwargs) + + def attach_probes(self, probe_defs): + if not probe_defs: + return + self._probes = torch.nn.ModuleDict( + { + probe_name: self._build_probe(**probe_def) + for probe_name, probe_def in probe_defs.items() + } + ) + + def get_probe_losses(self, minibatch): + loss = 0.0 + extra_log_keys = {} + for probe_name, probe in self._probes.items(): + probe_loss, probe_log_keys = probe.compute_loss(minibatch) + loss += probe_loss * probe._loss_weigth + for k, v in probe_log_keys.items(): + extra_log_keys[f"probe_{probe_name}_{k}"] = v + return loss, extra_log_keys + +def reduce_probe_metrics(logging_outputs, metrics): + handled_keys = set() + def get_v(k): + handled_keys.add(k) + return sum(log.get(k, 0) for log in logging_outputs) + for k in logging_outputs[0]: + if k.startswith("probe_"): + if k.endswith("_weigth"): + continue + v = get_v(k) + weigth = get_v(f'{k}_weigth') + metrics.log_scalar(k, v, weigth, round=3) + return handled_keys \ No newline at end of file diff --git a/fairseq/models/roberta/__init__.py b/fairseq/models/roberta/__init__.py index 56579e5915..cf16914fbc 100644 --- a/fairseq/models/roberta/__init__.py +++ b/fairseq/models/roberta/__init__.py @@ -6,4 +6,5 @@ from .hub_interface import * # noqa from .model import * # noqa from .model_camembert import * # noqa +from .model_gottbert import * # noqa from .model_xlmr import * # noqa diff --git a/fairseq/models/roberta/model.py b/fairseq/models/roberta/model.py index 0f6efe5b33..96a7b9c8a2 100644 --- a/fairseq/models/roberta/model.py +++ b/fairseq/models/roberta/model.py @@ -496,7 +496,6 @@ def base_architecture(args): args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None) args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) args.untie_weights_roberta = getattr(args, "untie_weights_roberta", False) args.spectral_norm_classification_head = getattr( args, "spectral_norm_classification_head", False diff --git a/fairseq/models/roberta/model_gottbert.py b/fairseq/models/roberta/model_gottbert.py new file mode 100644 index 0000000000..2e8c66354a --- /dev/null +++ b/fairseq/models/roberta/model_gottbert.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +GottBERT: a pure German Language Model +""" + +from fairseq.models import register_model + +from .hub_interface import RobertaHubInterface +from .model import RobertaModel + + +@register_model('gottbert') +class GottbertModel(RobertaModel): + + @classmethod + def hub_models(cls): + return { + 'gottbert-base': 'https://dl.gottbert.de/fairseq/models/gottbert-base.tar.gz', + } + + @classmethod + def from_pretrained(cls, + model_name_or_path, + checkpoint_file='model.pt', + data_name_or_path='.', + bpe='hf_byte_bpe', + bpe_vocab='vocab.json', + bpe_merges='merges.txt', + bpe_add_prefix_space=False, + **kwargs + ): + from fairseq import hub_utils + + x = hub_utils.from_pretrained( + model_name_or_path, + checkpoint_file, + data_name_or_path, + archive_map=cls.hub_models(), + bpe=bpe, + load_checkpoint_heads=True, + bpe_vocab=bpe_vocab, + bpe_merges=bpe_merges, + bpe_add_prefix_space=bpe_add_prefix_space, + **kwargs, + ) + return RobertaHubInterface(x['args'], x['task'], x['models'][0]) diff --git a/fairseq/models/transformer.py b/fairseq/models/transformer.py index 7614c33f74..fa4c29855b 100644 --- a/fairseq/models/transformer.py +++ b/fairseq/models/transformer.py @@ -16,7 +16,6 @@ register_model, register_model_architecture, ) -from fairseq.models.fairseq_encoder import EncoderOut from fairseq.modules import ( AdaptiveSoftmax, FairseqDropout, @@ -72,6 +71,13 @@ def moses_fastbpe(path): 'bpe': 'fastbpe', } + def spm(path): + return { + 'path': path, + 'bpe': 'sentencepiece', + 'tokenizer': 'space', + } + return { 'transformer.wmt14.en-fr': moses_subword('https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2'), 'transformer.wmt16.en-de': 'https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2', @@ -84,6 +90,12 @@ def moses_fastbpe(path): 'transformer.wmt19.en-ru.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.single_model.tar.gz'), 'transformer.wmt19.de-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.single_model.tar.gz'), 'transformer.wmt19.ru-en.single_model': moses_fastbpe('https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.single_model.tar.gz'), + 'transformer.wmt20.en-ta': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-ta.single.tar.gz'), + 'transformer.wmt20.en-iu.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.news.single.tar.gz'), + 'transformer.wmt20.en-iu.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.en-iu.nh.single.tar.gz'), + 'transformer.wmt20.ta-en': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.ta-en.single.tar.gz'), + 'transformer.wmt20.iu-en.news': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.news.single.tar.gz'), + 'transformer.wmt20.iu-en.nh': spm('https://dl.fbaipublicfiles.com/fairseq/models/wmt20.iu-en.nh.single.tar.gz'), } # fmt: on @@ -390,7 +402,7 @@ def forward_embedding( def forward( self, src_tokens, - src_lengths, + src_lengths: Optional[torch.Tensor] = None, return_all_hiddens: bool = False, token_embeddings: Optional[torch.Tensor] = None, ): @@ -406,7 +418,7 @@ def forward( default `None` will recompute embeddings Returns: - namedtuple: + dict: - **encoder_out** (Tensor): the last encoder layer's output of shape `(src_len, batch, embed_dim)` - **encoder_padding_mask** (ByteTensor): the positions of @@ -425,7 +437,7 @@ def forward( # compute padding mask encoder_padding_mask = src_tokens.eq(self.padding_idx) - encoder_states = [] if return_all_hiddens else None + encoder_states = [] # encoder layers for layer in self.layers: @@ -437,17 +449,21 @@ def forward( if self.layer_norm is not None: x = self.layer_norm(x) - return EncoderOut( - encoder_out=x, # T x B x C - encoder_padding_mask=encoder_padding_mask, # B x T - encoder_embedding=encoder_embedding, # B x T x C - encoder_states=encoder_states, # List[T x B x C] - src_tokens=None, - src_lengths=None, - ) + # The Pytorch Mobile lite interpreter does not supports returning NamedTuple in + # `foward` so we use a dictionary instead. + # TorchScript does not support mixed values so the values are all lists. + # The empty list is equivalent to None. + return { + "encoder_out": [x], # T x B x C + "encoder_padding_mask": [encoder_padding_mask], # B x T + "encoder_embedding": [encoder_embedding], # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": [], + "src_lengths": [], + } @torch.jit.export - def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): + def reorder_encoder_out(self, encoder_out: Dict[str, List[Tensor]], new_order): """ Reorder encoder output according to *new_order*. @@ -458,50 +474,46 @@ def reorder_encoder_out(self, encoder_out: EncoderOut, new_order): Returns: *encoder_out* rearranged according to *new_order* """ - """ - Since encoder_padding_mask and encoder_embedding are both of type - Optional[Tensor] in EncoderOut, they need to be copied as local - variables for Torchscript Optional refinement - """ - encoder_padding_mask: Optional[Tensor] = encoder_out.encoder_padding_mask - encoder_embedding: Optional[Tensor] = encoder_out.encoder_embedding + if len(encoder_out["encoder_out"]) == 0: + new_encoder_out = [] + else: + new_encoder_out = [encoder_out["encoder_out"][0].index_select(1, new_order)] + if len(encoder_out["encoder_padding_mask"]) == 0: + new_encoder_padding_mask = [] + else: + new_encoder_padding_mask = [ + encoder_out["encoder_padding_mask"][0].index_select(0, new_order) + ] + if len(encoder_out["encoder_embedding"]) == 0: + new_encoder_embedding = [] + else: + new_encoder_embedding = [ + encoder_out["encoder_embedding"][0].index_select(0, new_order) + ] - new_encoder_out = ( - encoder_out.encoder_out - if encoder_out.encoder_out is None - else encoder_out.encoder_out.index_select(1, new_order) - ) - new_encoder_padding_mask = ( - encoder_padding_mask - if encoder_padding_mask is None - else encoder_padding_mask.index_select(0, new_order) - ) - new_encoder_embedding = ( - encoder_embedding - if encoder_embedding is None - else encoder_embedding.index_select(0, new_order) - ) - src_tokens = encoder_out.src_tokens - if src_tokens is not None: - src_tokens = src_tokens.index_select(0, new_order) + if len(encoder_out["src_tokens"]) == 0: + src_tokens = [] + else: + src_tokens = [(encoder_out["src_tokens"][0]).index_select(0, new_order)] - src_lengths = encoder_out.src_lengths - if src_lengths is not None: - src_lengths = src_lengths.index_select(0, new_order) + if len(encoder_out["src_lengths"]) == 0: + src_lengths = [] + else: + src_lengths = [(encoder_out["src_lengths"][0]).index_select(0, new_order)] - encoder_states = encoder_out.encoder_states - if encoder_states is not None: + encoder_states = encoder_out["encoder_states"] + if len(encoder_states) > 0: for idx, state in enumerate(encoder_states): encoder_states[idx] = state.index_select(1, new_order) - return EncoderOut( - encoder_out=new_encoder_out, # T x B x C - encoder_padding_mask=new_encoder_padding_mask, # B x T - encoder_embedding=new_encoder_embedding, # B x T x C - encoder_states=encoder_states, # List[T x B x C] - src_tokens=src_tokens, # B x T - src_lengths=src_lengths, # B x 1 - ) + return { + "encoder_out": new_encoder_out, # T x B x C + "encoder_padding_mask": new_encoder_padding_mask, # B x T + "encoder_embedding": new_encoder_embedding, # B x T x C + "encoder_states": encoder_states, # List[T x B x C] + "src_tokens": src_tokens, # B x T + "src_lengths": src_lengths, # B x 1 + } def max_positions(self): """Maximum input length supported by the encoder.""" @@ -664,7 +676,7 @@ def build_decoder_layer(self, args, no_encoder_attn=False): def forward( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]] = None, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, features_only: bool = False, full_context_alignment: bool = False, @@ -706,7 +718,7 @@ def forward( def extract_features( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, @@ -723,14 +735,14 @@ def extract_features( """ A scriptable subclass of this class has an extract_features method and calls - super().extract_features, but super() is not supported in torchscript. Aa copy of + super().extract_features, but super() is not supported in torchscript. A copy of this function is made to be used in the subclass instead. """ def extract_features_scriptable( self, prev_output_tokens, - encoder_out: Optional[EncoderOut] = None, + encoder_out: Optional[Dict[str, List[Tensor]]], incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, full_context_alignment: bool = False, alignment_layer: Optional[int] = None, @@ -807,8 +819,15 @@ def extract_features_scriptable( x, layer_attn, _ = layer( x, - encoder_out.encoder_out if encoder_out is not None else None, - encoder_out.encoder_padding_mask if encoder_out is not None else None, + encoder_out["encoder_out"][0] + if (encoder_out is not None and len(encoder_out["encoder_out"]) > 0) + else None, + encoder_out["encoder_padding_mask"][0] + if ( + encoder_out is not None + and len(encoder_out["encoder_padding_mask"]) > 0 + ) + else None, incremental_state, self_attn_mask=self_attn_mask, self_attn_padding_mask=self_attn_padding_mask, diff --git a/fairseq/models/transformer_lm.py b/fairseq/models/transformer_lm.py index 35bfa6eb6f..d86b68b508 100644 --- a/fairseq/models/transformer_lm.py +++ b/fairseq/models/transformer_lm.py @@ -171,6 +171,8 @@ class TransformerLanguageModel(FairseqLanguageModel): def hub_models(cls): def moses_fastbpe(path): return {"path": path, "tokenizer": "moses", "bpe": "fastbpe"} + def spm(path): + return {"path": path, "tokenizer": "space", "bpe": "sentencepiece"} return { "transformer_lm.gbw.adaptive_huge": "https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2", @@ -184,6 +186,18 @@ def moses_fastbpe(path): "transformer_lm.wmt19.ru": moses_fastbpe( "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.bz2" ), + "transformer_lm.wmt20.en": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.en.tar.gz" + ), + "transformer_lm.wmt20.ta": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.ta.tar.gz" + ), + "transformer_lm.wmt20.iu.news": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.iu.news.tar.gz" + ), + "transformer_lm.wmt20.iu.nh": spm( + "https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt20.iu.nh.tar.gz" + ), } def __init__(self, decoder): diff --git a/fairseq/models/wav2vec/wav2vec.py b/fairseq/models/wav2vec/wav2vec.py index 772995b526..83b6461129 100644 --- a/fairseq/models/wav2vec/wav2vec.py +++ b/fairseq/models/wav2vec/wav2vec.py @@ -3,14 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field import logging import math +from typing import Optional, Tuple +from omegaconf import II import sys import torch import torch.nn as nn import torch.nn.functional as F -from fairseq.models import BaseFairseqModel, register_model, register_model_architecture +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model from fairseq.modules import ( Fp32GroupNorm, Fp32LayerNorm, @@ -18,264 +22,208 @@ KmeansVectorQuantizer, TransposeLast, ) +from fairseq.tasks import FairseqTask from fairseq.utils import buffered_arange logger = logging.getLogger(__name__) -@register_model("wav2vec") -class Wav2VecModel(BaseFairseqModel): - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - parser.add_argument( - "--prediction-steps", - type=int, - metavar="N", - help="number of steps ahead to predict", - ) - parser.add_argument( - "--sample-distance", - type=int, - metavar="N", - help="sample distance from target. does not work properly with cross-sampling", - ) - parser.add_argument( - "--cross-sample-negatives", - type=int, - metavar="N", - help="num of cross sampled negatives", - ) - parser.add_argument( - "--num-negatives", type=int, metavar="N", help="number of negative examples" - ) - parser.add_argument( - "--conv-feature-layers", - type=str, - metavar="EXPR", - help="convolutional feature extraction layers [(dim, kernel_size, stride), ...]", - ) - parser.add_argument( - "--conv-aggregator-layers", - type=str, - metavar="EXPR", - help="convolutional feature extraction layers [(dim, kernel_size, stride), ...]", - ) - parser.add_argument( - "--dropout", - type=float, - metavar="D", - help="dropout to apply within the model", - ) - parser.add_argument( - "--dropout-features", - type=float, - metavar="D", - help="dropout to apply to the features", - ) - parser.add_argument( - "--dropout-agg", - type=float, - metavar="D", - help="dropout to apply after aggregation step", - ) - parser.add_argument( - "--encoder", type=str, choices=["cnn"], help="type of encoder to use" - ) - parser.add_argument( - "--aggregator", - type=str, - choices=["cnn", "gru"], - help="type of aggregator to use", - ) - parser.add_argument( - "--gru-dim", type=int, metavar="N", help="GRU dimensionality" - ) - - parser.add_argument( - "--no-conv-bias", - action="store_true", - help="if set, does not learn bias for conv layers", - ) - parser.add_argument( - "--agg-zero-pad", - action="store_true", - help="if set, zero pads in aggregator instead of repl pad", - ) +AGGREGATOR_CHOICES = ChoiceEnum(["cnn", "gru"]) +PROJECT_FEATURES_CHOICES = ChoiceEnum(["none", "same", "new"]) +ACTIVATION_CHOICES = ChoiceEnum(["relu", "gelu"]) +VQ_TYPE_CHOICES = ChoiceEnum(["none", "gumbel", "kmeans"]) - parser.add_argument( - "--skip-connections-feat", - action="store_true", - help="if set, adds skip connections to the feature extractor", - ) - parser.add_argument( - "--skip-connections-agg", - action="store_true", - help="if set, adds skip connections to the aggregator", - ) - parser.add_argument( - "--residual-scale", - type=float, - metavar="D", - help="scales residual by sqrt(value)", - ) - - parser.add_argument( - "--log-compression", - action="store_true", - help="if set, adds a log compression to feature extractor", - ) - - parser.add_argument( - "--balanced-classes", - action="store_true", - help="if set, loss is scaled to balance for number of negatives", - ) - parser.add_argument( - "--project-features", - choices=["none", "same", "new"], - help="if not none, features are projected using the (same or new) aggregator", - ) - - parser.add_argument( - "--non-affine-group-norm", - action="store_true", - help="if set, group norm is not affine", - ) - - parser.add_argument( - "--offset", - help="if set, introduces an offset from target to predictions. " - 'if set to "auto", it is computed automatically from the receptive field', - ) - - parser.add_argument( - "--activation", - type=str, - choices=["relu", "gelu"], - help="which activation function to use", - ) +@dataclass +class Wav2VecConfig(FairseqDataclass): + prediction_steps: int = field( + default=12, metadata={"help": "number of steps ahead to predict"} + ) + sample_distance: Optional[int] = field( + default=None, + metadata={ + "help": "sample distance from target. does not work properly with cross-sampling" + }, + ) + cross_sample_negatives: int = field( + default=0, metadata={"help": "num of cross sampled negatives"} + ) + num_negatives: int = field( + default=10, metadata={"help": "num of cross sampled negatives"} + ) + conv_feature_layers: str = field( + default="[(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)]", + metadata={ + "help": "convolutional feature extraction layers [(dim, kernel_size, stride), ...]" + }, + ) + conv_aggregator_layers: str = field( + default="[(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)]", + metadata={ + "help": "convolutional aggregator layers [(dim, kernel_size, stride), ...]" + }, + ) + dropout: float = field( + default=0.0, metadata={"help": "dropout to apply within the model"} + ) + dropout_features: float = field( + default=0.0, metadata={"help": "dropout to apply to the features"} + ) + dropout_agg: float = field( + default=0.0, metadata={"help": "dropout to apply after aggregation step"} + ) + aggregator: AGGREGATOR_CHOICES = field( + default="cnn", metadata={"help": "type of aggregator to use"} + ) + gru_dim: int = field(default=512, metadata={"help": "GRU dimensionality"}) + no_conv_bias: bool = field( + default=False, metadata={"help": "if set, does not learn bias for conv layers"} + ) + agg_zero_pad: bool = field( + default=False, + metadata={"help": "if set, zero pads in aggregator instead of repl pad"}, + ) + skip_connections_feat: bool = field( + default=False, + metadata={"help": "if set, adds skip connections to the feature extractor"}, + ) + skip_connections_agg: bool = field( + default=True, + metadata={"help": "if set, adds skip connections to the aggregator"}, + ) + residual_scale: float = field( + default=0.5, metadata={"help": "scales residual by sqrt(value)"} + ) + log_compression: bool = field( + default=True, + metadata={"help": "if set, adds a log compression to feature extractor"}, + ) + balanced_classes: bool = field( + default=False, + metadata={"help": "if set, loss is scaled to balance for number of negatives"}, + ) + project_features: PROJECT_FEATURES_CHOICES = field( + default="none", + metadata={ + "help": "if not none, features are projected using the (same or new) aggregator" + }, + ) + non_affine_group_norm: bool = field( + default=False, metadata={"help": "if set, group norm is not affine"} + ) + offset: str = field( + default="auto", + metadata={ + "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" + }, + ) + activation: ACTIVATION_CHOICES = field( + default="relu", + metadata={ + "help": "if set to 'auto', it is computed automatically from the receptive field, else set to int value" + }, + ) + vq_type: VQ_TYPE_CHOICES = field( + default="none", metadata={"help": "which type of quantizer to use"} + ) + vq_vars: int = field( + default=320, + metadata={"help": "project to this many vector quantized variables per group"}, + ) + vq_groups: int = field( + default=2, metadata={"help": "number of groups of latent variables"} + ) + vq_dim: int = field( + default=0, + metadata={ + "help": "uses this dimensionality for quantized vectors. 0 to use model dim // groups" + }, + ) + vq_depth: int = field( + default=1, metadata={"help": "number of layers for vq weight projection"} + ) + combine_groups: bool = field( + default=False, metadata={"help": "if set, variables are shared among groups"} + ) + vq_temp: Tuple[float, float, float] = field( + default=(2.0, 0.5, 0.999995), + metadata={ + "help": "temperature for latent variable sampling with gumbel softmax. should be a tuple of 3 values (start, end, decay)" + }, + ) + vq_gamma: float = field( + default=0.25, + metadata={"help": "gamma parameter for kmeans style vector quantization"}, + ) + infonce: bool = II("criterion.infonce") - parser.add_argument( - "--vq-type", - type=str, - choices=["none", "gumbel", "kmeans"], - help="which type of quantizer to use", - ) - parser.add_argument( - "--vq-vars", - type=int, - metavar="N", - help="if set, project to this many vector quantized variables per group", - ) - parser.add_argument( - "--vq-groups", - type=int, - metavar="N", - help="number of groups of latent variables", - ) - parser.add_argument( - "--vq-dim", - type=int, - metavar="N", - help="uses this dimensionality for quantized vectors", - ) - parser.add_argument( - "--vq-depth", - type=int, - metavar="N", - help="number of layers for vq weight projection", - ) - parser.add_argument( - "--combine-groups", - action="store_true", - help="if set, variables are shared among groups", - ) - parser.add_argument( - "--vq-temp", - type=str, - metavar="TEMP", - help="temperature for latent variable sampling with gumbel softmax. should be a tuple of 3 values (start, end, decay)", - ) - parser.add_argument( - "--vq-gamma", - type=float, - metavar="D", - help="gamma parameter for kmeans style vector quantization", - ) +@register_model("wav2vec", dataclass=Wav2VecConfig) +class Wav2VecModel(BaseFairseqModel): @classmethod - def build_model(cls, args, task): + def build_model(cls, cfg: Wav2VecConfig, task: FairseqTask): """Build a new model instance.""" - # make sure all arguments are present in older models - base_wav2vec_architecture(args) - - model = Wav2VecModel(args) + model = Wav2VecModel(cfg) logger.info(model) return model - def __init__(self, args): + def __init__(self, cfg: Wav2VecConfig): super().__init__() - self.prediction_steps = args.prediction_steps - offset = args.offset + self.prediction_steps = cfg.prediction_steps + offset = cfg.offset - if args.activation == "relu": + if cfg.activation == "relu": activation = nn.ReLU() - elif args.activation == "gelu": + elif cfg.activation == "gelu": activation = nn.GELU() else: - raise Exception("unknown activation " + args.activation) - - if args.encoder == "cnn": - feature_enc_layers = eval(args.conv_feature_layers) - self.feature_extractor = ConvFeatureExtractionModel( - conv_layers=feature_enc_layers, - dropout=0.0, - log_compression=args.log_compression, - skip_connections=args.skip_connections_feat, - residual_scale=args.residual_scale, - non_affine_group_norm=args.non_affine_group_norm, - activation=activation, - ) - embed = feature_enc_layers[-1][0] - else: - raise Exception("unknown encoder type " + args.encoder) + raise Exception("unknown activation " + cfg.activation) + + feature_enc_layers = eval(cfg.conv_feature_layers) + self.feature_extractor = ConvFeatureExtractionModel( + conv_layers=feature_enc_layers, + dropout=0.0, + log_compression=cfg.log_compression, + skip_connections=cfg.skip_connections_feat, + residual_scale=cfg.residual_scale, + non_affine_group_norm=cfg.non_affine_group_norm, + activation=activation, + ) + embed = feature_enc_layers[-1][0] self.vector_quantizer = None - if args.vq_type == "gumbel": + if cfg.vq_type == "gumbel": self.vector_quantizer = GumbelVectorQuantizer( dim=embed, - num_vars=args.vq_vars, - temp=eval(args.vq_temp), - groups=args.vq_groups, - combine_groups=args.combine_groups, - vq_dim=args.vq_dim if args.vq_dim > 0 else embed, + num_vars=cfg.vq_vars, + temp=cfg.vq_temp, + groups=cfg.vq_groups, + combine_groups=cfg.combine_groups, + vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, time_first=False, activation=activation, - weight_proj_depth=args.vq_depth, + weight_proj_depth=cfg.vq_depth, weight_proj_factor=2, ) - elif args.vq_type == "kmeans": + elif cfg.vq_type == "kmeans": self.vector_quantizer = KmeansVectorQuantizer( dim=embed, - num_vars=args.vq_vars, - groups=args.vq_groups, - combine_groups=args.combine_groups, - vq_dim=args.vq_dim if args.vq_dim > 0 else embed, + num_vars=cfg.vq_vars, + groups=cfg.vq_groups, + combine_groups=cfg.combine_groups, + vq_dim=cfg.vq_dim if cfg.vq_dim > 0 else embed, time_first=False, - gamma=args.vq_gamma, + gamma=cfg.vq_gamma, ) else: assert ( - args.vq_type == "none" or args.vq_type is None + cfg.vq_type == "none" or cfg.vq_type is None ), "Unknown quantizer type" - if args.offset == "auto": - assert args.encoder == "cnn" + if cfg.offset == "auto": jin = 0 rin = 0 for _, k, stride in feature_enc_layers: @@ -291,34 +239,34 @@ def __init__(self, args): offset = int(offset) def make_aggregator(): - if args.aggregator == "cnn": - agg_layers = eval(args.conv_aggregator_layers) + if cfg.aggregator == "cnn": + agg_layers = eval(cfg.conv_aggregator_layers) agg_dim = agg_layers[-1][0] feature_aggregator = ConvAggegator( conv_layers=agg_layers, embed=embed, - dropout=args.dropout, - skip_connections=args.skip_connections_agg, - residual_scale=args.residual_scale, - non_affine_group_norm=args.non_affine_group_norm, - conv_bias=not args.no_conv_bias, - zero_pad=args.agg_zero_pad, + dropout=cfg.dropout, + skip_connections=cfg.skip_connections_agg, + residual_scale=cfg.residual_scale, + non_affine_group_norm=cfg.non_affine_group_norm, + conv_bias=not cfg.no_conv_bias, + zero_pad=cfg.agg_zero_pad, activation=activation, ) - elif args.aggregator == "gru": - agg_dim = args.gru_dim + elif cfg.aggregator == "gru": + agg_dim = cfg.gru_dim feature_aggregator = nn.Sequential( TransposeLast(), nn.GRU( input_size=embed, hidden_size=agg_dim, num_layers=1, - dropout=args.dropout, + dropout=cfg.dropout, ), TransposeLast(deconstruct_idx=0), ) else: - raise Exception("unknown aggregator type " + args.aggregator) + raise Exception("unknown aggregator type " + cfg.aggregator) return feature_aggregator, agg_dim @@ -327,24 +275,24 @@ def make_aggregator(): self.wav2vec_predictions = Wav2VecPredictionsModel( in_dim=agg_dim, out_dim=embed, - prediction_steps=args.prediction_steps, - n_negatives=args.num_negatives, - cross_sample_negatives=args.cross_sample_negatives, - sample_distance=args.sample_distance, - dropout=args.dropout, + prediction_steps=cfg.prediction_steps, + n_negatives=cfg.num_negatives, + cross_sample_negatives=cfg.cross_sample_negatives, + sample_distance=cfg.sample_distance, + dropout=cfg.dropout, offset=offset, - balanced_classes=args.balanced_classes, - infonce=args.infonce, + balanced_classes=cfg.balanced_classes, + infonce=cfg.infonce, ) - self.dropout_feats = nn.Dropout(p=args.dropout_features) - self.dropout_agg = nn.Dropout(p=args.dropout_agg) + self.dropout_feats = nn.Dropout(p=cfg.dropout_features) + self.dropout_agg = nn.Dropout(p=cfg.dropout_agg) - if args.project_features == "none": + if cfg.project_features == "none": self.project_features = None - elif args.project_features == "same": + elif cfg.project_features == "same": self.project_features = self.feature_aggregator - elif args.project_features == "new": + elif cfg.project_features == "new": self.project_features, _ = make_aggregator() def forward(self, source): @@ -680,56 +628,3 @@ def forward(self, x, y): labels = (labels, weights) return predictions, labels - - -@register_model_architecture("wav2vec", "wav2vec") -def base_wav2vec_architecture(args): - conv_feature_layers = "[(512, 10, 5)]" - conv_feature_layers += " + [(512, 8, 4)]" - conv_feature_layers += " + [(512, 4, 2)] * 3" - args.conv_feature_layers = getattr(args, "conv_feature_layers", conv_feature_layers) - - args.conv_aggregator_layers = getattr( - args, "conv_aggregator_layers", "[(512, 3, 1)] * 9" - ) - - args.prediction_steps = getattr(args, "prediction_steps", 12) - args.num_negatives = getattr(args, "num_negatives", 1) - args.sample_distance = getattr(args, "sample_distance", None) - args.cross_sample_negatives = getattr(args, "cross_sample_negatives", 0) - - args.dropout = getattr(args, "dropout", 0.0) - args.dropout_features = getattr(args, "dropout_features", 0.0) - args.dropout_agg = getattr(args, "dropout_agg", 0.0) - args.encoder = getattr(args, "encoder", "cnn") - args.aggregator = getattr(args, "aggregator", "cnn") - - args.skip_connections_feat = getattr(args, "skip_connections_feat", False) - args.skip_connections_agg = getattr(args, "skip_connections_agg", False) - args.residual_scale = getattr(args, "residual_scale", 0.5) - - args.gru_dim = getattr(args, "gru_dim", 512) - - args.no_conv_bias = getattr(args, "no_conv_bias", False) - args.agg_zero_pad = getattr(args, "agg_zero_pad", False) - - args.log_compression = getattr(args, "log_compression", False) - - args.balanced_classes = getattr(args, "balanced_classes", False) - args.infonce = getattr(args, "infonce", False) - args.project_features = getattr(args, "project_features", "none") - - args.non_affine_group_norm = getattr(args, "non_affine_group_norm", False) - - args.offset = getattr(args, "offset", "auto") - - args.activation = getattr(args, "activation", "relu") - - args.vq_type = getattr(args, "vq_type", "none") - args.vq_vars = getattr(args, "vq_vars", 320) - args.vq_groups = getattr(args, "vq_groups", 2) - args.vq_dim = getattr(args, "vq_dim", 0) - args.vq_depth = getattr(args, "vq_depth", 1) - args.combine_groups = getattr(args, "combine_groups", False) - args.vq_temp = getattr(args, "vq_temp", "(2.0, 0.5, 0.999995)") - args.vq_gamma = getattr(args, "vq_gamma", 0.25) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index e6fecdd4fe..783ebcfe6b 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -14,7 +14,7 @@ from fairseq import utils from fairseq.data.data_utils import compute_mask_indices from fairseq.dataclass import ChoiceEnum, FairseqDataclass -from fairseq.models import BaseFairseqModel, register_model, register_model_architecture +from fairseq.models import BaseFairseqModel, register_model from fairseq.modules import ( Fp32GroupNorm, Fp32LayerNorm, @@ -92,7 +92,7 @@ class Wav2Vec2Config(FairseqDataclass): default=False, metadata={"help": "apply layernorm first in the transformer"} ) conv_feature_layers: str = field( - default="[(512, 10, 5), (512, 8, 4)] + [(512, 4, 2)] * 3 + [(512, 1, 1)]", + default="[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]", metadata={ "help": "string describing convolutional feature extraction layers in form of a python list that contains " "[(dim, kernel_size, stride), ...]" @@ -147,7 +147,7 @@ class Wav2Vec2Config(FairseqDataclass): default=0, metadata={ "help": "secondary mask argument (used for more complex distributions), " - "see help in compute_mask_indicesh" + "see help in compute_mask_indices" }, ) no_mask_overlap: bool = field( diff --git a/fairseq/models/wav2vec/wav2vec2_asr.py b/fairseq/models/wav2vec/wav2vec2_asr.py index f62ec633b4..790b0a8ad1 100644 --- a/fairseq/models/wav2vec/wav2vec2_asr.py +++ b/fairseq/models/wav2vec/wav2vec2_asr.py @@ -7,166 +7,145 @@ import contextlib import copy import math - import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +from dataclasses import dataclass, field +from omegaconf import MISSING, II, open_dict +from typing import Any + from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass import FairseqDataclass from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.tasks import FairseqTask from fairseq.models import ( BaseFairseqModel, FairseqEncoder, FairseqEncoderDecoderModel, FairseqIncrementalDecoder, register_model, - register_model_architecture, ) +from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES from fairseq.modules import LayerNorm, PositionalEmbedding, TransformerDecoderLayer -def add_common_args(parser): - parser.add_argument("--w2v-path", help="path to wav2vec 2.0 model") - parser.add_argument( - "--no-pretrained-weights", - action="store_true", - help="if true, does not load pretrained weights", +@dataclass +class Wav2Vec2AsrConfig(FairseqDataclass): + w2v_path: str = field( + default=MISSING, metadata={"help": "path to wav2vec 2.0 model"} ) - parser.add_argument( - "--dropout-input", - type=float, - metavar="D", - help="dropout to apply to the input (after feat extr)", + no_pretrained_weights: bool = field( + default=False, metadata={"help": "if true, does not load pretrained weights"} ) - parser.add_argument( - "--final-dropout", - type=float, - metavar="D", - help="dropout after transformer and before final projection", + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, ) - parser.add_argument( - "--apply-mask", action="store_true", help="apply masking during fine-tuning" + final_dropout: float = field( + default=0.0, + metadata={"help": "dropout after transformer and before final projection"}, ) - parser.add_argument( - "--dropout", - type=float, - metavar="D", - help="dropout probability inside wav2vec 2.0 model", + dropout: float = field( + default=0.0, metadata={"help": "dropout probability inside wav2vec 2.0 model"} ) - parser.add_argument( - "--attention-dropout", - type=float, - metavar="D", - help="dropout probability for attention weights inside wav2vec 2.0 model", + attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights inside wav2vec 2.0 model" + }, ) - parser.add_argument( - "--activation-dropout", - "--relu-dropout", - type=float, - metavar="D", - help="dropout probability after activation in FFN inside wav2vec 2.0 model", + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN inside wav2vec 2.0 model" + }, ) - parser.add_argument( - "--mask-length", type=int, help="repeat the mask indices multiple times" + # masking + apply_mask: bool = field( + default=False, metadata={"help": "apply masking during fine-tuning"} ) - - parser.add_argument( - "--mask-prob", type=float, help="probability of replacing a token with mask" + mask_length: int = field( + default=10, metadata={"help": "repeat the mask indices multiple times"} ) - - parser.add_argument( - "--mask-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - help="how to choose masks", + mask_prob: float = field( + default=0.5, + metadata={ + "help": "probability of replacing a token with mask (normalized by length)" + }, ) - - parser.add_argument( - "--mask-other", - type=float, - help="stdev of the mask length in case of 'normal' selection strategy", + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose masks"} ) - - parser.add_argument( - "--no-mask-overlap", - action="store_true", - help="whether to allow masks to overlap", - ) - - parser.add_argument( - "--mask-channel-length", type=int, help="repeat the mask indices multiple times" + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indices" + }, ) - - parser.add_argument( - "--mask-channel-prob", - type=float, - help="probability of replacing a token with mask", + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} ) - parser.add_argument( - "--mask-channel-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - help="how to choose masks", + # channel masking + mask_channel_length: int = field( + default=10, metadata={"help": "length of the mask for features (channels)"} ) - - parser.add_argument( - "--mask-channel-other", - type=float, - help="stdev of the mask length in case of 'normal' selection strategy", + mask_channel_prob: float = field( + default=0.0, metadata={"help": "probability of replacing a feature with 0"} ) - - parser.add_argument( - "--no-mask-channel-overlap", - action="store_true", - help="whether to allow masks to overlap", + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, ) - - parser.add_argument( - "--freeze-finetune-updates", + mask_channel_other: float = field( default=0, - type=int, - help="dont finetune wav2vec for this many updates", + metadata={ + "help": "secondary mask argument (used for more complex distributions), " + "see help in compute_mask_indicesh" + }, ) - - parser.add_argument( - "--feature-grad-mult", - default=None, - type=float, - help="reset feature grad mult in wav2vec 2.0 to this", + no_mask_channel_overlap: bool = field( + default=False, metadata={"help": "whether to allow channel masks to overlap"} ) - - parser.add_argument( - "--layerdrop", - default=0.0, - type=float, - help="probability of dropping a layer in wav2vec 2.0", + freeze_finetune_updates: int = field( + default=0, metadata={"help": "dont finetune wav2vec for this many updates"} + ) + feature_grad_mult: float = field( + default=0.0, metadata={"help": "reset feature grad mult in wav2vec 2.0 to this"} ) + layerdrop: float = field( + default=0.0, metadata={"help": "probability of dropping a layer in wav2vec 2.0"} + ) + normalize: bool = II("task.normalize") + data: str = II("task.data") + # this holds the loaded wav2vec args + w2v_args: Any = None -@register_model("wav2vec_ctc") -class Wav2VecCtc(BaseFairseqModel): - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - add_common_args(parser) +@dataclass +class Wav2Vec2CtcConfig(Wav2Vec2AsrConfig): + pass - def __init__(self, w2v_encoder, args): + +@register_model("wav2vec_ctc", dataclass=Wav2Vec2CtcConfig) +class Wav2VecCtc(BaseFairseqModel): + def __init__(self, cfg: Wav2Vec2CtcConfig, w2v_encoder: BaseFairseqModel): super().__init__() + self.cfg = cfg self.w2v_encoder = w2v_encoder - self.args = args def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) return state_dict @classmethod - def build_model(cls, args, task): + def build_model(cls, cfg: Wav2Vec2CtcConfig, task: FairseqTask): """Build a new model instance.""" - base_architecture(args) - w2v_encoder = Wav2VecEncoder(args, task.target_dictionary) - return cls(w2v_encoder, args) + w2v_encoder = Wav2VecEncoder(cfg, task.target_dictionary) + return cls(cfg, w2v_encoder) def get_normalized_probs(self, net_output, log_probs): """Get normalized probabilities (or log probs) from a net's output.""" @@ -181,96 +160,67 @@ def forward(self, **kwargs): x = self.w2v_encoder(**kwargs) return x - # def max_positions(self): - # return None +@dataclass +class Wav2Vec2Seq2SeqConfig(Wav2Vec2AsrConfig): + decoder_embed_dim: int = field( + default=768, metadata={"help": "decoder embedding dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field(default=6, metadata={"help": "num of decoder layers"}) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "decoder layerdrop chance"} + ) + decoder_attention_heads: int = field( + default=4, metadata={"help": "num decoder attention heads"} + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_normalize_before: bool = field( + default=False, metadata={"help": "apply layernorm before each decoder block"} + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings (outside self attention)" + }, + ) + decoder_dropout: float = field( + default=0.0, metadata={"help": "dropout probability in the decoder"} + ) + decoder_attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights inside the decoder" + }, + ) + decoder_activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN inside the decoder" + }, + ) + max_target_positions: int = field( + default=2048, metadata={"help": "max target positions"} + ) + share_decoder_input_output_embed: bool = field( + default=False, metadata={"help": "share decoder input and output embeddings"} + ) -@register_model("wav2vec_seq2seq") -class TransformerModel(FairseqEncoderDecoderModel): - def __init__(self, args, encoder, decoder): - super().__init__(encoder, decoder) - - @staticmethod - def add_args(parser): - add_common_args(parser) - - parser.add_argument( - "--decoder-embed-dim", - type=int, - metavar="N", - help="decoder embedding dimension", - ) - parser.add_argument( - "--decoder-ffn-embed-dim", - type=int, - metavar="N", - help="decoder embedding dimension for FFN", - ) - parser.add_argument( - "--decoder-layers", type=int, metavar="N", help="num decoder layers" - ) - parser.add_argument( - "--decoder-layerdrop", - type=float, - metavar="D", - help="decoder layerdrop chance", - ) - parser.add_argument( - "--decoder-attention-heads", - type=int, - metavar="N", - help="num decoder attention heads", - ) - parser.add_argument( - "--decoder-learned-pos", - action="store_true", - help="use learned positional embeddings in the decoder", - ) - parser.add_argument( - "--decoder-normalize-before", - action="store_true", - help="apply layernorm before each decoder block", - ) - parser.add_argument( - "--no-token-positional-embeddings", - default=False, - action="store_true", - help="if set, disables positional embeddings (outside self attention)", - ) - - parser.add_argument( - "--decoder-dropout", - type=float, - metavar="D", - help="dropout probability in the decoder", - ) - parser.add_argument( - "--decoder-attention-dropout", - type=float, - metavar="D", - help="dropout probability for attention weights inside the decoder", - ) - parser.add_argument( - "--decoder-activation-dropout", - type=float, - metavar="D", - help="dropout probability after activation in FFN inside the decoder", - ) - # fmt: on +@register_model("wav2vec_seq2seq", dataclass=Wav2Vec2Seq2SeqConfig) +class Wav2Vec2Seq2SeqModel(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder): + super().__init__(encoder, decoder) @classmethod - def build_model(cls, args, task): + def build_model(cls, cfg: Wav2Vec2Seq2SeqConfig, task: FairseqTask): """Build a new model instance.""" - # make sure all arguments are present in older models - base_architecture(args) - - if not hasattr(args, "max_source_positions"): - args.max_source_positions = 2048 - if not hasattr(args, "max_target_positions"): - args.max_target_positions = 2048 - src_dict, tgt_dict = task.source_dictionary, task.target_dictionary def build_embedding(dictionary, embed_dim): @@ -279,19 +229,20 @@ def build_embedding(dictionary, embed_dim): emb = Embedding(num_embeddings, embed_dim, padding_idx) return emb - decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim) + decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim) + + encoder = cls.build_encoder(cfg) + decoder = cls.build_decoder(cfg, tgt_dict, decoder_embed_tokens) - encoder = cls.build_encoder(args) - decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens) - return TransformerModel(args, encoder, decoder) + return Wav2Vec2Seq2SeqModel(encoder, decoder) @classmethod - def build_encoder(cls, args): - return Wav2VecEncoder(args) + def build_encoder(cls, cfg: Wav2Vec2AsrConfig): + return Wav2VecEncoder(cfg) @classmethod - def build_decoder(cls, args, tgt_dict, embed_tokens): - return TransformerDecoder(args, tgt_dict, embed_tokens) + def build_decoder(cls, cfg: Wav2Vec2Seq2SeqConfig, tgt_dict, embed_tokens): + return TransformerDecoder(cfg, tgt_dict, embed_tokens) def forward(self, **kwargs): encoder_out = self.encoder(tbc=False, **kwargs) @@ -304,52 +255,50 @@ def upgrade_state_dict_named(self, state_dict, name): class Wav2VecEncoder(FairseqEncoder): - def __init__(self, args, tgt_dict=None): - self.apply_mask = args.apply_mask + def __init__(self, cfg: Wav2Vec2AsrConfig, tgt_dict=None): + self.apply_mask = cfg.apply_mask arg_overrides = { - "dropout": args.dropout, - "activation_dropout": args.activation_dropout, - "dropout_input": args.dropout_input, - "attention_dropout": args.attention_dropout, - "mask_length": args.mask_length, - "mask_prob": args.mask_prob, - "mask_selection": args.mask_selection, - "mask_other": args.mask_other, - "no_mask_overlap": args.no_mask_overlap, - "mask_channel_length": args.mask_channel_length, - "mask_channel_prob": args.mask_channel_prob, - "mask_channel_selection": args.mask_channel_selection, - "mask_channel_other": args.mask_channel_other, - "no_mask_channel_overlap": args.no_mask_channel_overlap, - "encoder_layerdrop": args.layerdrop, - "feature_grad_mult": args.feature_grad_mult, + "dropout": cfg.dropout, + "activation_dropout": cfg.activation_dropout, + "dropout_input": cfg.dropout_input, + "attention_dropout": cfg.attention_dropout, + "mask_length": cfg.mask_length, + "mask_prob": cfg.mask_prob, + "mask_selection": cfg.mask_selection, + "mask_other": cfg.mask_other, + "no_mask_overlap": cfg.no_mask_overlap, + "mask_channel_length": cfg.mask_channel_length, + "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_selection": cfg.mask_channel_selection, + "mask_channel_other": cfg.mask_channel_other, + "no_mask_channel_overlap": cfg.no_mask_channel_overlap, + "encoder_layerdrop": cfg.layerdrop, + "feature_grad_mult": cfg.feature_grad_mult, } - if getattr(args, "w2v_args", None) is None: - state = checkpoint_utils.load_checkpoint_to_cpu( - args.w2v_path, arg_overrides - ) + if cfg.w2v_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu(cfg.w2v_path, arg_overrides) w2v_args = state.get("cfg", None) if w2v_args is None: w2v_args = convert_namespace_to_omegaconf(state["args"]) - args.w2v_args = w2v_args + cfg.w2v_args = w2v_args else: state = None - w2v_args = args.w2v_args + w2v_args = cfg.w2v_args if isinstance(w2v_args, Namespace): - args.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) + cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf(w2v_args) - assert ( - args.normalize == w2v_args.task.normalize - ), "Fine-tuning works best when data normalization is the same. " \ - "Please check that --normalize is set or unset for both" + assert cfg.normalize == w2v_args.task.normalize, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for both pre-training and here" + ) - w2v_args.task.data = args.data + w2v_args.task.data = cfg.data task = tasks.setup_task(w2v_args.task) model = task.build_model(w2v_args.model) - if state is not None and not args.no_pretrained_weights: + if state is not None and not cfg.no_pretrained_weights: model.load_state_dict(state["model"], strict=True) model.remove_pretraining_modules() @@ -360,14 +309,14 @@ def __init__(self, args, tgt_dict=None): self.w2v_model = model - self.final_dropout = nn.Dropout(args.final_dropout) - self.freeze_finetune_updates = args.freeze_finetune_updates + self.final_dropout = nn.Dropout(cfg.final_dropout) + self.freeze_finetune_updates = cfg.freeze_finetune_updates self.num_updates = 0 if tgt_dict is not None: self.proj = Linear(d, len(tgt_dict)) - elif getattr(args, "decoder_embed_dim", d) != d: - self.proj = Linear(d, args.decoder_embed_dim) + elif getattr(cfg, "decoder_embed_dim", d) != d: + self.proj = Linear(d, cfg.decoder_embed_dim) else: self.proj = None @@ -436,21 +385,26 @@ class TransformerDecoder(FairseqIncrementalDecoder): (default: False). """ - def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): + def __init__( + self, + cfg: Wav2Vec2Seq2SeqConfig, + dictionary, + embed_tokens, + no_encoder_attn=False, + ): super().__init__(dictionary) - self.dropout = args.decoder_dropout - self.share_input_output_embed = args.share_decoder_input_output_embed + self.dropout = cfg.decoder_dropout + self.share_input_output_embed = cfg.share_decoder_input_output_embed input_embed_dim = embed_tokens.embedding_dim - embed_dim = args.decoder_embed_dim - self.output_embed_dim = args.decoder_embed_dim - args.encoder_embed_dim = embed_dim + embed_dim = cfg.decoder_embed_dim + self.output_embed_dim = cfg.decoder_embed_dim - self.layerdrop = args.decoder_layerdrop + self.layerdrop = cfg.decoder_layerdrop padding_idx = embed_tokens.padding_idx - self.max_target_positions = args.max_target_positions + self.max_target_positions = cfg.max_target_positions self.embed_tokens = embed_tokens self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim @@ -463,25 +417,31 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): self.embed_positions = ( PositionalEmbedding( - args.max_target_positions, + cfg.max_target_positions, embed_dim, padding_idx, - learned=args.decoder_learned_pos, + learned=cfg.decoder_learned_pos, ) - if not args.no_token_positional_embeddings + if not cfg.no_token_positional_embeddings else None ) - args = copy.deepcopy(args) - args.dropout = args.decoder_dropout - args.attention_dropout = args.decoder_attention_dropout - args.activation_dropout = args.decoder_activation_dropout + # TODO: update this when transformer gets converted to dataclass configs + transformer_cfg = copy.deepcopy(cfg) + with open_dict(transformer_cfg): + transformer_cfg.dropout = transformer_cfg.decoder_dropout + transformer_cfg.attention_dropout = ( + transformer_cfg.decoder_attention_dropout + ) + transformer_cfg.activation_dropout = ( + transformer_cfg.decoder_activation_dropout + ) self.layers = nn.ModuleList([]) self.layers.extend( [ - TransformerDecoderLayer(args, no_encoder_attn) - for _ in range(args.decoder_layers) + TransformerDecoderLayer(transformer_cfg, no_encoder_attn) + for _ in range(transformer_cfg.decoder_layers) ] ) @@ -491,9 +451,7 @@ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False): ) nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5) - if args.decoder_normalize_before and not getattr( - args, "no_decoder_final_norm", False - ): + if transformer_cfg.decoder_normalize_before: self.layer_norm = LayerNorm(embed_dim) else: self.layer_norm = None @@ -633,51 +591,3 @@ def Linear(in_features, out_features, bias=True): if bias: nn.init.constant_(m.bias, 0.0) return m - - -@register_model_architecture("wav2vec_ctc", "wav2vec_ctc") -def base_architecture(args): - args.no_pretrained_weights = getattr(args, "no_pretrained_weights", False) - args.dropout_input = getattr(args, "dropout_input", 0) - args.final_dropout = getattr(args, "final_dropout", 0) - args.apply_mask = getattr(args, "apply_mask", False) - args.dropout = getattr(args, "dropout", 0) - args.attention_dropout = getattr(args, "attention_dropout", 0) - args.activation_dropout = getattr(args, "activation_dropout", 0) - - args.mask_length = getattr(args, "mask_length", 10) - args.mask_prob = getattr(args, "mask_prob", 0.5) - args.mask_selection = getattr(args, "mask_selection", "static") - args.mask_other = getattr(args, "mask_other", 0) - args.no_mask_overlap = getattr(args, "no_mask_overlap", False) - args.mask_channel_length = getattr(args, "mask_channel_length", 10) - args.mask_channel_prob = getattr(args, "mask_channel_prob", 0.5) - args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") - args.mask_channel_other = getattr(args, "mask_channel_other", 0) - args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) - - args.freeze_finetune_updates = getattr(args, "freeze_finetune_updates", 0) - args.feature_grad_mult = getattr(args, "feature_grad_mult", 0) - args.layerdrop = getattr(args, "layerdrop", 0.0) - - -@register_model_architecture("wav2vec_seq2seq", "wav2vec_seq2seq") -def seq2seq_architecture(args): - args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024) - args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096) - args.decoder_layers = getattr(args, "decoder_layers", 10) - args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0) - args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16) - args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False) - args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False) - args.no_token_positional_embeddings = getattr( - args, "no_token_positional_embeddings", False - ) - args.decoder_dropout = getattr(args, "decoder_dropout", 0) - args.decoder_attention_dropout = getattr(args, "decoder_attention_dropout", 0) - args.decoder_activation_dropout = getattr(args, "decoder_activation_dropout", 0) - args.share_decoder_input_output_embed = getattr( - args, "share_decoder_input_output_embed", False - ) - - base_architecture(args) diff --git a/fairseq/models/wav2vec/wav2vec2_scribblelens.py b/fairseq/models/wav2vec/wav2vec2_scribblelens.py index eb256879e3..62d3374b73 100644 --- a/fairseq/models/wav2vec/wav2vec2_scribblelens.py +++ b/fairseq/models/wav2vec/wav2vec2_scribblelens.py @@ -3,19 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import logging import math -import numpy as np +from dataclasses import dataclass, field +from typing import List, Optional, Tuple, Dict, Any +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - -from typing import List, Tuple - from fairseq import utils from fairseq.data.data_utils import compute_mask_indices -from fairseq.models import BaseFairseqModel, register_model, register_model_architecture +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model, probed_model from fairseq.modules import ( Fp32GroupNorm, Fp32LayerNorm, @@ -29,339 +28,74 @@ from fairseq.modules.transformer_sentence_encoder import init_bert_params from fairseq.utils import buffered_arange -@register_model("wav2vec2_scribblelens") -class Wav2Vec2ModelSL(BaseFairseqModel): - @staticmethod - def add_args(parser): - """Add model-specific arguments to the parser.""" - - parser.add_argument( - "--extractor-mode", - choices=["default", "layer_norm"], - help="mode for feature extractor. default has a single group norm with d groups in the first conv block, whereas layer_norm has layer norms in every block (meant to use with --normalize)", - ) - - parser.add_argument( - "--encoder-layers", - type=int, - metavar="L", - help="num encoder layers in the transformer", - ) - parser.add_argument( - "--encoder-embed-dim", - type=int, - metavar="H", - help="encoder embedding dimension", - ) - parser.add_argument( - "--encoder-ffn-embed-dim", - type=int, - metavar="F", - help="encoder embedding dimension for FFN", - ) - parser.add_argument( - "--encoder-attention-heads", - type=int, - metavar="A", - help="num encoder attention heads", - ) - parser.add_argument( - "--activation-fn", - choices=utils.get_available_activation_fns(), - help="activation function to use", - ) - - parser.add_argument( - "--dropout", - type=float, - metavar="D", - help="dropout probability for the transformer", - ) - - parser.add_argument( - "--attention-dropout", - type=float, - metavar="D", - help="dropout probability for attention weights", - ) - - parser.add_argument( - "--activation-dropout", - type=float, - metavar="D", - help="dropout probability after activation in FFN", - ) - - parser.add_argument( - "--final-dim", - type=int, - metavar="D", - help="project final representations and targets to this many dimensions", - ) - - parser.add_argument( - "--layer-norm-first", - action="store_true", - help="apply layernorm first in the transformer", - ) - - parser.add_argument( - "--encoder-layerdrop", - type=float, - help="probability of dropping a tarnsformer layer", - ) - - parser.add_argument( - "--conv-feature-layers", - type=str, - metavar="EXPR", - help="convolutional feature extraction layers [(dim, kernel_size, stride), ...]", - ) - - parser.add_argument( - "--logit-temp", type=float, help="temperature to divide logits by" - ) - - parser.add_argument( - "--quantize-targets", action="store_true", help="use quantized targets" - ) - - parser.add_argument( - "--quantize-input", action="store_true", help="use quantized inputs" - ) - - parser.add_argument( - "--same-quantizer", - action="store_true", - help="use same quantizer for inputs and targets", - ) - - parser.add_argument( - "--feature-grad-mult", - type=float, - help="multiply feature extractor var grads by this", - ) - - parser.add_argument( - "--latent-vars", - type=int, - metavar="N", - help="number of latent variables V in each group of the codebook", - ) - - parser.add_argument( - "--latent-groups", - type=int, - metavar="N", - help="number of groups G of latent variables in the codebook", - ) - - parser.add_argument( - "--latent-dim", - type=int, - metavar="N", - help="if set, uses this dimensionality for latent variables. otherwise uses final_dim / latent_groups", - ) - - parser.add_argument("--mask-length", type=int, help="mask length") - - parser.add_argument( - "--mask-prob", type=float, help="probability of replacing a token with mask" - ) - - parser.add_argument( - "--mask-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - help="how to choose masks", - ) - - parser.add_argument( - "--mask-other", - type=float, - help="secondary mask argument (used for more complex distributions), see help in compute_mask_indices", - ) - - parser.add_argument( - "--no-mask-overlap", - action="store_true", - help="whether to allow masks to overlap", - ) - - parser.add_argument( - "--mask-min-space", - type=int, - help="min space between spans (if no overlap is enabled)", - ) - - parser.add_argument( - "--mask-channel-length", - type=int, - help="repeat the mask indices multiple times", - ) - - parser.add_argument( - "--mask-channel-prob", - type=float, - help="probability of replacing a token with mask", - ) - - parser.add_argument( - "--mask-channel-selection", - type=str, - choices=["static", "uniform", "normal", "poisson"], - help="how to choose masks", - ) - - parser.add_argument( - "--mask-channel-other", - type=float, - help="secondary mask argument (used for more complex distributions), see help in compute_mask_indices", - ) - - parser.add_argument( - "--no-mask-channel-overlap", - action="store_true", - help="whether to allow masks to overlap", - ) - - parser.add_argument( - "--mask-channel-min-space", - type=int, - help="min space between spans (if no overlap is enabled)", - ) - - parser.add_argument( - "--dropout-input", - type=float, - metavar="D", - help="dropout to apply to the input (after feat extr)", - ) - - parser.add_argument( - "--dropout-features", - type=float, - metavar="D", - help="dropout to apply to the features (after feat extr)", - ) - - parser.add_argument( - "--num-negatives", type=int, metavar="N", help="number of negative examples" - ) - - parser.add_argument( - "--negatives-from-everywhere", - action="store_true", - help="sample negatives from everywhere, not just masked states", - ) - - parser.add_argument( - "--cross-sample-negatives", - type=int, - metavar="N", - help="num of cross sampled negatives", - ) - - parser.add_argument( - "--codebook-negatives", - type=int, - metavar="N", - help="num of codebook sampled negatives", - ) - - parser.add_argument( - "--conv-pos", - type=int, - metavar="N", - help="number of filters for convolutional positional embeddings", - ) - - parser.add_argument( - "--conv-pos-groups", - type=int, - metavar="N", - help="number of groups for convolutional positional embedding", - ) - - parser.add_argument( - "--latent-temp", - type=str, - metavar="D", - help="temperature for latent variable sampling. can be tuple of 3 values (start, end, decay)", - ) - - parser.add_argument( - "--target-glu", action="store_true", help="adds projection + glu to targets" - ) +from .wav2vec2 import Wav2Vec2Config - parser.add_argument( - "--conv-bias", action="store_true", help="include bias in conv encoder" - ) - parser.add_argument( - "--compute-alignment-metrics", - action="store_true", - help="compute mutual info and rand scores", - ) +@dataclass +class Wav2Vec2SLConfig(Wav2Vec2Config): + probe_defs: Optional[Dict[str, Any]] = field(default=None, metadata={"help": "probes"}) + compute_alignment_metrics: bool = field(default=False, metadata={"help": "compute mutual info and rand scores"}) - def __init__(self, args): +@register_model("wav2vec2_scribblelens", dataclass=Wav2Vec2SLConfig) +class Wav2Vec2ModelSL(BaseFairseqModel, probed_model.ProbedModel): + def __init__(self, cfg: Wav2Vec2Config): super().__init__() - self.args = args + self.cfg = cfg - feature_enc_layers = eval(args.conv_feature_layers) + feature_enc_layers = eval(cfg.conv_feature_layers) self.embed = feature_enc_layers[-1][0] self.feature_extractor = ConvFeatureExtractionModel( conv_layers=feature_enc_layers, dropout=0.0, - mode=args.extractor_mode, - conv_bias=args.conv_bias, + mode=cfg.extractor_mode, + conv_bias=cfg.conv_bias, ) self.post_extract_proj = ( - nn.Linear(self.embed, args.encoder_embed_dim) - if self.embed != args.encoder_embed_dim and not args.quantize_input + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim and not cfg.quantize_input else None ) - self.mask_prob = args.mask_prob - self.mask_selection = args.mask_selection - self.mask_other = args.mask_other - self.mask_length = args.mask_length - self.no_mask_overlap = args.no_mask_overlap - self.mask_min_space = args.mask_min_space + self.mask_prob = cfg.mask_prob + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length = cfg.mask_length + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space - self.mask_channel_prob = args.mask_channel_prob - self.mask_channel_selection = args.mask_channel_selection - self.mask_channel_other = args.mask_channel_other - self.mask_channel_length = args.mask_channel_length - self.no_mask_channel_overlap = args.no_mask_channel_overlap - self.mask_channel_min_space = args.mask_channel_min_space + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space - self.dropout_input = nn.Dropout(args.dropout_input) - self.dropout_features = nn.Dropout(args.dropout_features) + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) - self.feature_grad_mult = args.feature_grad_mult + self.feature_grad_mult = cfg.feature_grad_mult self.quantizer = None self.input_quantizer = None - self.n_negatives = args.num_negatives - self.cross_sample_negatives = args.cross_sample_negatives - self.codebook_negatives = args.codebook_negatives - self.negatives_from_everywhere = args.negatives_from_everywhere + self.n_negatives = cfg.num_negatives + self.cross_sample_negatives = cfg.cross_sample_negatives + self.codebook_negatives = cfg.codebook_negatives + self.negatives_from_everywhere = cfg.negatives_from_everywhere - self.logit_temp = args.logit_temp + self.logit_temp = cfg.logit_temp - final_dim = args.final_dim if args.final_dim > 0 else args.encoder_embed_dim + final_dim = cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim - if args.quantize_targets: - vq_dim = args.latent_dim if args.latent_dim > 0 else final_dim + if cfg.quantize_targets: + vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else final_dim self.quantizer = GumbelVectorQuantizer( dim=self.embed, - num_vars=args.latent_vars, - temp=eval(args.latent_temp), - groups=args.latent_groups, + num_vars=cfg.latent_vars, + temp=cfg.latent_temp, + groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, @@ -370,39 +104,39 @@ def __init__(self, args): else: self.project_q = nn.Linear(self.embed, final_dim) - if args.quantize_input: - if args.same_quantizer and self.quantizer is not None: + if cfg.quantize_input: + if cfg.same_quantizer and self.quantizer is not None: vq_dim = final_dim self.input_quantizer = self.quantizer else: - vq_dim = ( - args.latent_dim if args.latent_dim > 0 else args.encoder_embed_dim - ) + vq_dim = cfg.latent_dim if cfg.latent_dim > 0 else cfg.encoder_embed_dim self.input_quantizer = GumbelVectorQuantizer( dim=self.embed, - num_vars=args.latent_vars, - temp=eval(args.latent_temp), - groups=args.latent_groups, + num_vars=cfg.latent_vars, + temp=cfg.latent_temp, + groups=cfg.latent_groups, combine_groups=False, vq_dim=vq_dim, time_first=True, ) - self.project_inp = nn.Linear(vq_dim, args.encoder_embed_dim) + self.project_inp = nn.Linear(vq_dim, cfg.encoder_embed_dim) self.mask_emb = nn.Parameter( - torch.FloatTensor(args.encoder_embed_dim).uniform_() + torch.FloatTensor(cfg.encoder_embed_dim).uniform_() ) - self.encoder = TransformerEncoder(args) + self.encoder = TransformerEncoder(cfg) self.layer_norm = LayerNorm(self.embed) self.target_glu = None - if args.target_glu: + if cfg.target_glu: self.target_glu = nn.Sequential( nn.Linear(final_dim, final_dim * 2), nn.GLU() ) - self.final_proj = nn.Linear(args.encoder_embed_dim, final_dim) + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + self.attach_probes(cfg.probe_defs) def upgrade_state_dict_named(self, state_dict, name): super().upgrade_state_dict_named(state_dict, name) @@ -410,13 +144,10 @@ def upgrade_state_dict_named(self, state_dict, name): return state_dict @classmethod - def build_model(cls, args, task=None): + def build_model(cls, cfg: Wav2Vec2SLConfig, task=None): """Build a new model instance.""" - # make sure all arguments are present - base_architecture(args) - - return cls(args) + return cls(cfg) def apply_mask(self, x, padding_mask): B, T, C = x.shape @@ -532,8 +263,6 @@ def compute_preds(self, x, y, negatives): return logits def forward(self, source, padding_mask=None, mask=True, features_only=False, alignments=None): - # padding_mask = None # JCh: padding_mask prob need to be True where the data is padded. mask=True => data invalid - if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: @@ -542,9 +271,8 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False, ali with torch.no_grad(): features = self.feature_extractor(source) - compute_alignment_metrics = self.args.compute_alignment_metrics and alignments is not None + compute_alignment_metrics = self.cfg.compute_alignment_metrics and alignments is not None - # features = torch.squeeze(features) # TODO check if this makes sense; also seems length reduction is too big for this input features_pen = features.float().pow(2).mean() features = features.transpose(1, 2) @@ -1007,73 +735,3 @@ def forward( x = self.final_layer_norm(x) return x, attn - - -@register_model_architecture("wav2vec2_scribblelens", "wav2vec2_scribblelens") -def base_architecture(args): - args.extractor_mode = getattr(args, "extractor_mode", "default") - - args.encoder_layers = getattr(args, "encoder_layers", 12) - args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768) - args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072) - args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12) - - args.activation_fn = getattr(args, "activation_fn", "gelu") - - args.dropout = getattr(args, "dropout", 0.1) - args.attention_dropout = getattr(args, "attention_dropout", 0.1) - args.activation_dropout = getattr(args, "activation_dropout", 0.0) - - args.final_dim = getattr(args, "final_dim", 0) - - args.layer_norm_first = getattr(args, "layer_norm_first", False) - args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0) - - conv_feature_layers = "[(512, 10, 5)]" - conv_feature_layers += " + [(512, 8, 4)]" - conv_feature_layers += " + [(512, 4, 2)] * 3" - conv_feature_layers += " + [(512, 1, 1)]" - args.conv_feature_layers = getattr(args, "conv_feature_layers", conv_feature_layers) - - args.logit_temp = getattr(args, "logit_temp", 0.1) - - args.quantize_targets = getattr(args, "quantize_targets", False) - args.quantize_input = getattr(args, "quantize_input", False) - args.same_quantizer = getattr(args, "same_quantizer", False) - - args.feature_grad_mult = getattr(args, "feature_grad_mult", 1.0) - - args.latent_vars = getattr(args, "latent_vars", 320) - args.latent_groups = getattr(args, "latent_groups", 2) - args.latent_dim = getattr(args, "latent_dim", 0) - - args.mask_length = getattr(args, "mask_length", 10) - args.mask_prob = getattr(args, "mask_prob", 0.65) - args.mask_selection = getattr(args, "mask_selection", "static") - args.mask_other = getattr(args, "mask_other", 0) - args.no_mask_overlap = getattr(args, "no_mask_overlap", False) - args.mask_min_space = getattr(args, "mask_min_space", 1) - - args.mask_channel_length = getattr(args, "mask_channel_length", 10) - args.mask_channel_prob = getattr(args, "mask_channel_prob", 0) - args.mask_channel_selection = getattr(args, "mask_channel_selection", "static") - args.mask_channel_other = getattr(args, "mask_channel_other", 0) - args.no_mask_channel_overlap = getattr(args, "no_mask_channel_overlap", False) - args.mask_channel_min_space = getattr(args, "mask_channel_min_space", 1) - - args.dropout_input = getattr(args, "dropout_input", 0) - args.dropout_features = getattr(args, "dropout_features", 0) - - args.num_negatives = getattr(args, "num_negatives", 100) - args.negatives_from_everywhere = getattr(args, "negatives_from_everywhere", False) - args.cross_sample_negatives = getattr(args, "cross_sample_negatives", 0) - args.codebook_negatives = getattr(args, "codebook_negatives", 0) - - args.conv_pos = getattr(args, "conv_pos", 128) - args.conv_pos_groups = getattr(args, "conv_pos_groups", 16) - - args.latent_temp = getattr(args, "latent_temp", "(2,0.5,0.999995)") - - args.target_glu = getattr(args, "target_glu", False) - - args.conv_bias = getattr(args, "conv_bias", False) diff --git a/fairseq/modules/checkpoint_activations.py b/fairseq/modules/checkpoint_activations.py index 1f99c24ca1..e0e5679c5a 100644 --- a/fairseq/modules/checkpoint_activations.py +++ b/fairseq/modules/checkpoint_activations.py @@ -6,6 +6,7 @@ from typing import Any, Dict, List, Tuple, Union import torch +import torch.utils.checkpoint as checkpoint from fairseq import utils @@ -133,7 +134,7 @@ class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(ctx, run_function, parent_ctx_dict, kwarg_keys, *args): if torch.is_grad_enabled(): # grad may be disabled, e.g., during validation - torch.utils.checkpoint.check_backward_validity(args) + checkpoint.check_backward_validity(args) ctx.run_function = run_function ctx.kwarg_keys = kwarg_keys @@ -165,7 +166,7 @@ def backward(ctx, *args): ) tensor_inputs = ctx.saved_tensors - tensor_inputs = torch.utils.checkpoint.detach_variable(tensor_inputs) + tensor_inputs = checkpoint.detach_variable(tensor_inputs) inputs = unpack_non_tensors(tensor_inputs, ctx.packed_non_tensor_inputs) # Store the current states. diff --git a/fairseq/modules/cross_entropy.py b/fairseq/modules/cross_entropy.py index 0d2beb44bb..6f33c24cb5 100644 --- a/fairseq/modules/cross_entropy.py +++ b/fairseq/modules/cross_entropy.py @@ -26,12 +26,14 @@ def _cross_entropy_pytorch(logits, target, ignore_index=None, reduction="mean"): import xentropy_cuda from apex.contrib import xentropy - logger.info("using fused cross entropy") - def cross_entropy(logits, target, ignore_index=-100, reduction="mean"): if logits.device == torch.device("cpu"): return _cross_entropy_pytorch(logits, target, ignore_index, reduction) else: + if not getattr(cross_entropy, "_has_logged_once", False): + logger.info("using fused cross entropy") + cross_entropy._has_logged_once = True + half_to_float = logits.dtype == torch.half losses = xentropy.SoftmaxCrossEntropyLoss.apply( logits, diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py index 47657bb0ab..7113438888 100644 --- a/fairseq/modules/gumbel_vector_quantizer.py +++ b/fairseq/modules/gumbel_vector_quantizer.py @@ -73,7 +73,10 @@ def block(input_dim, output_dim): nn.init.normal_(self.weight_proj.weight, mean=0, std=1) nn.init.zeros_(self.weight_proj.bias) - assert len(temp) == 3, temp + if isinstance(temp, str): + import ast + temp = ast.literal_eval(temp) + assert len(temp) == 3, f"{temp}, {len(temp)}" self.max_temp, self.min_temp, self.temp_decay = temp self.curr_temp = self.max_temp diff --git a/fairseq/modules/linearized_convolution.py b/fairseq/modules/linearized_convolution.py index 09a8f201c0..b36cea91fa 100644 --- a/fairseq/modules/linearized_convolution.py +++ b/fairseq/modules/linearized_convolution.py @@ -38,6 +38,7 @@ def upgrade_state_dict_named(self, state_dict, name): if prefix + "_linearized_weight" in state_dict: del state_dict[prefix + "_linearized_weight"] + @torch.jit.ignore def forward(self, input, incremental_state=None): """ Args: diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 99f95deb5f..6ab86245d2 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -87,14 +87,10 @@ def __init__( self.reset_parameters() self.onnx_trace = False - self.tpu = False def prepare_for_onnx_export_(self): self.onnx_trace = True - def prepare_for_tpu_(self, **kwargs): - self.tpu = True - def reset_parameters(self): if self.qkv_same_dim: # Empirically observed the convergence to be much better with @@ -148,13 +144,15 @@ def forward( if need_head_weights: need_weights = True + is_tpu = query.device.type == "xla" + tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] if ( not self.onnx_trace - and not self.tpu # don't use PyTorch version on TPUs + and not is_tpu # don't use PyTorch version on TPUs and incremental_state is None and not static_kv # A workaround for quantization to work. Otherwise JIT compilation @@ -337,7 +335,7 @@ def forward( if key_padding_mask is not None: # don't attend to padding symbols attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - if not self.tpu: + if not is_tpu: attn_weights = attn_weights.masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf"), diff --git a/fairseq/modules/same_pad.py b/fairseq/modules/same_pad.py index b46f94d635..4c04990ea6 100644 --- a/fairseq/modules/same_pad.py +++ b/fairseq/modules/same_pad.py @@ -8,11 +8,14 @@ class SamePad(nn.Module): - def __init__(self, kernel_size): + def __init__(self, kernel_size, causal=False): super().__init__() - self.remove = kernel_size % 2 == 0 + if causal: + self.remove = kernel_size - 1 + else: + self.remove = 1 if kernel_size % 2 == 0 else 0 def forward(self, x): - if self.remove: - x = x[:, :, :-1] + if self.remove > 0: + x = x[:, :, : -self.remove] return x diff --git a/fairseq/modules/transformer_sentence_encoder.py b/fairseq/modules/transformer_sentence_encoder.py index 208488f562..7a5dcbdde3 100644 --- a/fairseq/modules/transformer_sentence_encoder.py +++ b/fairseq/modules/transformer_sentence_encoder.py @@ -113,7 +113,6 @@ def __init__( self.apply_bert_init = apply_bert_init self.learned_pos_embedding = learned_pos_embedding self.traceable = traceable - self.tpu = False # whether we're on TPU self.embed_tokens = self.build_embedding( self.vocab_size, self.embedding_dim, self.padding_idx @@ -220,9 +219,6 @@ def build_transformer_sentence_encoder_layer( qn_block_size=qn_block_size, ) - def prepare_for_tpu_(self, **kwargs): - self.tpu = True - def forward( self, tokens: torch.Tensor, @@ -231,10 +227,11 @@ def forward( positions: Optional[torch.Tensor] = None, token_embeddings: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + is_tpu = tokens.device.type == "xla" # compute padding mask. This is needed for multi-head attention padding_mask = tokens.eq(self.padding_idx) - if not self.traceable and not self.tpu and not padding_mask.any(): + if not self.traceable and not is_tpu and not padding_mask.any(): padding_mask = None if token_embeddings is not None: diff --git a/fairseq/optim/adagrad.py b/fairseq/optim/adagrad.py index a79b6c39da..4f539541c1 100644 --- a/fairseq/optim/adagrad.py +++ b/fairseq/optim/adagrad.py @@ -37,4 +37,4 @@ def optimizer_config(self): @property def supports_flat_params(self): - return True + return False diff --git a/fairseq/optim/adam.py b/fairseq/optim/adam.py index 9b8ddffd7e..1a4f213707 100644 --- a/fairseq/optim/adam.py +++ b/fairseq/optim/adam.py @@ -5,7 +5,7 @@ import logging import math -from collections import Collection +from collections.abc import Collection from dataclasses import dataclass, field from typing import List @@ -95,7 +95,7 @@ def average_params(self): class Adam(torch.optim.Optimizer): - """Implements Adam algorithm. + r"""Implements Adam algorithm. This implementation is modified from torch.optim.Adam based on: `Fixed Weight Decay Regularization in Adam` diff --git a/fairseq/optim/composite.py b/fairseq/optim/composite.py new file mode 100644 index 0000000000..51e6999368 --- /dev/null +++ b/fairseq/optim/composite.py @@ -0,0 +1,183 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Dict, Any, List, Optional + +import torch.optim +from fairseq.dataclass import FairseqDataclass +from fairseq.optim import FairseqOptimizer, register_optimizer, _build_optimizer +from fairseq.optim.lr_scheduler import FairseqLRScheduler, build_lr_scheduler +from omegaconf import II, open_dict + + +logger = logging.getLogger(__name__) + + +@dataclass +class OptimizerAndSchedulerConfig(FairseqDataclass): + optimizer: Any = None + lr_scheduler: Optional[Any] = None + lr: List[float] = II("optimization.lr") + + +@dataclass +class CompositeOptimizerConfig(FairseqDataclass): + groups: Dict[str, OptimizerAndSchedulerConfig] = field( + default_factory=lambda: {}, + metadata={ + "help": "optimizer name -> optimizer OptimizerAndSchedulerConfig. " + "Configures a different optimizer and (optionally) lr scheduler for each parameter group" + }, + ) + + +@register_optimizer("composite", dataclass=CompositeOptimizerConfig) +class FairseqCompositeOptimizer(FairseqOptimizer): + + optimizers: Dict[str, FairseqOptimizer] = {} + lr_schedulers: Dict[str, FairseqLRScheduler] = {} + lr_scheduler: FairseqLRScheduler = None + _optimizer: torch.optim.Optimizer + + def __init__(self, cfg: CompositeOptimizerConfig, params): + super().__init__(cfg) + + assert ( + len(params) > 1 + ), "Composite optimizer only works when there are multiple parameter groups (try fp16_no_flatten_grads: true)" + + groupped_params = defaultdict(list) + for p in params: + group = getattr(p, "param_group", "default") + groupped_params[group].append(p) + + assert groupped_params.keys() == cfg.groups.keys(), ( + f"Parameter groups {groupped_params.keys()} and optimizer groups {cfg.groups.keys()} are not the same! " + "Try setting 'param_group' on your parameters in the model." + ) + + for group, group_params in groupped_params.items(): + group_cfg = cfg.groups[group] + with open_dict(group_cfg): + group_cfg.optimizer.lr = group_cfg.lr + group_cfg.lr_scheduler.lr = group_cfg.lr + self.optimizers[group] = _build_optimizer(group_cfg.optimizer, group_params) + if group_cfg.lr_scheduler is not None: + self.lr_schedulers[group] = build_lr_scheduler( + group_cfg.lr_scheduler, self.optimizers[group] + ) + + if len(self.lr_schedulers) > 0: + assert len(self.lr_schedulers) == len(self.optimizers), ( + f"Please provide an lr scheduler for each optimizer to use pass_through scheduler. " + f"Optimizers: {self.optimizers}; Lr scheds: {self.lr_schedulers}" + ) + self.lr_scheduler = CompositeLRScheduler(self.lr_schedulers) + + self._optimizer = CompositeOptimizer(self.optimizers) + + @property + def supports_groups(self): + return True + + @property + def param_groups(self): + for opt in self.optimizers.values(): + for group in opt.param_groups: + yield group + + def get_lr(self): + """Return the current learning rate.""" + k = ( + "default" + if "default" in self.optimizers + else next(iter(self.optimizers.keys())) + ) + return self.optimizers[k].param_groups[0]["lr"] + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {k: s.state_dict() for k, s in self.optimizers.items()} + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an LR scheduler state dict.""" + for k, state in state_dict.items(): + if k not in self.optimizers: + # skip extra keys like "loss_scale" added by fp16 optimizer + continue + + overrides = ( + optimizer_overrides[k] + if isinstance(optimizer_overrides, dict) and k in optimizer_overrides + else None + ) + self.optimizers[k].load_state_dict(state, optimizer_overrides=overrides) + + +class CompositeOptimizer(torch.optim.Optimizer): + def __init__(self, optimizers: Dict[str, FairseqOptimizer]): + self.optimizers = optimizers + + @property + def supports_memory_efficient_fp16(self): + return all(o.supports_memory_efficient_fp16 for o in self.optimizers.values()) + + @property + def supports_flat_params(self): + return all(o.supports_flat_params for o in self.optimizers.values()) + + def step(self, closure=None, groups=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for k, opt in self.optimizers.items(): + if groups is None or k in groups: + opt.step() + + return loss + + def zero_grad(self): + for opt in self.optimizers.values(): + opt.zero_grad() + + +class CompositeLRScheduler(FairseqLRScheduler): + def __init__(self, lr_schedulers): + super().__init__(None, None) + + self.lr_schedulers = lr_schedulers + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {k: s.state_dict() for k, s in self.lr_schedulers.items()} + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + for k, state in state_dict.items(): + self.lr_schedulers[k].load_state_dict(state) + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + for s in self.lr_schedulers.values(): + s.step_begin_epoch(epoch) + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + for s in self.lr_schedulers.values(): + s.step(epoch) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return {k: s.step_update(num_updates) for k, s in self.lr_schedulers.items()} diff --git a/fairseq/optim/fairseq_optimizer.py b/fairseq/optim/fairseq_optimizer.py index f9864533b6..a1c1d219a0 100644 --- a/fairseq/optim/fairseq_optimizer.py +++ b/fairseq/optim/fairseq_optimizer.py @@ -109,14 +109,20 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): """Clips gradient norm.""" return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) - def step(self, closure=None, scale=1.0): + def step(self, closure=None, scale=1.0, groups=None): """Performs a single optimization step.""" if self.supports_step_with_scale: - self.optimizer.step(closure, scale=scale) + if self.supports_groups: + self.optimizer.step(closure, scale=scale, groups=groups) + else: + self.optimizer.step(closure, scale=scale) else: if scale != 1.0: self.multiply_grads(1.0 / scale) - self.optimizer.step(closure) + if self.supports_groups: + self.optimizer.step(closure, groups=groups) + else: + self.optimizer.step(closure) def zero_grad(self): """Clears the gradients of all optimized parameters.""" @@ -136,6 +142,12 @@ def supports_step_with_scale(self): return self.optimizer.supports_step_with_scale return False + @property + def supports_groups(self): + if hasattr(self.optimizer, "supports_groups"): + return self.optimizer.supports_groups + return False + @property def supports_flat_params(self): """ diff --git a/fairseq/optim/fp16_optimizer.py b/fairseq/optim/fp16_optimizer.py index 4457023527..a0da4948c8 100644 --- a/fairseq/optim/fp16_optimizer.py +++ b/fairseq/optim/fp16_optimizer.py @@ -65,6 +65,8 @@ def build_fp32_params(cls, args, params, flatten=True): for p in params: p32 = torch.nn.Parameter(p.data.float()) p32.grad = torch.zeros_like(p32.data) + if hasattr(p, "param_group"): + p32.param_group = p.param_group fp32_params.append(p32) return fp32_params @@ -198,15 +200,15 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): return grad_norm - def step(self, closure=None): + def step(self, closure=None, groups=None): """Performs a single optimization step.""" self._sync_fp16_grads_to_fp32() if getattr(self, "supports_step_with_scale", False): - self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor)) + self.fp32_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups) else: self._unscale_grads() - self.fp32_optimizer.step(closure) + self.fp32_optimizer.step(closure, groups=groups) if self.scaler is not None: self.scaler.update() @@ -303,6 +305,10 @@ def optimizer(self): def optimizer(self, optimizer): self.fp32_optimizer.optimizer = optimizer + @property + def lr_scheduler(self): + return getattr(self.fp32_optimizer, "lr_scheduler", None) + @property def optimizer_config(self): return self.fp32_optimizer.optimizer_config @@ -416,14 +422,14 @@ def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): return grad_norm - def step(self, closure=None): + def step(self, closure=None, groups=None): """Performs a single optimization step.""" if getattr(self, "supports_step_with_scale", False): # NOTE(msb) optimizer divides by scale factor - self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor)) + self.wrapped_optimizer.step(closure, scale=(1.0 / self._multiply_factor), groups=groups) else: self._unscale_grads() - self.wrapped_optimizer.step(closure) + self.wrapped_optimizer.step(closure, groups=groups) if self.scaler is not None: self.scaler.update() @@ -514,6 +520,10 @@ def optimizer(self, optimizer): def optimizer_config(self): return self.wrapped_optimizer.optimizer_config + @property + def lr_scheduler(self): + return getattr(self.wrapped_optimizer, "lr_scheduler", None) + def get_lr(self): return self.wrapped_optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py index 646ac66be9..38b57fe54c 100644 --- a/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/cosine_lr_scheduler.py @@ -8,14 +8,14 @@ from dataclasses import dataclass, field from typing import List -from fairseq.dataclass import FairseqDataclass -from omegaconf import II, DictConfig +from omegaconf import II -from . import FairseqLRScheduler, register_lr_scheduler +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler @dataclass -class CosineConfig(FairseqDataclass): +class CosineLRScheduleConfig(FairseqDataclass): warmup_updates: int = field( default=0, metadata={"help": "warmup the learning rate linearly for the first N updates"}, @@ -23,12 +23,14 @@ class CosineConfig(FairseqDataclass): warmup_init_lr: float = field( default=-1, metadata={ - "help": "initial learning rate during warmup phase; default is args.lr" + "help": "initial learning rate during warmup phase; default is cfg.lr" }, ) - max_lr: float = field( - default=1.0, metadata={"help": "max learning rate, must be more than args.lr"} + lr: List[float] = field( + default=II("optimization.lr"), + metadata={"help": "max learning rate, must be more than cfg.min_lr"}, ) + min_lr: float = field(default=0.0, metadata={"help": "min learning rate"}) t_mult: float = field( default=1.0, metadata={"help": "factor to grow the length of each period"} ) @@ -38,38 +40,35 @@ class CosineConfig(FairseqDataclass): lr_shrink: float = field( default=0.1, metadata={"help": "shrink factor for annealing"} ) - # TODO common var for parent class - lr: List[float] = II("optimization.lr") + # This is not required, but is for convenience in inferring lr_period_updates max_update: int = II("optimization.max_update") -@register_lr_scheduler("cosine", dataclass=CosineConfig) -class CosineSchedule(FairseqLRScheduler): +@register_lr_scheduler("cosine", dataclass=CosineLRScheduleConfig) +class CosineLRSchedule(FairseqLRScheduler): """Assign LR based on a cyclical schedule that follows the cosine function. See https://arxiv.org/pdf/1608.03983.pdf for details. We also support a warmup phase where we linearly increase the learning rate from some initial learning rate (``--warmup-init-lr``) until the configured - max learning rate (``--max-lr``). + max learning rate (``--lr``). During warmup:: - lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) + lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates) lr = lrs[update_num] After warmup:: - lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i)) + lr = cfg.min_lr + 0.5*(cfg.lr - cfg.min_lr)*(1 + cos(t_curr / t_i)) where ``t_curr`` is current percentage of updates within the current period range and ``t_i`` is the current period range, which is scaled by ``t_mul`` after every iteration. """ - def __init__( - self, cfg: DictConfig, fairseq_optimizer - ): + def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer): super().__init__(cfg, fairseq_optimizer) if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: raise ValueError( @@ -77,33 +76,27 @@ def __init__( f" Consider --lr-scheduler=fixed instead. ({cfg.lr})" ) - warmup_end_lr = cfg.max_lr - lr = ( - cfg.lr[0] - if isinstance(cfg.lr, Collection) - else cfg.lr - ) - if cfg.warmup_init_lr < 0: - cfg.warmup_init_lr = lr + self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr + assert ( + self.max_lr > cfg.min_lr + ), f"max_lr (={cfg.lr}) must be more than min_lr (={cfg.min_lr})" - self.min_lr = lr - self.max_lr = cfg.max_lr - assert self.max_lr > self.min_lr, "max_lr must be more than lr" + warmup_end_lr = self.max_lr + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = cfg.min_lr self.t_mult = cfg.t_mult self.period = cfg.lr_period_updates if self.period <= 0: assert ( - cfg.max_update >= 0 + cfg.max_update > 0 ), "Either --max_update or --lr-period-updates must be set" self.period = cfg.max_update - cfg.warmup_updates if cfg.warmup_updates > 0: - # linearly warmup for the first args.warmup_updates - self.lr_step = ( - warmup_end_lr - cfg.warmup_init_lr - ) / cfg.warmup_updates + # linearly warmup for the first cfg.warmup_updates + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates else: self.lr_step = 1 @@ -143,7 +136,7 @@ def step_update(self, num_updates): t_curr = curr_updates - (self.period * i) lr_shrink = self.lr_shrink ** i - min_lr = self.min_lr * lr_shrink + min_lr = self.cfg.min_lr * lr_shrink max_lr = self.max_lr * lr_shrink self.lr = min_lr + 0.5 * (max_lr - min_lr) * ( diff --git a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py index d0ac115829..6c12fa56b8 100644 --- a/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/fairseq_lr_scheduler.py @@ -6,14 +6,13 @@ from argparse import Namespace from fairseq.dataclass.utils import gen_parser_from_dataclass - -from .. import FairseqOptimizer +from fairseq.optim import FairseqOptimizer class FairseqLRScheduler(object): def __init__(self, cfg, optimizer): super().__init__() - if not isinstance(optimizer, FairseqOptimizer): + if optimizer is not None and not isinstance(optimizer, FairseqOptimizer): raise ValueError("optimizer must be an instance of FairseqOptimizer") self.cfg = cfg self.optimizer = optimizer diff --git a/fairseq/optim/lr_scheduler/fixed_schedule.py b/fairseq/optim/lr_scheduler/fixed_schedule.py index e91ba86f8c..d0e7e14b7e 100644 --- a/fairseq/optim/lr_scheduler/fixed_schedule.py +++ b/fairseq/optim/lr_scheduler/fixed_schedule.py @@ -3,37 +3,44 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import LegacyFairseqLRScheduler, register_lr_scheduler +from dataclasses import dataclass, field +from typing import Optional, List +from omegaconf import II +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler("fixed") -class FixedSchedule(LegacyFairseqLRScheduler): - """Decay the LR on a fixed schedule.""" - def __init__(self, args, optimizer): - super().__init__(args, optimizer) +@dataclass +class FixedLRScheduleConfig(FairseqDataclass): + force_anneal: Optional[int] = field( + default=None, + metadata={"help": "force annealing at specified epoch"}, + ) + lr_shrink: float = field( + default=0.1, + metadata={"help": "shrink factor for annealing, lr_new = (lr * lr_shrink)"}, + ) + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + lr: List[float] = II("optimization.lr") + - # set defaults - args.warmup_updates = getattr(args, "warmup_updates", 0) or 0 +@register_lr_scheduler("fixed", dataclass=FixedLRScheduleConfig) +class FixedLRSchedule(FairseqLRScheduler): + """Decay the LR on a fixed schedule.""" - self.lr = args.lr[0] - if args.warmup_updates > 0: - self.warmup_factor = 1.0 / args.warmup_updates + def __init__(self, cfg: FixedLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + + self.lr = cfg.lr[0] + if cfg.warmup_updates > 0: + self.warmup_factor = 1.0 / cfg.warmup_updates else: self.warmup_factor = 1 - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - # fmt: off - parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', - help='force annealing at specified epoch (epochs start at 1)') - parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', - help='shrink factor for annealing, lr_new = (lr * lr_shrink)') - parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', - help='warmup the learning rate linearly for the first N updates') - # fmt: on - def state_dict(self): return {"lr": self.lr} @@ -42,14 +49,14 @@ def load_state_dict(self, state_dict): self.lr = state_dict["lr"] def get_next_lr(self, epoch): - lrs = self.args.lr - if self.args.force_anneal is None or epoch < self.args.force_anneal: + lrs = self.cfg.lr + if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal: # use fixed LR schedule next_lr = lrs[min(epoch - 1, len(lrs) - 1)] else: # annneal based on lr_shrink - next_lr = lrs[-1] * self.args.lr_shrink ** ( - epoch + 1 - self.args.force_anneal + next_lr = lrs[-1] * self.cfg.lr_shrink ** ( + epoch + 1 - self.cfg.force_anneal ) return next_lr @@ -61,8 +68,8 @@ def step_begin_epoch(self, epoch): def step_update(self, num_updates): """Update the learning rate after each update.""" - if self.args.warmup_updates > 0 and num_updates < self.args.warmup_updates: - self.warmup_factor = (num_updates + 1) / float(self.args.warmup_updates) + if self.cfg.warmup_updates > 0 and num_updates < self.cfg.warmup_updates: + self.warmup_factor = (num_updates + 1) / float(self.cfg.warmup_updates) self.optimizer.set_lr(self.warmup_factor * self.lr) else: self.optimizer.set_lr(self.lr) diff --git a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py index c42e090677..d9321577bb 100644 --- a/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py +++ b/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py @@ -7,14 +7,14 @@ from dataclasses import dataclass, field from typing import List -from fairseq.dataclass import FairseqDataclass -from omegaconf import II, DictConfig +from omegaconf import II -from . import FairseqLRScheduler, register_lr_scheduler +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler @dataclass -class InverseSquareRootScheduleConfig(FairseqDataclass): +class InverseSquareRootLRScheduleConfig(FairseqDataclass): warmup_updates: int = field( default=4000, metadata={"help": "warmup the learning rate linearly for the first N updates"}, @@ -22,14 +22,13 @@ class InverseSquareRootScheduleConfig(FairseqDataclass): warmup_init_lr: float = field( default=-1, metadata={ - "help": "initial learning rate during warmup phase; default is args.lr" + "help": "initial learning rate during warmup phase; default is cfg.lr" }, ) - # TODO common vars at parent class lr: List[float] = II("optimization.lr") -@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootScheduleConfig) +@register_lr_scheduler("inverse_sqrt", dataclass=InverseSquareRootLRScheduleConfig) class InverseSquareRootSchedule(FairseqLRScheduler): """Decay the LR based on the inverse square root of the update number. @@ -40,36 +39,28 @@ class InverseSquareRootSchedule(FairseqLRScheduler): During warmup:: - lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) + lrs = torch.linspace(cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates) lr = lrs[update_num] After warmup:: - decay_factor = args.lr * sqrt(args.warmup_updates) + decay_factor = cfg.lr * sqrt(cfg.warmup_updates) lr = decay_factor / sqrt(update_num) """ - def __init__(self, cfg: DictConfig, optimizer): + def __init__(self, cfg: InverseSquareRootLRScheduleConfig, optimizer): super().__init__(cfg, optimizer) if isinstance(cfg.lr, Collection) and len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with inverse_sqrt." " Consider --lr-scheduler=fixed instead." ) - warmup_end_lr = ( - cfg.lr[0] - if isinstance(cfg.lr, Collection) - else cfg.lr - ) + warmup_end_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr if cfg.warmup_init_lr < 0: - cfg.warmup_init_lr = ( - 0 if cfg.warmup_updates > 0 else warmup_end_lr - ) + cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr - # linearly warmup for the first args.warmup_updates - self.lr_step = ( - warmup_end_lr - cfg.warmup_init_lr - ) / cfg.warmup_updates + # linearly warmup for the first cfg.warmup_updates + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates # then, decay prop. to the inverse square root of the update number self.decay_factor = warmup_end_lr * cfg.warmup_updates ** 0.5 diff --git a/fairseq/optim/lr_scheduler/manual_lr_scheduler.py b/fairseq/optim/lr_scheduler/manual_lr_scheduler.py new file mode 100644 index 0000000000..0269a1e285 --- /dev/null +++ b/fairseq/optim/lr_scheduler/manual_lr_scheduler.py @@ -0,0 +1,110 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import LegacyFairseqLRScheduler, register_lr_scheduler +import logging +import ast + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +@register_lr_scheduler("manual") +class ManualSchedule(LegacyFairseqLRScheduler): + """Decay the LR on a manual schedule.""" + + def __init__(self, args, optimizer): + super().__init__(args, optimizer) + + self.epoch2lr = self.parse_manuallr_args(args.epoch2lr) + self.update2lr = self.parse_manuallr_args(args.update2lr) + logger.info("@@@ ManualSchedule epoch2lr={}".format(self.epoch2lr)) + logger.info("@@@ ManualSchedule update2lr={}".format(self.update2lr)) + + if 1 in self.epoch2lr: + self.lr = self.epoch2lr[1] + elif 1 in self.update2lr: + self.lr = self.update2lr[1] + else: + self.lr = args.lr[0] + self.optimizer.set_lr(self.lr) # Set the beginning of the epoch. + + def parse_manuallr_args(self, lr_args_str): + lr_dict = ast.literal_eval(lr_args_str.replace(' ', '')) + if not isinstance(lr_dict, dict): + raise ValueError("epoch2lr/update2lr must be abel to evaluated to a dict") + + lr_args = {} + logger.info("@@@ after parsing input dictionary lr_dict = {}".format(lr_dict)) + for key, val in lr_dict.items(): + if "," in key: + for k in key.split(","): + lr_args[int(k)] = float(val) + elif "-" in key: + s = int(key.split("-")[0]) + e = int(key.split("-")[1]) + for k in range(s, e + 1, 1): + lr_args[k] = float(val) + else: + lr_args[int(key)] = float(val) + + return lr_args + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument( + "--epoch2lr", + type=str, + metavar="DICT", + default="{}", + help="a dictionary used to set lr for each epoch manually", + ) + parser.add_argument( + "--update2lr", + type=str, + metavar="DICT", + default="{}", + help="a dictionary used to set lr for each update manually", + ) + # fmt: on + + def state_dict(self): + return {"lr": self.lr} + + def load_state_dict(self, state_dict): + if "lr" in state_dict: + self.lr = state_dict["lr"] + + def get_next_lr(self, epoch): + manual_keys = [k for k in self.epoch2lr if k <= epoch] + if manual_keys: + manual_lr = self.epoch2lr[max(manual_keys)] + else: + logger.warning("@@@ epoch={} does not exist in manual lr input. epoch2lr={}...".format( + epoch, list(self.epoch2lr.items())[:min(10, len(self.epoch2lr.keys())-1)] + )) + manual_lr = self.optimizer.get_lr() + return manual_lr + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + manual_keys = [k for k in self.update2lr if k <= num_updates] + if manual_keys: + manual_lr = self.update2lr[max(manual_keys)] + else: + logger.warning("epoch={} does not exist in manual lr input update2lr={}...".format( + num_updates, list(self.update2lr.items())[:min(10, len(self.update2lr.keys())-1)])) + manual_lr = self.optimizer.get_lr() + + self.optimizer.set_lr(manual_lr) + return self.optimizer.get_lr() diff --git a/fairseq/optim/lr_scheduler/pass_through.py b/fairseq/optim/lr_scheduler/pass_through.py new file mode 100644 index 0000000000..2f93db328c --- /dev/null +++ b/fairseq/optim/lr_scheduler/pass_through.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class PassThroughScheduleConfig(FairseqDataclass): + pass + + +@register_lr_scheduler("pass_through", dataclass=PassThroughScheduleConfig) +class PassThroughScheduleSchedule(FairseqLRScheduler): + """Delegate lr scheduling to the optimizer.""" + + def __init__(self, cfg: PassThroughScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + assert ( + hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None + ), "Pass-through schedule can only be used with optimizers with their own schedulers" + + def state_dict(self): + return self.optimizer.lr_scheduler.state_dict() + + def load_state_dict(self, state_dict): + self.optimizer.lr_scheduler.load_state_dict(state_dict) + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + return self.optimizer.lr_scheduler.step_begin_epoch(epoch) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return self.optimizer.lr_scheduler.step_update(num_updates) diff --git a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py index 63adc740a9..b8109a7c1e 100644 --- a/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py +++ b/fairseq/optim/lr_scheduler/polynomial_decay_schedule.py @@ -3,53 +3,61 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from . import LegacyFairseqLRScheduler, register_lr_scheduler +from dataclasses import dataclass, field +from typing import Optional, List +from omegaconf import II +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler("polynomial_decay") -class PolynomialDecaySchedule(LegacyFairseqLRScheduler): + +@dataclass +class PolynomialDecayLRScheduleConfig(FairseqDataclass): + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + force_anneal: Optional[int] = field( + default=None, + metadata={"help": "force annealing at specified epoch"}, + ) + end_learning_rate: float = field( + default=0.0, + metadata={"help": "learning rate to decay to"}, + ) + power: float = field( + default=1.0, + metadata={"help": "decay exponent"}, + ) + total_num_update: float = field( + default=II("optimization.max_update"), + metadata={"help": "total number of updates over which to decay learning rate"}, + ) + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("polynomial_decay", dataclass=PolynomialDecayLRScheduleConfig) +class PolynomialDecayLRSchedule(FairseqLRScheduler): """Decay the LR on a fixed schedule.""" - def __init__(self, args, optimizer): - super().__init__(args, optimizer) + def __init__(self, cfg: PolynomialDecayLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) - # set defaults - args.warmup_updates = getattr(args, "warmup_updates", 0) or 0 + assert cfg.total_num_update > 0 - self.lr = args.lr[0] - if args.warmup_updates > 0: - self.warmup_factor = 1.0 / args.warmup_updates + self.lr = cfg.lr[0] + if cfg.warmup_updates > 0: + self.warmup_factor = 1.0 / cfg.warmup_updates else: self.warmup_factor = 1 - self.end_learning_rate = args.end_learning_rate - self.total_num_update = args.total_num_update - self.power = args.power + self.end_learning_rate = cfg.end_learning_rate + self.total_num_update = cfg.total_num_update + self.power = cfg.power self.optimizer.set_lr(self.warmup_factor * self.lr) - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - parser.add_argument( - "--force-anneal", - "--fa", - type=int, - metavar="N", - help="force annealing at specified epoch", - ) - parser.add_argument( - "--warmup-updates", - default=0, - type=int, - metavar="N", - help="warmup the learning rate linearly for the first N updates", - ) - parser.add_argument("--end-learning-rate", default=0.0, type=float) - parser.add_argument("--power", default=1.0, type=float) - parser.add_argument("--total-num-update", default=1000000, type=int) - def get_next_lr(self, epoch): - lrs = self.args.lr - if self.args.force_anneal is None or epoch < self.args.force_anneal: + lrs = self.cfg.lr + if self.cfg.force_anneal is None or epoch < self.cfg.force_anneal: # use fixed LR schedule next_lr = lrs[min(epoch, len(lrs) - 1)] else: @@ -65,13 +73,13 @@ def step_begin_epoch(self, epoch): def step_update(self, num_updates): """Update the learning rate after each update.""" - if self.args.warmup_updates > 0 and num_updates <= self.args.warmup_updates: - self.warmup_factor = num_updates / float(self.args.warmup_updates) + if self.cfg.warmup_updates > 0 and num_updates <= self.cfg.warmup_updates: + self.warmup_factor = num_updates / float(self.cfg.warmup_updates) lr = self.warmup_factor * self.lr elif num_updates >= self.total_num_update: lr = self.end_learning_rate else: - warmup = self.args.warmup_updates + warmup = self.cfg.warmup_updates lr_range = self.lr - self.end_learning_rate pct_remaining = 1 - (num_updates - warmup) / ( self.total_num_update - warmup diff --git a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py index 82bb36efe9..6e29ba79b6 100644 --- a/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py +++ b/fairseq/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -3,13 +3,59 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import dataclass, field +from typing import List + import torch.optim.lr_scheduler +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + -from . import LegacyFairseqLRScheduler, register_lr_scheduler +@dataclass +class ReduceLROnPlateauLRScheduleConfig(FairseqDataclass): + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + lr_threshold: float = field( + default=1e-4, + metadata={ + "help": ( + "threshold for measuring the new optimum, to only focus on " + "significant changes" + ) + }, + ) + lr_patience: int = field( + default=0, + metadata={ + "help": ( + "number of epochs with no improvement after which learning rate will " + "be reduced" + ) + }, + ) + warmup_updates: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + warmup_init_lr: float = field( + default=-1, + metadata={ + "help": "initial learning rate during warmup phase; default is cfg.lr" + }, + ) + lr: List[float] = II("optimization.lr") + maximize_best_checkpoint_metric: bool = II( + "checkpoint.maximize_best_checkpoint_metric" + ) -@register_lr_scheduler("reduce_lr_on_plateau") -class ReduceLROnPlateau(LegacyFairseqLRScheduler): +@register_lr_scheduler( + "reduce_lr_on_plateau", dataclass=ReduceLROnPlateauLRScheduleConfig +) +class ReduceLROnPlateauLRSchedule(FairseqLRScheduler): """ Decay the LR by a factor every time the validation loss plateaus. Also comes with optional warmup phase, where we linearly increase @@ -21,61 +67,43 @@ class ReduceLROnPlateau(LegacyFairseqLRScheduler): During warmup:: lrs = torch.linspace( - args.warmup_init_lr, args.lr, args.warmup_updates + cfg.warmup_init_lr, cfg.lr, cfg.warmup_updates ) lr = lrs[update_num] """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__(self, cfg: ReduceLROnPlateauLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with reduce_lr_on_plateau." " Consider --lr-scheduler=fixed instead." ) self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( self.optimizer.optimizer, - patience=args.lr_patience, - factor=args.lr_shrink, - mode="max" if args.maximize_best_checkpoint_metric else "min", - threshold=args.lr_threshold, + patience=cfg.lr_patience, + factor=cfg.lr_shrink, + mode="max" if cfg.maximize_best_checkpoint_metric else "min", + threshold=cfg.lr_threshold, ) - warmup_end_lr = args.lr[0] - # if no warm up, sets initial lr to be args.lr[0] - if args.warmup_init_lr < 0: - args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr + warmup_end_lr = cfg.lr[0] + # if no warm up, sets initial lr to be cfg.lr[0] + if cfg.warmup_init_lr < 0: + cfg.warmup_init_lr = 0 if cfg.warmup_updates > 0 else warmup_end_lr - # linearly warmup for the first args.warmup_updates - if args.warmup_updates > 0: - self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + # linearly warmup for the first cfg.warmup_updates + if cfg.warmup_updates > 0: + self.lr_step = (warmup_end_lr - cfg.warmup_init_lr) / cfg.warmup_updates # this flag is either set from arg when no warm up, or set by # step_update() when warmup finishes - self.warmup_end = True if args.warmup_updates <= 0 else False + self.warmup_end = True if cfg.warmup_updates <= 0 else False # initial learning rate # this self.lr is used only during init and/or warm up period - self.lr = args.warmup_init_lr + self.lr = cfg.warmup_init_lr self.optimizer.set_lr(self.lr) - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - # fmt: off - parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', - help='shrink factor for annealing, lr_new = (lr * lr_shrink)') - parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT', - help='threshold for measuring the new optimum, ' - 'to only focus on significant changes') - parser.add_argument('--lr-patience', default=0, type=int, - help='number of epochs with no improvement after which ' - 'learning rate will be reduced') - parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', - help='warmup the learning rate linearly for the first N updates') - parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', - help='initial learning rate during warmup phase; default is args.lr') - # fmt: on - def state_dict(self): """Return the LR scheduler state dict.""" return { @@ -104,9 +132,9 @@ def step_update(self, num_updates): """ Update the learning rate after each update.""" # if there is warmup - if self.args.warmup_updates > 0: - if num_updates <= self.args.warmup_updates: - self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + if self.cfg.warmup_updates > 0: + if num_updates <= self.cfg.warmup_updates: + self.lr = self.cfg.warmup_init_lr + num_updates * self.lr_step self.optimizer.set_lr(self.lr) else: if self.warmup_end is False: diff --git a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py index c573237f11..4d5547c39b 100644 --- a/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -4,12 +4,51 @@ # LICENSE file in the root directory of this source tree. import math - -from . import LegacyFairseqLRScheduler, register_lr_scheduler - - -@register_lr_scheduler("tri_stage") -class TriStageLRSchedule(LegacyFairseqLRScheduler): +from dataclasses import dataclass, field +from typing import Optional, List, Tuple +from omegaconf import II + +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler + + +@dataclass +class TriStageLRScheduleConfig(FairseqDataclass): + warmup_steps: int = field( + default=0, + metadata={"help": "warmup the learning rate linearly for the first N updates"}, + ) + hold_steps: int = field( + default=0, + metadata={"help": "steps in hold stage"}, + ) + decay_steps: int = field( + default=0, + metadata={"help": "steps in decay stages"}, + ) + phase_ratio: Optional[Tuple[float, float, float]] = field( + default=None, + metadata={ + "help": ( + "if set, automatically sets warmup/hold/decay steps to the ratio " + "specified here from max_updates. the ratios must add up to 1.0" + ) + }, + ) + init_lr_scale: float = field( + default=0.01, + metadata={"help": "initial learning rate scale during warmup phase"}, + ) + final_lr_scale: float = field( + default=0.01, + metadata={"help": "final learning rate scale"}, + ) + max_update: float = II("optimization.max_update") + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("tri_stage", dataclass=TriStageLRScheduleConfig) +class TriStageLRSchedule(FairseqLRScheduler): """Tristage learning rate schedulr Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf @@ -29,92 +68,63 @@ class TriStageLRSchedule(LegacyFairseqLRScheduler): During warmup:: - init_lr = args.init_lr_scale * args.lr - lrs = torch.linspace(init_lr, args.lr, args.warmup_steps) + init_lr = cfg.init_lr_scale * cfg.lr + lrs = torch.linspace(init_lr, cfg.lr, cfg.warmup_steps) lr = lrs[update_num] During hold:: - lr = args.lr + lr = cfg.lr During decay:: - decay_factor = - math.log(args.final_lr_scale) / args.decay_steps - lr = args.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor) + decay_factor = - math.log(cfg.final_lr_scale) / cfg.decay_steps + lr = cfg.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor) After that:: - lr = args.lr * args.final_lr_scale + lr = cfg.lr * cfg.final_lr_scale """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__(self, cfg: TriStageLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with tri-stage lr." " Consider --lr-scheduler=fixed instead." ) # calculate LR at each point - self.peak_lr = args.lr[0] - self.init_lr = args.init_lr_scale * args.lr[0] - self.final_lr = args.final_lr_scale * args.lr[0] + self.peak_lr = cfg.lr[0] + self.init_lr = cfg.init_lr_scale * cfg.lr[0] + self.final_lr = cfg.final_lr_scale * cfg.lr[0] + + if cfg.phase_ratio is not None: + assert cfg.max_update > 0 + assert sum(cfg.phase_ratio) == 1, "phase ratios must add up to 1" + self.warmup_steps = int(cfg.max_update * cfg.phase_ratio[0]) + self.hold_steps = int(cfg.max_update * cfg.phase_ratio[1]) + self.decay_steps = int(cfg.max_update * cfg.phase_ratio[2]) + else: + self.warmup_steps = cfg.warmup_steps + self.hold_steps = cfg.hold_steps + self.decay_steps = cfg.decay_steps - # remember the steps at each stage - self.warmup_steps = args.warmup_steps - self.hold_steps = args.hold_steps - self.decay_steps = args.decay_steps + assert ( + self.warmup_steps + self.hold_steps + self.decay_steps > 0 + ), "please specify steps or phase_ratio" self.warmup_rate = ( (self.peak_lr - self.init_lr) / self.warmup_steps if self.warmup_steps != 0 else 0 ) - self.decay_factor = -math.log(args.final_lr_scale) / args.decay_steps + self.decay_factor = -math.log(cfg.final_lr_scale) / self.decay_steps # initial learning rate self.lr = self.init_lr self.optimizer.set_lr(self.lr) - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - # fmt: off - parser.add_argument( - '--warmup-steps', - default=4000, - type=int, - metavar='N', - help='warmup the learning rate linearly for the first N updates' - ) - parser.add_argument( - '--hold-steps', - default=20000, - type=int, - metavar='N', - help='steps in hold stage.' - ) - parser.add_argument( - '--decay-steps', - default=60000, - type=int, - metavar='N', - help='steps in decay stages' - ) - parser.add_argument( - '--init-lr-scale', - default=0.01, - type=float, - help=""" - initial learning rate scale during warmup phase; default is 0.01""") - parser.add_argument( - '--final-lr-scale', - default=0.01, - type=float, - help="final learning rate scale; default to 0.01" - ) - # fmt: on - def _decide_stage(self, update_step): """ return stage, and the corresponding steps within the current stage diff --git a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py index 0f3193f2b8..bfe2a0d381 100644 --- a/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py +++ b/fairseq/optim/lr_scheduler/triangular_lr_scheduler.py @@ -4,52 +4,61 @@ # LICENSE file in the root directory of this source tree. import math +from dataclasses import dataclass, field +from typing import List -from . import LegacyFairseqLRScheduler, register_lr_scheduler +from omegaconf import II +from fairseq.dataclass import FairseqDataclass +from fairseq.optim.lr_scheduler import FairseqLRScheduler, register_lr_scheduler -@register_lr_scheduler("triangular") -class TriangularSchedule(LegacyFairseqLRScheduler): + +@dataclass +class TriangularLRScheduleConfig(FairseqDataclass): + max_lr: float = field( + default="???", metadata={"help": "max learning rate, must be more than cfg.lr"} + ) + lr_period_updates: float = field( + default=5000, + metadata={"help": "initial number of updates per period (cycle length)"}, + ) + lr_shrink: float = field( + default=0.1, metadata={"help": "shrink factor for annealing"} + ) + shrink_min: bool = field( + default=False, metadata={"help": "if set, also shrinks min lr"} + ) + lr: List[float] = II("optimization.lr") + + +@register_lr_scheduler("triangular", dataclass=TriangularLRScheduleConfig) +class TriangularLRSchedule(FairseqLRScheduler): """Assign LR based on a triangular cyclical schedule. See https://arxiv.org/pdf/1506.01186.pdf for details. """ - def __init__(self, args, optimizer): - super().__init__(args, optimizer) - if len(args.lr) > 1: + def __init__(self, cfg: TriangularLRScheduleConfig, optimizer): + super().__init__(cfg, optimizer) + if len(cfg.lr) > 1: raise ValueError( "Cannot use a fixed learning rate schedule with triangular." " Consider --lr-scheduler=fixed instead." ) - lr = args.lr[0] + lr = cfg.lr[0] - assert args.max_lr > lr, "max_lr must be more than lr" + assert cfg.max_lr > lr, "max_lr must be more than lr" self.min_lr = lr - self.max_lr = args.max_lr - self.stepsize = args.lr_period_updates // 2 - self.lr_shrink = args.lr_shrink - self.shrink_min = args.shrink_min + self.max_lr = cfg.max_lr + self.stepsize = cfg.lr_period_updates // 2 + self.lr_shrink = cfg.lr_shrink + self.shrink_min = cfg.shrink_min # initial learning rate self.lr = self.min_lr self.optimizer.set_lr(self.lr) - @staticmethod - def add_args(parser): - """Add arguments to the parser for this LR scheduler.""" - # fmt: off - parser.add_argument('--max-lr', required=True, type=float, metavar='LR', - help='max learning rate, must be more than args.lr') - parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', - help='initial number of updates per period (cycle length)') - parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', - help='shrink factor for annealing') - parser.add_argument('--shrink-min', action='store_true', - help='if set, also shrinks min lr') - # fmt: on - def step(self, epoch, val_loss=None): """Update the learning rate at the end of the given epoch.""" super().step(epoch, val_loss) diff --git a/fairseq/optim/nag.py b/fairseq/optim/nag.py index 3982a8271d..4f652fe6d3 100644 --- a/fairseq/optim/nag.py +++ b/fairseq/optim/nag.py @@ -3,7 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from collections import Collection +from collections.abc import Collection from dataclasses import dataclass, field from typing import List @@ -75,7 +75,7 @@ def step(self, closure=None): momentum = group["momentum"] lr = group["lr"] lr_old = group.get("lr_old", lr) - lr_correct = lr / lr_old + lr_correct = lr / lr_old if lr_old > 0 else lr for p in group["params"]: if p.grad is None: diff --git a/fairseq/registry.py b/fairseq/registry.py index 7a3dd1d1bf..3fbaeac301 100644 --- a/fairseq/registry.py +++ b/fairseq/registry.py @@ -45,7 +45,7 @@ def build_x(cfg: Union[DictConfig, str, Namespace], *extra_args, **extra_kwargs) else: choice = getattr(cfg, registry_name, None) if choice in DATACLASS_REGISTRY: - cfg = populate_dataclass(cfg, DATACLASS_REGISTRY[choice]()) + cfg = populate_dataclass(DATACLASS_REGISTRY[choice](), cfg) if choice is None: if required: diff --git a/fairseq/sequence_generator.py b/fairseq/sequence_generator.py index 9c5423e2b1..bd46f9e5b9 100644 --- a/fairseq/sequence_generator.py +++ b/fairseq/sequence_generator.py @@ -11,7 +11,6 @@ from fairseq import search, utils from fairseq.data import data_utils from fairseq.models import FairseqIncrementalDecoder -from fairseq.models.fairseq_encoder import EncoderOut from torch import Tensor @@ -279,8 +278,8 @@ def _generate( cand_size = 2 * beam_size # 2 x beam size in case half are EOS # offset arrays for converting between different indexing schemes - bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens) - cand_offsets = torch.arange(0, cand_size).type_as(tokens) + bbsz_offsets = (torch.arange(0, bsz) * beam_size).unsqueeze(1).type_as(tokens).to(src_tokens.device) + cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) reorder_state: Optional[Tensor] = None batch_idxs: Optional[Tensor] = None @@ -420,7 +419,7 @@ def _generate( break if self.search.stop_on_max_len and step >= max_len: break - assert step < max_len + assert step < max_len, f"{step} < {max_len}" # Remove finalized sentences (ones for which {beam_size} # finished hypotheses have been generated) from the batch. @@ -806,13 +805,13 @@ def forward_encoder(self, net_input: Dict[str, Tensor]): def forward_decoder( self, tokens, - encoder_outs: List[EncoderOut], + encoder_outs: List[Dict[str, List[Tensor]]], incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], temperature: float = 1.0, ): log_probs = [] avg_attn: Optional[Tensor] = None - encoder_out: Optional[EncoderOut] = None + encoder_out: Optional[Dict[str, List[Tensor]]] = None for i, model in enumerate(self.models): if self.has_encoder(): encoder_out = encoder_outs[i] @@ -868,7 +867,7 @@ def forward_decoder( return avg_probs, avg_attn @torch.jit.export - def reorder_encoder_out(self, encoder_outs: Optional[List[EncoderOut]], new_order): + def reorder_encoder_out(self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order): """ Reorder encoder output according to *new_order*. @@ -879,7 +878,7 @@ def reorder_encoder_out(self, encoder_outs: Optional[List[EncoderOut]], new_orde Returns: *encoder_out* rearranged according to *new_order* """ - new_outs: List[EncoderOut] = [] + new_outs: List[Dict[str, List[Tensor]]] = [] if not self.has_encoder(): return new_outs for i, model in enumerate(self.models): @@ -904,7 +903,7 @@ def reorder_incremental_state( class SequenceGeneratorWithAlignment(SequenceGenerator): - def __init__(self, models, tgt_dict, left_pad_target=False, **kwargs): + def __init__(self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs): """Generates translations of a given source sentence. Produces alignments following "Jointly Learning to Align and @@ -918,6 +917,11 @@ def __init__(self, models, tgt_dict, left_pad_target=False, **kwargs): super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) self.left_pad_target = left_pad_target + if print_alignment == "hard": + self.extract_alignment = utils.extract_hard_alignment + elif print_alignment == "soft": + self.extract_alignment = utils.extract_soft_alignment + @torch.no_grad() def generate(self, models, sample, **kwargs): finalized = super()._generate(sample, **kwargs) @@ -946,7 +950,7 @@ def generate(self, models, sample, **kwargs): # Process the attn matrix to extract hard alignments. for i in range(bsz * beam_size): - alignment = utils.extract_hard_alignment( + alignment = self.extract_alignment( attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos ) finalized[i // beam_size][i % beam_size]["alignment"] = alignment diff --git a/fairseq/tasks/__init__.py b/fairseq/tasks/__init__.py index 415f15e708..0e55d093b1 100644 --- a/fairseq/tasks/__init__.py +++ b/fairseq/tasks/__init__.py @@ -9,9 +9,8 @@ import os from fairseq.dataclass import FairseqDataclass -from fairseq.dataclass.utils import merge_with_parent +from fairseq.dataclass.utils import merge_with_parent, populate_dataclass from hydra.core.config_store import ConfigStore -from omegaconf import DictConfig from .fairseq_task import FairseqTask, LegacyFairseqTask # noqa @@ -22,13 +21,16 @@ TASK_CLASS_NAMES = set() -def setup_task(cfg: DictConfig, **kwargs): +def setup_task(cfg: FairseqDataclass, **kwargs): task = None task_name = getattr(cfg, "task", None) if isinstance(task_name, str): # legacy tasks task = TASK_REGISTRY[task_name] + if task_name in TASK_DATACLASS_REGISTRY: + dc = TASK_DATACLASS_REGISTRY[task_name] + cfg = populate_dataclass(dc(), cfg) else: task_name = getattr(cfg, "_name", None) diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index d1b6bf1c14..6ea40a813f 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -5,11 +5,11 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. -import editdistance import os import sys import torch +from argparse import Namespace from dataclasses import dataclass, field from typing import Optional, Any from omegaconf import MISSING @@ -71,7 +71,7 @@ class AudioPretrainingConfig(FairseqDataclass): metadata={"help": "beam search config for evaluating wer during training"}, ) eval_wer_tokenizer: Any = field( - default="space", + default=None, metadata={"help": "tokenizer config for evaluating wer during training"}, ) eval_wer_post_process: str = field( @@ -106,6 +106,7 @@ def __init__( self._source_dictionary = source_dictionary if cfg.eval_wer: assert cfg.labels is not None, "eval_wer can only be set during fine-tuning" + self.blank_symbol = "" @classmethod def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): @@ -123,25 +124,28 @@ def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): return cls(cfg, target_dictionary=target_dictionary) - def load_dataset(self, split, **kwargs): - """Load a given dataset split. + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + task_cfg = task_cfg or self.cfg - Args: - split (str): name of the split (e.g., train, valid, test) - """ - manifest = os.path.join(self.cfg.data, "{}.tsv".format(split)) + # upgrade old task + if isinstance(task_cfg, Namespace): + if not hasattr(task_cfg, "autoregressive"): + task_cfg.autoregressive = not task_cfg.criterion == 'ctc' + + manifest = os.path.join(data_path, "{}.tsv".format(split)) self.datasets[split] = FileAudioDataset( manifest, - sample_rate=self.cfg.sample_rate, + sample_rate=task_cfg.sample_rate, max_sample_size=self.cfg.max_sample_size, min_sample_size=self.cfg.max_sample_size, min_length=self.cfg.min_sample_size, - pad=self.cfg.labels is not None or self.cfg.enable_padding, - normalize=self.cfg.normalize, + pad=task_cfg.labels is not None or task_cfg.enable_padding, + normalize=task_cfg.normalize, ) - if self.cfg.labels: - label_path = os.path.join(self.cfg.data, f"{split}.{self.cfg.labels}") + if task_cfg.labels: + label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") labels = [] with open(label_path, "r") as f: for line in f: @@ -156,7 +160,7 @@ def load_dataset(self, split, **kwargs): eos=self.target_dictionary.eos(), batch_targets=True, process_label=process_label, - add_to_input=self.cfg.autoregressive, + add_to_input=task_cfg.autoregressive, ) @property @@ -185,7 +189,6 @@ def filter_indices_by_size( def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = super().valid_step(sample, model, criterion) - if self.cfg.eval_wer and self.cfg.autoregressive: metrics = self._inference_with_wer(self.sequence_generator, sample, model) logging_output["_num_char_errors"] = metrics["num_char_errors"] @@ -204,15 +207,18 @@ def build_model(self, model_cfg: FairseqDataclass): ) if self.cfg.eval_wer_tokenizer: self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) + else: + self.tokenizer = None return model def _inference_with_wer(self, generator, sample, model): - def decode(toks, escape_unk=True): + import editdistance + + def decode(toks): s = self.target_dictionary.string( toks.int().cpu(), self.cfg.eval_wer_post_process, - escape_unk=escape_unk, - extra_symbols_to_ignore={generator.eos}, + escape_unk=True, ) if self.tokenizer: s = self.tokenizer.decode(s) @@ -225,14 +231,11 @@ def decode(toks, escape_unk=True): hyp = decode(gen_out[i][0]["tokens"]) ref = decode( utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), - escape_unk=True, ) - hyp = post_process(hyp, self.cfg.eval_wer_post_process).strip("_") - ref = post_process(ref, self.cfg.eval_wer_post_process).strip("_") num_char_errors += editdistance.eval(hyp, ref) num_chars += len(ref) - hyp_words = hyp.split("_") - ref_words = ref.split("_") + hyp_words = hyp.split() + ref_words = ref.split() num_word_errors += editdistance.eval(hyp_words, ref_words) num_words += len(ref_words) diff --git a/fairseq/tasks/fairseq_task.py b/fairseq/tasks/fairseq_task.py index c47f9c4200..24116bfd52 100644 --- a/fairseq/tasks/fairseq_task.py +++ b/fairseq/tasks/fairseq_task.py @@ -7,6 +7,7 @@ import os import warnings from argparse import Namespace +from typing import List import torch from fairseq import metrics, search, tokenizer, utils @@ -91,11 +92,20 @@ def setup_task(cls, cfg: DictConfig, **kwargs): def has_sharded_data(self, split): return os.pathsep in getattr(self.cfg, "data", "") - def load_dataset(self, split, combine=False, **kwargs): + def load_dataset( + self, + split: str, + combine: bool = False, + task_cfg: FairseqDataclass = None, + **kwargs + ): """Load a given dataset split. Args: split (str): name of the split (e.g., train, valid, test) + combine (bool): combines a split segmented into pieces into one dataset + task_cfg (FairseqDataclass): optional task configuration stored in the checkpoint that can be used + to load datasets """ raise NotImplementedError @@ -270,8 +280,6 @@ def build_model(self, cfg: FairseqDataclass): from fairseq import models, quantization_utils model = models.build_model(cfg, self) - if getattr(cfg, "tpu", False): - model.prepare_for_tpu_() model = quantization_utils.quantize_model_scalar(model, cfg) return model @@ -366,12 +374,14 @@ def build_generator( else: search_strategy = search.BeamSearch(self.target_dictionary) + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} if seq_gen_cls is None: if getattr(args, "print_alignment", False): seq_gen_cls = SequenceGeneratorWithAlignment + extra_gen_cls_kwargs['print_alignment'] = args.print_alignment else: seq_gen_cls = SequenceGenerator - extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + return seq_gen_cls( models, self.target_dictionary, @@ -428,6 +438,14 @@ def valid_step(self, sample, model, criterion): loss, sample_size, logging_output = criterion(model, sample) return loss, sample_size, logging_output + def optimizer_step(self, optimizer, model, update_num): + optimizer.step() + + def build_dataset_for_inference( + self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs + ) -> torch.utils.data.Dataset: + raise NotImplementedError + def inference_step( self, generator, models, sample, prefix_tokens=None, constraints=None ): @@ -547,8 +565,6 @@ def build_model(self, args: Namespace): from fairseq import models, quantization_utils model = models.build_model(args, self) - if getattr(args, "tpu", False): - model.prepare_for_tpu_() model = quantization_utils.quantize_model_scalar(model, args) return model diff --git a/fairseq/tasks/language_modeling.py b/fairseq/tasks/language_modeling.py index e0bf1f9b2b..4a44d967b3 100644 --- a/fairseq/tasks/language_modeling.py +++ b/fairseq/tasks/language_modeling.py @@ -15,6 +15,7 @@ AppendTokenDataset, Dictionary, IdDataset, + LMContextWindowDataset, MonolingualDataset, NestedDictionaryDataset, NumelDataset, @@ -312,6 +313,39 @@ def inference_step( models, sample, prefix_tokens=prefix_tokens, bos_token=bos_token ) + def eval_lm_dataloader( + self, + dataset, + max_tokens: Optional[int] = 36000, + batch_size: Optional[int] = None, + max_positions: Optional[int] = None, + num_shards: int = 1, + shard_id: int = 0, + num_workers: int = 1, + data_buffer_size: int = 10, + # ensures that every evaluated token has access to a context of at least + # this size, if possible + context_window: int = 0, + ): + if context_window > 0: + dataset = LMContextWindowDataset( + dataset=dataset, + tokens_per_sample=self.args.tokens_per_sample, + context_window=context_window, + pad_idx=self.source_dictionary.pad(), + ) + return self.get_batch_iterator( + dataset=dataset, + max_tokens=max_tokens, + max_sentences=batch_size, + max_positions=max_positions, + ignore_invalid_inputs=True, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + data_buffer_size=data_buffer_size, + ).next_epoch_itr(shuffle=False) + @property def source_dictionary(self): """Return the :class:`~fairseq.data.Dictionary` for the language diff --git a/fairseq/tasks/masked_lm.py b/fairseq/tasks/masked_lm.py index 56086f5e81..70208bc4d5 100644 --- a/fairseq/tasks/masked_lm.py +++ b/fairseq/tasks/masked_lm.py @@ -88,6 +88,15 @@ def add_args(parser): action="store_true", help="mask whole words; you may also want to set --bpe", ) + parser.add_argument( + "--mask-multiple-length", + default=1, + type=int, + help="repeat the mask indices multiple times", + ) + parser.add_argument( + "--mask-stdev", default=0.0, type=float, help="stdev of the mask length" + ) parser.add_argument( "--shorten-method", default="none", @@ -180,6 +189,8 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): random_token_prob=self.args.random_token_prob, freq_weighted_replacement=self.args.freq_weighted_replacement, mask_whole_words=mask_whole_words, + mask_multiple_length=self.args.mask_multiple_length, + mask_stdev=self.args.mask_stdev, ) with data_utils.numpy_seed(self.args.seed + epoch): diff --git a/fairseq/tasks/scribblelens.py b/fairseq/tasks/scribblelens.py index 983b859b2b..8a9803f3e6 100644 --- a/fairseq/tasks/scribblelens.py +++ b/fairseq/tasks/scribblelens.py @@ -7,9 +7,23 @@ import os import sys +import torch -from fairseq.data import FileHandwritingDataset, Dictionary, AddTargetDataset, HandwritingDictionary -from . import LegacyFairseqTask, register_task +from argparse import Namespace +from dataclasses import dataclass, field +from typing import Optional, Any +from omegaconf import MISSING + +from fairseq.data import (AddTargetDataset, Dictionary, FileAudioDataset, + FileHandwritingDataset, HandwritingDictionary, + encoders) +from fairseq.data.data_utils import post_process +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.configs import GenerationConfig + +from . import FairseqTask, LegacyFairseqTask, register_task +from .. import utils +from ..logging import metrics class LabelEncoder(object): @@ -22,136 +36,88 @@ def __call__(self, label): ) -@register_task("scribblelens") -class ScribblelensTask(LegacyFairseqTask): - """ - - """ - - @staticmethod - def add_args(parser): - """Add task-specific arguments to the parser.""" - parser.add_argument("data", help="path to data directory") - parser.add_argument( - "--normalize", - action="store_true", - help="if set, normalizes input to have 0 mean and unit variance", - ) - parser.add_argument( - "--max-sample-size", - default=None, - type=int, - help="max sample size to crop to for batching. default = min sample length", - ) - parser.add_argument( - "--min-sample-size", - default=None, - type=int, - help="min sample size to crop to for batching. default = same as --max-sample-size", - ) - - parser.add_argument( - "--pad-to-multiples-of", - default=None, - type=int, - help="enforce that lengths of inputs are multiples of this", - ) - - parser.add_argument( - "--enable-padding", - action="store_true", - help="pad shorter samples instead of cropping", # actually needed to be set to true - ) - - parser.add_argument( - "--labels", - # type=bool, - # default=None, - type=str, - #action="store_true", - help="if to return also labels from dataset" #"extension of the label file to load, if any", - ) - - def __init__(self, args, source_dictionary=None): - super().__init__(args) - self._target_dictionary = None +@dataclass +class ScribblelensConfig(FairseqDataclass): + data: str = field(default=MISSING, metadata={"help": "path to data directory"}) + labels: bool = field( + default=False, + metadata={"help": "if to return also labels from dataset"} + ) + vocab_path: Optional[str] = field( + default=None, + metadata={"help": "path to data directory"} + ) + normalize: bool = field( + default=False, + metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, + ) + enable_padding: bool = field( + default=False, metadata={"help": "pad shorter samples instead of cropping"} + ) + pad_to_multiples_of: Optional[int] = field( + default=None, + metadata={"help": "enforce that lengths of inputs are multiples of this"} + ) + max_sample_size: Optional[int] = field( + default=None, metadata={"help": "max sample size to crop to for batching"} + ) + min_sample_size: Optional[int] = field( + default=None, metadata={"help": "min sample size to crop to for batching"} + ) + + + +@register_task("scribblelens", dataclass=ScribblelensConfig) +class ScribblelensTask(FairseqTask): + """""" + + cfg: ScribblelensConfig + + def __init__( + self, + cfg: ScribblelensConfig, + source_dictionary=None, + target_dictionary=None, + ): + super().__init__(cfg) + self._target_dictionary = target_dictionary self._source_dictionary = source_dictionary - self.is_ctc = args.criterion == "ctc" + # if cfg.eval_wer: + # assert cfg.labels is not None, "eval_wer can only be set during fine-tuning" + self.blank_symbol = "*" @classmethod - def setup_task(cls, args, **kwargs): + def setup_task(cls, cfg:ScribblelensConfig, **kwargs): """Setup the task (e.g., load dictionaries). Args: - args (argparse.Namespace): parsed command-line arguments - """ - return cls(args) - - def load_dataset(self, split, **kwargs): - """Load a given dataset split. - - Args: - split (str): name of the split (e.g., train, valid, test) + cfg (ScribblelensConfig): configuration of this task """ - if not self.args.labels: - self.datasets[split] = FileHandwritingDataset( - self.args.data, - split=split, - max_sample_size=self.args.max_sample_size, - min_sample_size=self.args.max_sample_size, - pad_to_multiples_of=self.args.pad_to_multiples_of, - min_length=self.args.min_sample_size, - pad=self.args.labels is not None or self.args.enable_padding, - - normalize=self.args.normalize, - ) - + if cfg.labels: + target_dictionary = HandwritingDictionary(cfg.vocab_path) else: + target_dictionary = None + + return cls(cfg, target_dictionary=target_dictionary) + + def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): + data_path = self.cfg.data + task_cfg = task_cfg or self.cfg + vocab_path = task_cfg.vocab_path if task_cfg.vocab_path is not None else task_cfg.data + '/tasman.alphabet.plus.space.mode5.json' + + self.datasets[split] = FileHandwritingDataset( + task_cfg.data, + vocab_path=vocab_path, + split=split, + max_sample_size=task_cfg.max_sample_size, + min_sample_size=task_cfg.max_sample_size, + pad_to_multiples_of=task_cfg.pad_to_multiples_of, + min_length=task_cfg.min_sample_size, + pad=task_cfg.labels is not None or task_cfg.enable_padding, + normalize=task_cfg.normalize, + labels=task_cfg.labels + ) - # TODO change this stuff! - - #assert False ## TODO(JCh): we must load labels from scribblelens. - # https://github.com/pytorch/fairseq/blob/master/examples/wav2vec/README.md#fine-tune-a-pre-trained-model-with-ctc - # fairseq/examples/wav2vec/libri_labels.py - - dict_path = FileHandwritingDataset.vocabularyPath(self.args.data) #os.path.join(self.args.data, f"dict.{self.args.labels}.txt") - self._target_dictionary = HandwritingDictionary(dict_path) #Dictionary.load(dict_path) - - # this dictionary ^ seems to be a file with perhaps just words? or only one occurence? or sth? - # seems what it does behind the hood is split the transcribed line into words and encode each word with some id, seems it assigns new ids from 0/1 for every new word it sees - # perhaps for letters can just be letter - 'a' or sth - # what if stuff will learn classification in a different order? need to add some additional layer or what? well, yeah, there needs to be some to predict letters from representations - - # label_path = os.path.join(self.args.data, f"{split}.{self.args.labels}") # generated an example how this looks like - # labels = [] - # with open(label_path, "r") as f: - # for line in f: - # labels.append(line) - - # process_label = LabelEncoder(self.target_dictionary) // mayyybe TODO sth with that - - self.datasets[split] = FileHandwritingDataset( - self.args.data, - split=split, - max_sample_size=self.args.max_sample_size, - min_sample_size=self.args.max_sample_size, - pad_to_multiples_of=self.args.pad_to_multiples_of, - min_length=self.args.min_sample_size, - pad=self.args.labels is not None or self.args.enable_padding, - - normalize=self.args.normalize, - labels=True, - ) - - # AddTargetDataset( - # self.datasets[split], - # labels, - # pad=self.target_dictionary.pad(), - # eos=self.target_dictionary.eos(), - # batch_targets=True, - # process_label=process_label, - # add_to_input=not self.is_ctc, - # ) @property def source_dictionary(self): @@ -168,11 +134,11 @@ def max_positions(self): return (sys.maxsize, sys.maxsize) def filter_indices_by_size( - self, - indices, - dataset, - max_positions=None, - ignore_invalid_inputs=False, + self, + indices, + dataset, + max_positions=None, + ignore_invalid_inputs=False, ): # we do not need to filter by size in this task as dataloaders take care of this return indices diff --git a/fairseq/tasks/translation_multi_simple_epoch.py b/fairseq/tasks/translation_multi_simple_epoch.py index d871502a2c..34af9bf4a3 100644 --- a/fairseq/tasks/translation_multi_simple_epoch.py +++ b/fairseq/tasks/translation_multi_simple_epoch.py @@ -138,12 +138,16 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): """ if split in self.datasets: dataset = self.datasets[split] - if self.has_sharded_data(split) and dataset.load_next_shard: - shard_epoch = dataset.shard_epoch - else: - # no need to load next shard so skip loading - # also this avoid always loading from beginning of the data - return + if self.has_sharded_data(split): + if self.args.virtual_epoch_size is not None: + if dataset.load_next_shard: + shard_epoch = dataset.shard_epoch + else: + # no need to load next shard so skip loading + # also this avoid always loading from beginning of the data + return + else: + shard_epoch = epoch else: # estimate the shard epoch from virtual data size and virtual epoch size shard_epoch = self.data_manager.estimate_global_pass_epoch(epoch) @@ -153,7 +157,7 @@ def load_dataset(self, split, epoch=1, combine=False, **kwargs): del self.datasets[split] logger.info("old dataset deleted manually") logger.info(f"mem usage: {data_utils.get_mem_usage()}") - self.datasets[split] = self.data_manager.load_sampled_multi_epoch_dataset( + self.datasets[split] = self.data_manager.load_dataset( split, self.training, epoch=epoch, diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 19ca213d55..8f42743ac3 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -178,10 +178,7 @@ def criterion(self): @property def model(self): if self._wrapped_model is None: - if ( - self.data_parallel_world_size > 1 - and not self.cfg.optimization.use_bmuf - ): + if self.data_parallel_world_size > 1 and not self.cfg.optimization.use_bmuf: self._wrapped_model = models.DistributedFairseqModel( self.cfg.distributed_training, self._model, @@ -279,6 +276,7 @@ def save_checkpoint(self, filename, extra_state): self._optim_history, extra_state, ) + logger.info(f"Finished saving checkpoint to {filename}") def load_checkpoint( self, @@ -295,14 +293,17 @@ def load_checkpoint( """ extra_state, self._optim_history, last_optim_state = None, [], None + logger.info(f"Preparing to load checkpoint {filename}") bexists = PathManager.isfile(filename) if bexists: - if ( - self.data_parallel_rank == 0 + load_on_all_ranks = ( + self.cfg.checkpoint.load_checkpoint_on_all_dp_ranks # TPUs don't support broadcast yet, so load checkpoints # on every worker for now or self.tpu - ): + ) + + if load_on_all_ranks or self.data_parallel_rank == 0: state = checkpoint_utils.load_checkpoint_to_cpu(filename) last_optim_state = state.get("last_optimizer_state", None) @@ -310,7 +311,8 @@ def load_checkpoint( # state. Later we will broadcast sharded states to each rank # to avoid memory from exploding. if ( - self.cfg.distributed_training.zero_sharding == "os" + not load_on_all_ranks + and self.cfg.distributed_training.zero_sharding == "os" and "last_optimizer_state" in state and self.data_parallel_world_size > 1 ): @@ -319,11 +321,7 @@ def load_checkpoint( last_optim_state = None state = None - if ( - self.data_parallel_world_size > 1 - # disable on TPUs until they support broadcast - and not self.tpu - ): + if self.data_parallel_world_size > 1 and not load_on_all_ranks: state = distributed_utils.broadcast_object( state, src_rank=0, @@ -366,7 +364,7 @@ def load_checkpoint( if not reset_lr_scheduler: self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) - if self.data_parallel_world_size > 1: + if not load_on_all_ranks and self.data_parallel_world_size > 1: last_optim_state = self.optimizer.broadcast_global_state_dict( last_optim_state ) @@ -376,11 +374,6 @@ def load_checkpoint( if extra_state is not None: epoch = extra_state["train_iterator"]["epoch"] - logger.info( - "loaded checkpoint {} (epoch {} @ {} updates)".format( - filename, epoch, self.get_num_updates() - ) - ) if "previous_training_time" in extra_state: self._previous_training_time = extra_state["previous_training_time"] @@ -395,8 +388,15 @@ def load_checkpoint( for meter in metrics.get_meters("default"): if isinstance(meter, meters.TimeMeter): meter.reset() + + logger.info( + "Loaded checkpoint {} (epoch {} @ {} updates)".format( + filename, epoch, self.get_num_updates() + ) + ) + else: - logger.info("no existing checkpoint found {}".format(filename)) + logger.info("No existing checkpoint found {}".format(filename)) return extra_state @@ -506,16 +506,7 @@ def train_step(self, samples, raise_oom=False): # forward and backward pass logging_outputs, sample_size, ooms = [], 0, 0 for i, sample in enumerate(samples): - sample = self._prepare_sample(sample) - if sample is None: - # when sample is None, run forward/backward on a dummy batch - # and ignore the resulting gradients - sample = self._prepare_sample(self._dummy_batch) - is_dummy_batch = True - else: - if self._dummy_batch == "DUMMY": - self._dummy_batch = sample - is_dummy_batch = False + sample, is_dummy_batch = self._prepare_sample(sample) def maybe_no_sync(): """ @@ -632,29 +623,38 @@ def maybe_no_sync(): grad_norm = self.clip_grad_norm(self.cfg.optimization.clip_norm) # check that grad norms are consistent across workers - if ( - not self.cfg.optimization.use_bmuf - and self.cfg.distributed_training.distributed_wrapper != "SlowMo" - and not self.tpu - ): - self._check_grad_norms(grad_norm) + # on tpu check tensor is slow + if not self.tpu: + if ( + not self.cfg.optimization.use_bmuf + and self.cfg.distributed_training.distributed_wrapper != "SlowMo" + ): + self._check_grad_norms(grad_norm) + if not torch.isfinite(grad_norm).all(): + # check local gradnorm single GPU case, trigger NanDetector + raise FloatingPointError("gradients are Nan/Inf") with torch.autograd.profiler.record_function("optimizer"): # take an optimization step - self.optimizer.step() + self.task.optimizer_step( + self.optimizer, model=self.model, update_num=self.get_num_updates() + ) except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails + self.zero_grad() with NanDetector(self.get_model()): - self.task.train_step( - sample, - self.model, - self.criterion, - self.optimizer, - self.get_num_updates(), - ignore_grad=False, - ) + for _, sample in enumerate(samples): + sample, _ = self._prepare_sample(sample) + self.task.train_step( + sample, + self.model, + self.criterion, + self.optimizer, + self.get_num_updates(), + ignore_grad=False, + ) raise except OverflowError as e: overflow = True @@ -769,14 +769,7 @@ def valid_step(self, sample, raise_oom=False): self.model.eval() self.criterion.eval() - sample = self._prepare_sample(sample) - if sample is None: - sample = self._prepare_sample(self._dummy_batch) - is_dummy_batch = True - else: - if self._dummy_batch == "DUMMY": - self._dummy_batch = sample - is_dummy_batch = False + sample, is_dummy_batch = self._prepare_sample(sample) try: _loss, sample_size, logging_output = self.task.valid_step( @@ -835,7 +828,12 @@ def lr_step(self, epoch, val_loss=None): def lr_step_update(self): """Update the learning rate after each update.""" new_lr = self.lr_scheduler.step_update(self.get_num_updates()) - metrics.log_scalar("lr", new_lr, weight=0, priority=300) + if isinstance(new_lr, dict): + for k, v in new_lr.items(): + metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300) + new_lr = new_lr.get("default", next(iter(new_lr.values()))) + else: + metrics.log_scalar("lr", new_lr, weight=0, priority=300) return new_lr def get_lr(self): @@ -917,7 +915,7 @@ def _local_cumulative_training_time(self): """Aggregate training time in seconds.""" return time.time() - self._start_time + self._previous_training_time - def _prepare_sample(self, sample): + def _prepare_sample(self, sample, is_dummy=False): if sample == "DUMMY": raise Exception( "Trying to use an uninitialized 'dummy' batch. This usually indicates " @@ -926,7 +924,11 @@ def _prepare_sample(self, sample): ) if sample is None or len(sample) == 0: - return None + assert ( + self._dummy_batch is not None and len(self._dummy_batch) > 0 + ), "Invalid dummy batch: {}".format(self._dummy_batch) + sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True) + return sample, True if self.cuda: if self.pipeline_model_parallel: @@ -936,6 +938,9 @@ def _prepare_sample(self, sample): ) else: sample = utils.move_to_cuda(sample) + elif self.tpu and is_dummy: + # the dummy batch may not be on the appropriate device + sample = utils.move_to_cuda(sample, device=self.device) def apply_half(t): if t.dtype is torch.float32: @@ -953,7 +958,10 @@ def apply_bfloat16(t): if self.cfg.common.bf16: sample = utils.apply_to_sample(apply_bfloat16, sample) - return sample + if self._dummy_batch == "DUMMY": + self._dummy_batch = sample + + return sample, False def _set_seed(self): # Set seed based on args.seed and the update number so that we get @@ -1078,7 +1086,7 @@ def _check_grad_norms(self, grad_norm): def is_consistent(tensor): max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) return ( - not torch.isfinite(tensor).any() + torch.isfinite(tensor).all() or (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() ) @@ -1090,7 +1098,8 @@ def is_consistent(tensor): error_detail = "grad_norm across the workers:\n{}\n".format( pretty_detail ) - raise RuntimeError( + # use FloatingPointError to trigger NanDetector + raise FloatingPointError( "Fatal error: gradients are inconsistent between workers. " "Try --ddp-backend=no_c10d. " "Or are you mixing up different generation of GPUs in training?" diff --git a/fairseq/utils.py b/fairseq/utils.py index 8e9119124d..a20c83384c 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -106,7 +106,7 @@ def move_to_cuda(sample, device=None): def _move_to_cuda(tensor): # non_blocking is ignored if tensor is not pinned, so we can always set # to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620) - return tensor.cuda(device=device, non_blocking=True) + return tensor.to(device=device, non_blocking=True) return apply_to_sample(_move_to_cuda, sample) @@ -437,7 +437,7 @@ def import_user_module(args): module_path = getattr(args, "user_dir", None) if module_path is not None: module_path = os.path.abspath(args.user_dir) - if not os.path.exists(module_path): + if not os.path.exists(module_path) and not os.path.isfile(os.path.dirname(module_path)): fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) if os.path.exists(fairseq_rel_path): module_path = fairseq_rel_path @@ -631,6 +631,23 @@ def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): return alignment +def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): + tgt_valid = ( + ((tgt_sent != pad)).nonzero(as_tuple=False) + ) + src_valid = ( + ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1) + ) + alignment = [] + if len(tgt_valid) != 0 and len(src_valid) != 0: + attn_valid = attn[tgt_valid, src_valid] + alignment = [ + ["{:.6f}".format(p) for p in src_probs.tolist()] + for src_probs in attn_valid + ] + return alignment + + def new_arange(x, *size): """ Return a Tensor of `size` filled with a range function on the device of x. diff --git a/fairseq_cli/eval_lm.py b/fairseq_cli/eval_lm.py index e8fd98c325..f27e0258d0 100644 --- a/fairseq_cli/eval_lm.py +++ b/fairseq_cli/eval_lm.py @@ -11,14 +11,16 @@ import logging import math import os +import sys from argparse import Namespace +from typing import Iterable, List, Optional import torch +import fairseq from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils -from fairseq.data import LMContextWindowDataset from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import progress_bar -from fairseq.logging.meters import StopwatchMeter, TimeMeter +from fairseq.logging.meters import StopwatchMeter from fairseq.sequence_scorer import SequenceScorer from omegaconf import DictConfig @@ -27,144 +29,74 @@ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, ) logger = logging.getLogger("fairseq_cli.eval_lm") -class WordStat(object): - def __init__(self, word, is_bpe): - self.word = word - self.is_bpe = is_bpe - self.log_prob = 0 - self.next_word_prob = 0 - self.count = 0 - self.missing_next_words = 0 - - def add(self, log_prob, next_word_prob): - """increments counters for the sum of log probs of current word and next - word (given context ending at current word). Since the next word might be at the end of the example, - or it might be not counted because it is not an ending subword unit, - also keeps track of how many of those we have seen""" - if next_word_prob is not None: - self.next_word_prob += next_word_prob - else: - self.missing_next_words += 1 - self.log_prob += log_prob - self.count += 1 - - def __str__(self): - return "{}\t{}\t{}\t{}\t{}\t{}".format( - self.word, - self.count, - self.log_prob, - self.is_bpe, - self.next_word_prob, - self.count - self.missing_next_words, - ) - - -def main(cfg: DictConfig, **unused_kwargs): - if isinstance(cfg, Namespace): - cfg = convert_namespace_to_omegaconf(cfg) - - utils.import_user_module(cfg.common) - - use_fp16 = cfg.common.fp16 - use_cuda = torch.cuda.is_available() and not cfg.common.cpu - - if use_cuda: - torch.cuda.set_device(cfg.distributed_training.device_id) - - logger.info(cfg) - - # Load ensemble - logger.info("loading model(s) from {}".format(cfg.common_eval.path)) - - # reduce tokens per sample by the required context window size - cfg.task.tokens_per_sample -= cfg.eval_lm.context_window - - # Initialize the task using the current *cfg* - task = tasks.setup_task(cfg.task) - - # Initialize the model (but not the task) using the checkpoint's *cfg* - models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( - [cfg.common_eval.path], - arg_overrides=eval(cfg.common_eval.model_overrides), - suffix=cfg.checkpoint.checkpoint_suffix, - strict=(cfg.checkpoint.checkpoint_shard_count == 1), - num_shards=cfg.checkpoint.checkpoint_shard_count, - task=task, - ) - - # Load dataset splits - gen_subset = cfg.dataset.gen_subset - task.load_dataset(gen_subset) - dataset = task.dataset(gen_subset) - if cfg.eval_lm.context_window > 0: - dataset = LMContextWindowDataset( - dataset=dataset, - tokens_per_sample=cfg.task.tokens_per_sample, - context_window=cfg.eval_lm.context_window, - pad_idx=task.source_dictionary.pad(), - ) - logger.info("{} {} {} examples".format(cfg.task.data, gen_subset, len(dataset))) - - # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer) - for model in models: - if use_fp16: - model.half() - if use_cuda and not cfg.distributed_training.pipeline_model_parallel: - model.cuda() - model.prepare_for_inference_(cfg) - - assert len(models) > 0 - - logger.info( - "num. model params: {}".format(sum(p.numel() for p in models[0].parameters())) - ) - - itr = task.get_batch_iterator( - dataset=dataset, - max_tokens=cfg.dataset.max_tokens or 36000, - max_sentences=cfg.dataset.batch_size, - max_positions=utils.resolve_max_positions( - *[model.max_positions() for model in models] - ), - ignore_invalid_inputs=True, - num_shards=max( - cfg.dataset.num_shards, - cfg.distributed_training.distributed_world_size, - ), - shard_id=max( - cfg.dataset.shard_id, - cfg.distributed_training.distributed_rank, - ), - num_workers=cfg.dataset.num_workers, - data_buffer_size=cfg.dataset.data_buffer_size, - ).next_epoch_itr(shuffle=False) - progress = progress_bar.progress_bar( - itr, - log_format=cfg.common.log_format, - log_interval=cfg.common.log_interval, - default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), - ) +def eval_lm( + models: List[fairseq.models.FairseqModel], + source_dictionary: fairseq.data.Dictionary, + batch_iterator: Iterable, + post_process: Optional[str] = None, + output_word_probs: bool = False, + output_word_stats: bool = False, + target_dictionary: Optional[fairseq.data.Dictionary] = None, + softmax_batch: int = False, + remove_bos_token: bool = False, + device: Optional[torch.device] = None, +): + """ + Args: + models (List[~fairseq.models.FairseqModel]): list of models to + evaluate. Models are essentially `nn.Module` instances, but + must be compatible with fairseq's `SequenceScorer`. + source_dictionary (~fairseq.data.Dictionary): dictionary for + applying any relevant post processing or outputing word + probs/stats. + batch_iterator (Iterable): yield batches of data + post_process (Optional[str]): post-process text by removing BPE, + letter segmentation, etc. Valid options can be found in + fairseq.data.utils.post_process, although not all options + are implemented here. + output_word_probs (Optional[bool]): output words and their + predicted log probabilities + output_word_stats (Optional[bool]): output word statistics such + as word count and average probability + target_dictionary (Optional[~fairseq.data.Dictionary]): output + dictionary (defaults to *source_dictionary*) + softmax_batch (Optional[bool]): if BxT is more than this, will + batch the softmax over vocab to this amount of tokens, in + order to fit into GPU memory + remove_bos_token (Optional[bool]): if True, confirm that the + first token is the beginning-of-sentence symbol (according + to the relevant dictionary) and remove it from the output + device (Optional[torch.device]): device to use for evaluation + (defaults to device of first model parameter) + """ + if target_dictionary is None: + target_dictionary = source_dictionary + if device is None: + device = next(models[0].parameters()).device gen_timer = StopwatchMeter() - scorer = SequenceScorer(task.target_dictionary, cfg.eval_lm.softmax_batch) + scorer = SequenceScorer(target_dictionary, softmax_batch) score_sum = 0.0 count = 0 - if cfg.common_eval.post_process is not None: - if cfg.common_eval.post_process == "sentencepiece": - raise NotImplementedError - else: - bpe_cont = cfg.common_eval.post_process.rstrip() + if post_process is not None: + if post_process in {"subword_nmt", "@@ "}: + bpe_cont = post_process.rstrip() bpe_toks = { i - for i in range(len(task.source_dictionary)) - if task.source_dictionary[i].endswith(bpe_cont) + for i in range(len(source_dictionary)) + if source_dictionary[i].endswith(bpe_cont) } + else: + raise NotImplementedError( + "--post-process={post_process} is not implemented" + ) bpe_len = len(bpe_cont) else: bpe_toks = None @@ -172,13 +104,11 @@ def main(cfg: DictConfig, **unused_kwargs): word_stats = dict() - wps_meter = TimeMeter() - - for sample in progress: + for sample in batch_iterator: if "net_input" not in sample: continue - sample = utils.move_to_cuda(sample) if use_cuda else sample + sample = utils.move_to_cuda(sample, device=device) gen_timer.start() hypos = scorer.generate(models, sample) @@ -192,8 +122,8 @@ def main(cfg: DictConfig, **unused_kwargs): tgt_len = tokens.numel() pos_scores = hypo["positional_scores"].float() - if getattr(cfg.task, "add_bos_token", False): - assert hypo["tokens"][0].item() == task.target_dictionary.bos() + if remove_bos_token: + assert hypo["tokens"][0].item() == target_dictionary.bos() tokens = tokens[1:] pos_scores = pos_scores[1:] @@ -209,19 +139,19 @@ def main(cfg: DictConfig, **unused_kwargs): if inf_scores.any(): logger.info( "skipping tokens with inf scores:", - task.target_dictionary.string(tokens[inf_scores.nonzero()]), + target_dictionary.string(tokens[inf_scores.nonzero()]), ) pos_scores = pos_scores[(~inf_scores).nonzero()] score_sum += pos_scores.sum().cpu() count += pos_scores.numel() - skipped_toks - if cfg.eval_lm.output_word_probs or cfg.eval_lm.output_word_stats: + if output_word_probs or output_word_stats: w = "" word_prob = [] is_bpe = False for i in range(len(tokens)): w_ind = tokens[i].item() - w += task.source_dictionary[w_ind] + w += source_dictionary[w_ind] if bpe_toks is not None and w_ind in bpe_toks: w = w[:-bpe_len] is_bpe = True @@ -241,7 +171,7 @@ def main(cfg: DictConfig, **unused_kwargs): ) is_bpe = False w = "" - if cfg.eval_lm.output_word_probs: + if output_word_probs: logger.info( str(int(sample_id)) + " " @@ -252,24 +182,154 @@ def main(cfg: DictConfig, **unused_kwargs): ) ) - wps_meter.update(sample["ntokens"]) - progress.log({"wps": round(wps_meter.avg)}) - - avg_nll_loss = -score_sum / count / math.log(2) # convert to base 2 + avg_nll_loss = -score_sum / count / math.log(2) if count > 0 else 0 # convert to base 2 logger.info( "Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)".format( - gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg + gen_timer.n, gen_timer.sum, 1.0 / gen_timer.avg if gen_timer.avg > 0 else 0 + ) + ) + + if output_word_stats: + for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): + logger.info(ws) + + return { + "loss": avg_nll_loss, + "perplexity": 2 ** avg_nll_loss, + } + + +class WordStat(object): + def __init__(self, word, is_bpe): + self.word = word + self.is_bpe = is_bpe + self.log_prob = 0 + self.next_word_prob = 0 + self.count = 0 + self.missing_next_words = 0 + + def add(self, log_prob, next_word_prob): + """increments counters for the sum of log probs of current word and next + word (given context ending at current word). Since the next word might be at the end of the example, + or it might be not counted because it is not an ending subword unit, + also keeps track of how many of those we have seen""" + if next_word_prob is not None: + self.next_word_prob += next_word_prob + else: + self.missing_next_words += 1 + self.log_prob += log_prob + self.count += 1 + + def __str__(self): + return "{}\t{}\t{}\t{}\t{}\t{}".format( + self.word, + self.count, + self.log_prob, + self.is_bpe, + self.next_word_prob, + self.count - self.missing_next_words, ) + + +def main(cfg: DictConfig, **unused_kwargs): + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + utils.import_user_module(cfg.common) + + logger.info(cfg) + + if cfg.eval_lm.context_window > 0: + # reduce tokens per sample by the required context window size + cfg.task.tokens_per_sample -= cfg.eval_lm.context_window + + # Initialize the task using the current *cfg* + task = tasks.setup_task(cfg.task) + + # Load ensemble + logger.info("loading model(s) from {}".format(cfg.common_eval.path)) + models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [cfg.common_eval.path], + arg_overrides=eval(cfg.common_eval.model_overrides), + suffix=cfg.checkpoint.checkpoint_suffix, + strict=(cfg.checkpoint.checkpoint_shard_count == 1), + num_shards=cfg.checkpoint.checkpoint_shard_count, + task=task, ) + + use_fp16 = cfg.common.fp16 + use_cuda = torch.cuda.is_available() and not cfg.common.cpu + if use_cuda: + torch.cuda.set_device(cfg.distributed_training.device_id) + + # Optimize ensemble for generation and set the source and dest dicts on the model + # (required by scorer) + for model in models: + if use_fp16: + model.half() + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(cfg) + + assert len(models) > 0 + + logger.info( + "num. model params: {}".format(sum(p.numel() for p in models[0].parameters())) + ) + + # Load dataset splits + task.load_dataset(cfg.dataset.gen_subset) + dataset = task.dataset(cfg.dataset.gen_subset) + logger.info( + "{} {} {} examples".format(cfg.task.data, cfg.dataset.gen_subset, len(dataset)) + ) + + itr = task.eval_lm_dataloader( + dataset=dataset, + max_tokens=cfg.dataset.max_tokens or 36000, + batch_size=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + *[model.max_positions() for model in models] + ), + num_shards=max( + cfg.dataset.num_shards, + cfg.distributed_training.distributed_world_size, + ), + shard_id=max( + cfg.dataset.shard_id, + cfg.distributed_training.distributed_rank, + ), + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + context_window=cfg.eval_lm.context_window, + ) + + itr = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + results = eval_lm( + models=models, + source_dictionary=task.source_dictionary, + batch_iterator=itr, + post_process=cfg.common_eval.post_process, + output_word_probs=cfg.eval_lm.output_word_probs, + output_word_stats=cfg.eval_lm.output_word_stats, + target_dictionary=task.target_dictionary, + softmax_batch=cfg.eval_lm.softmax_batch, + remove_bos_token=getattr(cfg.task, "add_bos_token", False), + ) + logger.info( "Loss (base 2): {:.4f}, Perplexity: {:.2f}".format( - avg_nll_loss, 2 ** avg_nll_loss + results["loss"], results["perplexity"] ) ) - if cfg.eval_lm.output_word_stats: - for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True): - logger.info(ws) + return results def cli_main(): diff --git a/fairseq_cli/generate.py b/fairseq_cli/generate.py index 021f819ed7..4aeb4a56fa 100644 --- a/fairseq_cli/generate.py +++ b/fairseq_cli/generate.py @@ -81,7 +81,7 @@ def _main(cfg: DictConfig, output_file): # Load dataset splits task = tasks.setup_task(cfg.task) - task.load_dataset(cfg.dataset.gen_subset) + # Set dictionaries try: @@ -94,7 +94,7 @@ def _main(cfg: DictConfig, output_file): # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) - models, _model_args = checkpoint_utils.load_model_ensemble( + models, saved_cfg = checkpoint_utils.load_model_ensemble( utils.split_paths(cfg.common_eval.path), arg_overrides=overrides, task=task, @@ -103,6 +103,9 @@ def _main(cfg: DictConfig, output_file): num_shards=cfg.checkpoint.checkpoint_shard_count, ) + # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config + task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) + if cfg.generation.lm_path is not None: overrides["data"] = cfg.task.data @@ -296,7 +299,7 @@ def decode_fn(x): file=output_file, ) - if cfg.generation.print_alignment: + if cfg.generation.print_alignment == "hard": print( "A-{}\t{}".format( sample_id, @@ -309,6 +312,19 @@ def decode_fn(x): ), file=output_file, ) + if cfg.generation.print_alignment == "soft": + print( + "A-{}\t{}".format( + sample_id, + " ".join( + [ + ",".join(src_probs) + for src_probs in alignment + ] + ), + ), + file=output_file, + ) if cfg.generation.print_step: print( diff --git a/fairseq_cli/hydra_train.py b/fairseq_cli/hydra_train.py index ffd3c5cd07..b092ce14ee 100644 --- a/fairseq_cli/hydra_train.py +++ b/fairseq_cli/hydra_train.py @@ -4,29 +4,32 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import hydra -from omegaconf import OmegaConf +import logging import os +import sys from fairseq.dataclass.initialize import hydra_init from fairseq_cli.train import main as pre_main from fairseq import distributed_utils from fairseq.dataclass.configs import FairseqConfig -import logging +import hydra import torch +from omegaconf import OmegaConf -logger = logging.getLogger(__name__) +logger = logging.getLogger("fairseq_cli.hydra_train") @hydra.main(config_path=os.path.join("..", "fairseq", "config"), config_name="config") def hydra_main(cfg: FairseqConfig) -> None: - cfg = OmegaConf.create(OmegaConf.to_container(cfg, resolve=True, enum_to_str=True)) OmegaConf.set_struct(cfg, True) + if cfg.common.reset_logging: + reset_logging() # Hydra hijacks logging, fix that + if cfg.common.profile: with torch.cuda.profiler.profile(): with torch.autograd.profiler.emit_nvtx(): @@ -35,7 +38,22 @@ def hydra_main(cfg: FairseqConfig) -> None: distributed_utils.call_main(cfg, pre_main) -if __name__ == "__main__": +def reset_logging(): + root = logging.getLogger() + for handler in root.handlers: + root.removeHandler(handler) + root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + logging.Formatter( + fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + ) + ) + root.addHandler(handler) + + +def cli_main(): try: from hydra._internal.utils import get_args @@ -46,3 +64,7 @@ def hydra_main(cfg: FairseqConfig) -> None: hydra_init(cfg_name) hydra_main() + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq_cli/interactive.py b/fairseq_cli/interactive.py index 530830d6b0..4785855985 100644 --- a/fairseq_cli/interactive.py +++ b/fairseq_cli/interactive.py @@ -173,7 +173,7 @@ def main(cfg: FairseqConfig): model.prepare_for_inference_(cfg) # Initialize generator - generator = task.build_generator(models, cfg.task) + generator = task.build_generator(models, cfg.generation) # Handle tokenization and BPE tokenizer = encoders.build_tokenizer(cfg.tokenizer) diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index e1af605348..165ed86b58 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -16,7 +16,6 @@ import numpy as np import torch - from fairseq import ( checkpoint_utils, distributed_utils, @@ -29,8 +28,8 @@ from fairseq.dataclass.utils import convert_namespace_to_omegaconf from fairseq.logging import meters, metrics, progress_bar from fairseq.model_parallel.megatron_trainer import MegatronTrainer -from omegaconf import DictConfig from fairseq.trainer import Trainer +from omegaconf import DictConfig logging.basicConfig( @@ -76,7 +75,7 @@ def main(cfg: DictConfig) -> None: logger.info(model) logger.info("task: {}".format(task.__class__.__name__)) logger.info("model: {}".format(model.__class__.__name__)) - logger.info("criterion: {})".format(criterion.__class__.__name__)) + logger.info("criterion: {}".format(criterion.__class__.__name__)) logger.info( "num. model params: {} (num. trained: {})".format( sum(p.numel() for p in model.parameters()), @@ -125,7 +124,15 @@ def main(cfg: DictConfig) -> None: lr = trainer.get_lr() train_meter = meters.StopwatchMeter() train_meter.start() - while lr > cfg.optimization.min_lr and epoch_itr.next_epoch_idx <= max_epoch: + while epoch_itr.next_epoch_idx <= max_epoch: + if lr <= cfg.optimization.stop_min_lr: + logger.info( + f"stopping training because current learning rate ({lr}) is smaller " + "than or equal to minimum learning rate " + f"(--stop-min-lr={cfg.optimization.stop_min_lr})" + ) + break + # train for one epoch valid_losses, should_stop = train(cfg, trainer, task, epoch_itr) if should_stop: @@ -189,7 +196,7 @@ def train( else cfg.optimization.update_freq[-1] ) itr = iterators.GroupedIterator(itr, update_freq) - if getattr(cfg.common, "tpu", False): + if cfg.common.tpu: itr = utils.tpu_data_loader(itr) progress = progress_bar.progress_bar( itr, @@ -203,7 +210,17 @@ def train( ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( - cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.wandb_project + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) + ), + azureml_logging=( + cfg.common.azureml_logging + if distributed_utils.is_master(cfg.distributed_training) + else False ), ) @@ -257,9 +274,32 @@ def validate_and_save( ) -> Tuple[List[Optional[float]], bool]: num_updates = trainer.get_num_updates() max_update = cfg.optimization.max_update or math.inf + + # Stopping conditions (and an additional one based on validation loss later + # on) + should_stop = False + if num_updates >= max_update: + should_stop = True + logger.info( + f"Stopping training due to " + f"num_updates: {num_updates} >= max_update: {max_update}" + ) + + training_time_hours = trainer.cumulative_training_time() / (60 * 60) + if ( + cfg.optimization.stop_time_hours > 0 + and training_time_hours > cfg.optimization.stop_time_hours + ): + should_stop = True + logger.info( + f"Stopping training due to " + f"cumulative_training_time: {training_time_hours} > " + f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)" + ) + do_save = ( (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0) - or num_updates >= max_update + or should_stop or ( cfg.checkpoint.save_interval_updates > 0 and num_updates > 0 @@ -270,7 +310,7 @@ def validate_and_save( do_validate = ( (not end_of_epoch and do_save) # validate during mid-epoch saves or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0) - or num_updates >= max_update + or should_stop or ( cfg.dataset.validate_interval_updates > 0 and num_updates > 0 @@ -283,20 +323,10 @@ def validate_and_save( if do_validate: valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets) - # Stopping conditions - should_stop = ( - should_stop_early(cfg, valid_losses[0]) - or num_updates >= max_update - or ( - cfg.optimization.stop_time_hours > 0 - and trainer.cumulative_training_time() / (60 * 60) - > cfg.optimization.stop_time_hours - ) - ) + should_stop |= should_stop_early(cfg, valid_losses[0]) # Save checkpoint if do_save or should_stop: - logger.info("begin save checkpoint") checkpoint_utils.save_checkpoint( cfg.checkpoint, trainer, epoch_itr, valid_losses[0] ) @@ -344,7 +374,12 @@ def validate( ), default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), wandb_project=( - cfg.common.wandb_project if distributed_utils.is_master(cfg.distributed_training) else None + cfg.common.wandb_project + if distributed_utils.is_master(cfg.distributed_training) + else None + ), + wandb_run_name=os.environ.get( + "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir) ), ) diff --git a/fairseq_cli/validate.py b/fairseq_cli/validate.py index 36e8bd16ca..c69bb94142 100644 --- a/fairseq_cli/validate.py +++ b/fairseq_cli/validate.py @@ -51,7 +51,7 @@ def main(cfg: DictConfig, override_args=None): # Load ensemble logger.info("loading model(s) from {}".format(cfg.common_eval.path)) - models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( [cfg.common_eval.path], arg_overrides=overrides, suffix=cfg.checkpoint.checkpoint_suffix, @@ -66,15 +66,15 @@ def main(cfg: DictConfig, override_args=None): model.cuda() # Print args - logger.info(model_args) + logger.info(saved_cfg) # Build criterion - criterion = task.build_criterion(model_args.criterion) + criterion = task.build_criterion(saved_cfg.criterion) criterion.eval() for subset in cfg.dataset.valid_subset.split(","): try: - task.load_dataset(subset, combine=False, epoch=1) + task.load_dataset(subset, combine=False, epoch=1, task_cfg=saved_cfg.task) dataset = task.dataset(subset) except KeyError: raise Exception("Cannot find dataset: " + subset) diff --git a/hydra-train.py b/hydra-train.py new file mode 100644 index 0000000000..1e064a02fd --- /dev/null +++ b/hydra-train.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Hydra entry point with debugging. +""" + +from fairseq_cli.hydra_train import cli_main + + +if __name__ == '__main__': + # import ptvsd + # ptvsd.enable_attach(('0.0.0.0', 7309)) + # print("Attach debugger now") + # ptvsd.wait_for_attach() + cli_main() diff --git a/setup.py b/setup.py index 572d2b50de..1954298034 100644 --- a/setup.py +++ b/setup.py @@ -22,14 +22,18 @@ def write_version_py(): # append latest commit hash to version string try: - sha = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("ascii").strip() + sha = ( + subprocess.check_output(["git", "rev-parse", "HEAD"]) + .decode("ascii") + .strip() + ) version += "+" + sha[:7] except Exception: pass # write version info to fairseq/version.py with open(os.path.join("fairseq", "version.py"), "w") as f: - f.write("__version__ = \"{}\"\n".format(version)) + f.write('__version__ = "{}"\n'.format(version)) return version @@ -132,7 +136,7 @@ def include_dirs(self, dirs): # use CPU build of PyTorch dependency_links = [ - "https://download.pytorch.org/whl/cpu/torch-1.3.0%2Bcpu-cp36-cp36m-linux_x86_64.whl" + "https://download.pytorch.org/whl/cpu/torch-1.7.0%2Bcpu-cp36-cp36m-linux_x86_64.whl" ] else: dependency_links = [] @@ -149,6 +153,11 @@ def include_dirs(self, dirs): ) +extra_packages = [] +if os.path.exists(os.path.join("fairseq", "model_parallel", "megatron", "mpu")): + extra_packages.append("fairseq.model_parallel.megatron.mpu") + + def do_setup(package_data): setup( name="fairseq", @@ -159,22 +168,26 @@ def do_setup(package_data): "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], long_description=readme, long_description_content_type="text/markdown", setup_requires=[ "cython", - "numpy", + 'numpy<1.20.0; python_version<"3.7"', + 'numpy; python_version>="3.7"', "setuptools>=18.0", ], install_requires=[ "cffi", "cython", - "dataclasses", - "editdistance", - "hydra-core", - "numpy", + 'dataclasses; python_version<"3.7"', + "hydra-core<1.1", + "omegaconf<2.1", + 'numpy<1.20.0; python_version<"3.7"', + 'numpy; python_version>="3.7"', "regex", "sacrebleu>=1.4.12", "torch", @@ -190,7 +203,8 @@ def do_setup(package_data): "tests", "tests.*", ] - ), + ) + + extra_packages, package_data=package_data, ext_modules=extensions, test_suite="tests", @@ -198,6 +212,7 @@ def do_setup(package_data): "console_scripts": [ "fairseq-eval-lm = fairseq_cli.eval_lm:cli_main", "fairseq-generate = fairseq_cli.generate:cli_main", + "fairseq-hydra-train = fairseq_cli.hydra_train:cli_main", "fairseq-interactive = fairseq_cli.interactive:cli_main", "fairseq-preprocess = fairseq_cli.preprocess:cli_main", "fairseq-score = fairseq_cli.score:cli_main", @@ -223,12 +238,16 @@ def get_files(path, relative_to="fairseq"): try: # symlink examples into fairseq package so package_data accepts them - if "build_ext" not in sys.argv[1:]: - os.symlink(os.path.join("..", "examples"), "fairseq/examples") + fairseq_examples = os.path.join("fairseq", "examples") + if "build_ext" not in sys.argv[1:] and not os.path.exists(fairseq_examples): + os.symlink(os.path.join("..", "examples"), fairseq_examples) + package_data = { - "fairseq": get_files("fairseq/examples"), + "fairseq": ( + get_files(fairseq_examples) + get_files(os.path.join("fairseq", "config")) + ) } do_setup(package_data) finally: - if "build_ext" not in sys.argv[1:]: - os.unlink("fairseq/examples") + if "build_ext" not in sys.argv[1:] and os.path.exists(fairseq_examples): + os.unlink(fairseq_examples) diff --git a/tests/distributed/__init__.py b/tests/distributed/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/distributed/test_distributed_utils.py b/tests/distributed/test_distributed_utils.py new file mode 100644 index 0000000000..161ee85eaa --- /dev/null +++ b/tests/distributed/test_distributed_utils.py @@ -0,0 +1,69 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import sys +import unittest + +import torch + +from fairseq import distributed_utils as dist_utils + +from .utils import objects_are_equal, spawn_and_init + + +class TestDistributedUtils(unittest.TestCase): + def setUp(self): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA not available, skipping test") + if sys.platform == "win32": + raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") + if torch.cuda.device_count() < 2: + raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") + + def test_broadcast_object_python(self): + spawn_and_init( + functools.partial( + TestDistributedUtils._test_broadcast_object, + "hello world", + ), + world_size=2, + ) + + def test_broadcast_object_tensor(self): + spawn_and_init( + functools.partial( + TestDistributedUtils._test_broadcast_object, + torch.rand(5), + ), + world_size=2, + ) + + def test_broadcast_object_complex(self): + spawn_and_init( + functools.partial( + TestDistributedUtils._test_broadcast_object, + { + "a": "1", + "b": [2, torch.rand(2, 3), 3], + "c": (torch.rand(2, 3), 4), + "d": {5, torch.rand(5)}, + "e": torch.rand(5), + "f": torch.rand(5).int().cuda(), + }, + ), + world_size=2, + ) + + @staticmethod + def _test_broadcast_object(ref_obj, rank, group): + obj = dist_utils.broadcast_object( + ref_obj if rank == 0 else None, src_rank=0, group=group + ) + assert objects_are_equal(ref_obj, obj) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/distributed/utils.py b/tests/distributed/utils.py new file mode 100644 index 0000000000..d2b3ddb1ff --- /dev/null +++ b/tests/distributed/utils.py @@ -0,0 +1,61 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import tempfile + +import torch + + +def spawn_and_init(fn, world_size, args=None): + if args is None: + args = () + with tempfile.NamedTemporaryFile(delete=False) as tmp_file: + torch.multiprocessing.spawn( + fn=functools.partial(init_and_run, fn, args), + args=(world_size, tmp_file.name,), + nprocs=world_size, + ) + + +def distributed_init(rank, world_size, tmp_file): + torch.distributed.init_process_group( + backend="nccl", + init_method="file://{}".format(tmp_file), + world_size=world_size, + rank=rank, + ) + torch.cuda.set_device(rank) + + +def init_and_run(fn, args, rank, world_size, tmp_file): + distributed_init(rank, world_size, tmp_file) + group = torch.distributed.new_group() + fn(rank, group, *args) + + +def objects_are_equal(a, b) -> bool: + if type(a) is not type(b): + return False + if isinstance(a, dict): + if set(a.keys()) != set(b.keys()): + return False + for k in a.keys(): + if not objects_are_equal(a[k], b[k]): + return False + return True + elif isinstance(a, (list, tuple, set)): + if len(a) != len(b): + return False + return all(objects_are_equal(x, y) for x, y in zip(a, b)) + elif torch.is_tensor(a): + return ( + a.size() == b.size() + and a.dtype == b.dtype + and a.device == b.device + and torch.all(a == b) + ) + else: + return a == b diff --git a/tests/gpu/test_binaries_gpu.py b/tests/gpu/test_binaries_gpu.py index 2ac60a0934..5690e73752 100644 --- a/tests/gpu/test_binaries_gpu.py +++ b/tests/gpu/test_binaries_gpu.py @@ -93,17 +93,25 @@ def test_levenshtein_transformer(self): ], task="translation_lev", ) + gen_config = [ + "--task", + "translation_lev", + "--iter-decode-max-iter", + "9", + "--iter-decode-eos-penalty", + "0", + "--print-step", + ] + # non-ensemble generation + generate_main(data_dir, gen_config) + # ensemble generation generate_main( data_dir, - [ - "--task", - "translation_lev", - "--iter-decode-max-iter", - "9", - "--iter-decode-eos-penalty", - "0", - "--print-step", - ], + gen_config, + path=os.pathsep.join([ + os.path.join(data_dir, "checkpoint_last.pt"), + os.path.join(data_dir, "checkpoint_last.pt"), + ]), ) diff --git a/tests/test_binaries.py b/tests/test_binaries.py index dae38dda0c..4e605bd0b1 100644 --- a/tests/test_binaries.py +++ b/tests/test_binaries.py @@ -5,13 +5,14 @@ import contextlib import logging +import json import os import random import sys import tempfile import unittest from io import StringIO - +from typing import List, Dict import torch from fairseq import options from fairseq_cli import eval_lm, train, validate @@ -25,6 +26,14 @@ ) +try: + import transformers # noqa + + has_hf_transformers = True +except ImportError: + has_hf_transformers = False + + class TestTranslation(unittest.TestCase): def setUp(self): logging.disable(logging.CRITICAL) @@ -295,7 +304,9 @@ def test_multilingual_transformer(self): + dec_ltok_flag, ) - @unittest.skipIf(sys.platform.lower() == "darwin", "skip latent depth test on MacOS") + @unittest.skipIf( + sys.platform.lower() == "darwin", "skip latent depth test on MacOS" + ) def test_multilingual_translation_latent_depth(self): # test with latent depth in encoder, decoder, or both encoder_latent_layer = [[], ["--encoder-latent-layer"]] @@ -425,7 +436,7 @@ def test_translation_multi_simple_epoch(self): + dec_ltok_flag, ) - def test_translation_multi_simple_epoch_dicts(self): + def test_translation_multi_simple_epoch_no_vepoch(self): # test with all combinations of encoder/decoder lang tokens with contextlib.redirect_stdout(StringIO()): enc_ltok_flag = ["--encoder-langtok", "src"] @@ -434,9 +445,57 @@ def test_translation_multi_simple_epoch_dicts(self): "test_translation_multi_simple_epoch_dict" ) as data_dir: create_dummy_data(data_dir) - preprocess_translation_data( - data_dir, extra_flags=[] + preprocess_translation_data(data_dir, extra_flags=[]) + train_translation_model( + data_dir, + arch="transformer", + task="translation_multi_simple_epoch", + extra_flags=[ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--sampling-method", + "temperature", + "--sampling-temperature", + "1.5", + ] + + enc_ltok_flag + + dec_ltok_flag, + lang_flags=["--lang-pairs", "in-out"], + run_validation=True, + extra_valid_flags=enc_ltok_flag + dec_ltok_flag, + ) + generate_main( + data_dir, + extra_flags=[ + "--task", + "translation_multi_simple_epoch", + "--lang-pairs", + "in-out", + "--source-lang", + "in", + "--target-lang", + "out", + ] + + enc_ltok_flag + + dec_ltok_flag, ) + + def test_translation_multi_simple_epoch_dicts(self): + # test with all combinations of encoder/decoder lang tokens + with contextlib.redirect_stdout(StringIO()): + enc_ltok_flag = ["--encoder-langtok", "src"] + dec_ltok_flag = ["--decoder-langtok"] + with tempfile.TemporaryDirectory( + "test_translation_multi_simple_epoch_dict" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir, extra_flags=[]) train_translation_model( data_dir, arch="transformer", @@ -536,11 +595,17 @@ def test_transformer_pointer_generator(self): "0", ], run_validation=True, - extra_valid_flags=["--user-dir", "examples/pointer_generator/pointer_generator_src"], + extra_valid_flags=[ + "--user-dir", + "examples/pointer_generator/pointer_generator_src", + ], ) generate_main( data_dir, - extra_flags=["--user-dir", "examples/pointer_generator/pointer_generator_src"], + extra_flags=[ + "--user-dir", + "examples/pointer_generator/pointer_generator_src", + ], ) def test_lightconv(self): @@ -842,6 +907,38 @@ def test_alignment_full_context(self): ) generate_main(data_dir) + def test_transformer_layerdrop(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_layerdrop") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "3", + "--decoder-layers", + "3", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--encoder-layerdrop", + "0.01", + "--decoder-layerdrop", + "0.01", + ], + ) + generate_main(data_dir) + generate_main( + data_dir, + [ + "--model-overrides", + "{'encoder_layers_to_keep':'0,2','decoder_layers_to_keep':'1'}", + ], + ) + class TestStories(unittest.TestCase): def setUp(self): @@ -951,6 +1048,39 @@ def test_transformer_lm(self): run_validation=True, ) eval_lm_main(data_dir) + eval_lm_main(data_dir, extra_flags=["--context-window", "25"]) + generate_main( + data_dir, + [ + "--task", + "language_modeling", + "--sample-break-mode", + "eos", + "--tokens-per-sample", + "500", + ], + ) + + def test_transformer_lm_with_adaptive_softmax(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory( + "test_transformer_lm_with_adaptive_softmax" + ) as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + train_language_model( + data_dir, + "transformer_lm", + [ + "--add-bos-token", + "--criterion", + "adaptive_loss", + "--adaptive-softmax-cutoff", + "5,10,15", + ], + run_validation=True, + ) + eval_lm_main(data_dir) generate_main( data_dir, [ @@ -1035,6 +1165,36 @@ def test_lstm_lm_residuals(self): ], ) + @unittest.skipIf(not has_hf_transformers, "skip test if transformers is missing") + def test_transformer_xl_bptt_lm(self): + with contextlib.redirect_stdout(StringIO()): + with tempfile.TemporaryDirectory("test_transformer_xl_bptt_lm") as data_dir: + create_dummy_data(data_dir) + preprocess_lm_data(data_dir) + task_flags = [ + "--user-dir", + "examples/truncated_bptt", + "--task", + "truncated_bptt_lm", + "--batch-size", + "2", + "--tokens-per-sample", + "50", + ] + train_language_model( + data_dir=data_dir, + arch="transformer_xl", + extra_flags=task_flags + + [ + "--n-layer", + "2", + ], + task="truncated_bptt_lm", + run_validation=True, + extra_valid_flags=task_flags, + ) + eval_lm_main(data_dir, extra_flags=task_flags) + class TestMaskedLanguageModel(unittest.TestCase): def setUp(self): @@ -1280,7 +1440,7 @@ def train_legacy_masked_language_model(data_dir, arch, extra_args=()): "0.5", "--lr", "0.0001", - "--min-lr", + "--stop-min-lr", "1e-09", # dropout, attention args "--dropout", @@ -1363,6 +1523,65 @@ def test_optimizers(self): generate_main(data_dir) +def read_last_log_entry( + logs: List[logging.LogRecord], logger_name: str +) -> Dict[str, float]: + for x in reversed(logs): + if x.name == logger_name: + return json.loads(x.message) + raise ValueError(f"No entries from {logger_name} found in captured logs") + + +class TestActivationCheckpointing(unittest.TestCase): + def test_activation_checkpointing_does_not_change_metrics(self): + """--checkpoint-activations should not change loss""" + base_flags = [ + "--encoder-layers", + "2", + "--decoder-layers", + "2", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--restore-file", + "x.pt", + "--log-format", + "json", + "--log-interval", + "1", + "--max-update", + "2", + ] + + def _train(extra_flags): + with self.assertLogs() as logs: + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + base_flags + extra_flags, + run_validation=True, + extra_valid_flags=["--log-format", "json"], + ) + return logs.records + + with tempfile.TemporaryDirectory("test_transformer_with_act_cpt") as data_dir: + + create_dummy_data(data_dir, num_examples=20) + preprocess_translation_data(data_dir) + ckpt_logs = _train(["--checkpoint-activations"]) + baseline_logs = _train([]) + assert len(baseline_logs) == len(ckpt_logs) + + baseline_train_stats = read_last_log_entry(baseline_logs, "train") + ckpt_train_stats = read_last_log_entry(ckpt_logs, "train") + assert baseline_train_stats["train_loss"] == ckpt_train_stats["train_loss"] + + baseline_valid_stats = read_last_log_entry(baseline_logs, "valid") + ckpt_valid_stats = read_last_log_entry(ckpt_logs, "valid") + assert baseline_valid_stats["valid_loss"] == ckpt_valid_stats["valid_loss"] + + def create_dummy_roberta_head_data( data_dir, num_examples=100, maxlen=10, num_classes=2, regression=False ): @@ -1478,13 +1697,20 @@ def train_roberta_head(data_dir, arch, num_classes=2, extra_flags=None): train.main(train_args) -def train_language_model(data_dir, arch, extra_flags=None, run_validation=False): +def train_language_model( + data_dir, + arch, + extra_flags=None, + run_validation=False, + extra_valid_flags=None, + task="language_modeling", +): train_parser = options.get_training_parser() train_args = options.parse_args_and_arch( train_parser, [ "--task", - "language_modeling", + task, data_dir, "--arch", arch, @@ -1492,10 +1718,6 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False) "adam", "--lr", "0.0001", - "--criterion", - "adaptive_loss", - "--adaptive-softmax-cutoff", - "5,10,15", "--max-tokens", "500", "--tokens-per-sample", @@ -1523,7 +1745,7 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False) validate_parser, [ "--task", - "language_modeling", + task, data_dir, "--path", os.path.join(data_dir, "checkpoint_last.pt"), @@ -1534,12 +1756,13 @@ def train_language_model(data_dir, arch, extra_flags=None, run_validation=False) "--no-progress-bar", "--num-workers", "0", - ], + ] + + (extra_valid_flags or []), ) validate.main(validate_args) -def eval_lm_main(data_dir): +def eval_lm_main(data_dir, extra_flags=None): eval_lm_parser = options.get_eval_lm_parser() eval_lm_args = options.parse_args_and_arch( eval_lm_parser, @@ -1550,75 +1773,10 @@ def eval_lm_main(data_dir): "--no-progress-bar", "--num-workers", "0", - ], - ) - eval_lm.main(eval_lm_args) - - -def train_masked_language_model(data_dir, arch, extra_args=()): - train_parser = options.get_training_parser() - # TODO: langs should be in and out right? - train_args = options.parse_args_and_arch( - train_parser, - [ - "--task", - "cross_lingual_lm", - data_dir, - "--arch", - arch, - # Optimizer args - "--optimizer", - "adam", - "--lr-scheduler", - "reduce_lr_on_plateau", - "--lr-shrink", - "0.5", - "--lr", - "0.0001", - "--min-lr", - "1e-09", - # dropout, attention args - "--dropout", - "0.1", - "--attention-dropout", - "0.1", - # MLM args - "--criterion", - "masked_lm_loss", - "--masked-lm-only", - "--monolingual-langs", - "in,out", - "--num-segment", - "5", - # Transformer args: use a small transformer model for fast training - "--encoder-layers", - "1", - "--encoder-embed-dim", - "32", - "--encoder-attention-heads", - "1", - "--encoder-ffn-embed-dim", - "32", - # Other training args - "--max-tokens", - "500", - "--tokens-per-sample", - "500", - "--save-dir", - data_dir, - "--max-epoch", - "1", - "--no-progress-bar", - "--distributed-world-size", - "1", - "--dataset-impl", - "raw", - "--num-workers", - "0", ] - + list(extra_args), + + (extra_flags or []), ) - train.main(train_args) + eval_lm.main(eval_lm_args) if __name__ == "__main__": diff --git a/tests/test_checkpoint_utils.py b/tests/test_checkpoint_utils.py new file mode 100644 index 0000000000..e3c685deec --- /dev/null +++ b/tests/test_checkpoint_utils.py @@ -0,0 +1,89 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import logging +import os +import tempfile +import unittest +from io import StringIO + +from fairseq import checkpoint_utils + +from tests.utils import ( + create_dummy_data, + preprocess_translation_data, + train_translation_model, +) + + +class TestCheckpointUtils(unittest.TestCase): + def setUp(self): + logging.disable(logging.CRITICAL) + + def tearDown(self): + logging.disable(logging.NOTSET) + + @contextlib.contextmanager + def _train_transformer(self, seed, extra_args=None): + if extra_args is None: + extra_args = [] + with tempfile.TemporaryDirectory(f"_train_transformer_seed{seed}") as data_dir: + create_dummy_data(data_dir) + preprocess_translation_data(data_dir) + train_translation_model( + data_dir, + "transformer_iwslt_de_en", + [ + "--encoder-layers", + "3", + "--decoder-layers", + "3", + "--encoder-embed-dim", + "8", + "--decoder-embed-dim", + "8", + "--seed", + str(seed), + ] + + extra_args, + ) + yield os.path.join(data_dir, "checkpoint_last.pt") + + def test_load_model_ensemble_and_task(self): + with contextlib.redirect_stdout(StringIO()): + with self._train_transformer(seed=123) as model1: + with self._train_transformer(seed=456) as model2: + ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + filenames=[model1, model2] + ) + self.assertEqual(len(ensemble), 2) + + # after Transformer has been migrated to Hydra, this will probably + # become cfg.common.seed + self.assertEqual(ensemble[0].args.seed, 123) + self.assertEqual(ensemble[1].args.seed, 456) + + # the task from the first model should be returned + self.assertEqual(task.args.seed, 123) + + def test_prune_state_dict(self): + with contextlib.redirect_stdout(StringIO()): + extra_args = ["--encoder-layerdrop", "0.01", "--decoder-layerdrop", "0.01"] + with self._train_transformer(seed=1, extra_args=extra_args) as model: + ensemble, cfg, task = checkpoint_utils.load_model_ensemble_and_task( + filenames=[model], + arg_overrides={ + "encoder_layers_to_keep": "0,2", + "decoder_layers_to_keep": "1", + }, + ) + self.assertEqual(len(ensemble), 1) + self.assertEqual(len(ensemble[0].encoder.layers), 2) + self.assertEqual(len(ensemble[0].decoder.layers), 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_iopath.py b/tests/test_iopath.py new file mode 100644 index 0000000000..908261a661 --- /dev/null +++ b/tests/test_iopath.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest import mock + + +class TestIOPath(unittest.TestCase): + + def test_no_iopath(self): + from .test_reproducibility import TestReproducibility + + with mock.patch.dict("sys.modules", {"iopath": None}): + # reuse reproducibility tests, which are e2e tests that should cover + # most checkpoint related functionality + TestReproducibility._test_reproducibility(self, "test_reproducibility") + + def test_no_supports_rename(self): + from .test_reproducibility import TestReproducibility + + with mock.patch("fairseq.file_io.PathManager.supports_rename") as mock_fn: + mock_fn.return_value = False + TestReproducibility._test_reproducibility(self, "test_reproducibility") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_lm_context_window.py b/tests/test_lm_context_window.py new file mode 100644 index 0000000000..7415e86abd --- /dev/null +++ b/tests/test_lm_context_window.py @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from fairseq.data import MonolingualDataset +from fairseq.tasks.language_modeling import LanguageModelingTask, LanguageModelingConfig +from tests import utils as test_utils + + +class TestLMContextWindow(unittest.TestCase): + + def test_eval_dataloader(self): + dictionary = test_utils.dummy_dictionary(10) + assert len(dictionary) == 14 # 4 extra special symbols + assert dictionary.pad() == 1 + + dataset = test_utils.TestDataset([ + torch.tensor([4, 5, 6, 7], dtype=torch.long), + torch.tensor([8, 9, 10, 11], dtype=torch.long), + torch.tensor([12, 13], dtype=torch.long), + ]) + dataset = MonolingualDataset(dataset, sizes=[4, 4, 2], src_vocab=dictionary) + + config = LanguageModelingConfig(tokens_per_sample=4) + task = LanguageModelingTask(config, dictionary) + + eval_dataloader = task.eval_lm_dataloader( + dataset=dataset, + batch_size=1, + context_window=2, + ) + + batch = next(eval_dataloader) + assert batch["net_input"]["src_tokens"][0].tolist() == [4, 5, 6, 7, 1, 1] + assert batch["target"][0].tolist() == [4, 5, 6, 7, 1, 1] + + batch = next(eval_dataloader) + assert batch["net_input"]["src_tokens"][0].tolist() == [6, 7, 8, 9, 10, 11] + assert batch["target"][0].tolist() == [1, 1, 8, 9, 10, 11] + + batch = next(eval_dataloader) + assert batch["net_input"]["src_tokens"][0].tolist() == [10, 11, 12, 13] + assert batch["target"][0].tolist() == [1, 1, 12, 13] + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_reproducibility.py b/tests/test_reproducibility.py index 517e23c39e..405d545593 100644 --- a/tests/test_reproducibility.py +++ b/tests/test_reproducibility.py @@ -26,7 +26,7 @@ def _test_reproducibility( ): def get_last_log_stats_containing_string(log_records, search_string): for log_record in logs.records[::-1]: - if search_string in log_record.msg: + if isinstance(log_record.msg, str) and search_string in log_record.msg: return json.loads(log_record.msg) if extra_flags is None: diff --git a/tests/utils.py b/tests/utils.py index a145aa587d..178df5763e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -345,18 +345,20 @@ def train_translation_model( validate.main(validate_args) -def generate_main(data_dir, extra_flags=None): +def generate_main(data_dir, extra_flags=None, path=None): if extra_flags is None: extra_flags = [ "--print-alignment", ] + if path is None: + path = os.path.join(data_dir, "checkpoint_last.pt") generate_parser = options.get_generation_parser() generate_args = options.parse_args_and_arch( generate_parser, [ data_dir, "--path", - os.path.join(data_dir, "checkpoint_last.pt"), + path, "--beam", "3", "--batch-size", diff --git a/uwr_related/configs/scribblelens_base.yaml b/uwr_related/configs/scribblelens_base.yaml new file mode 100644 index 0000000000..3c860c3025 --- /dev/null +++ b/uwr_related/configs/scribblelens_base.yaml @@ -0,0 +1,78 @@ +# @package _group_ + +common: + fp16: false + log_format: json + log_interval: 20 + tensorboard_logdir: tensorboard + +checkpoint: + keep_last_epochs: 3 + +task: + _name: scribblelens + data: /pio/scratch/2/jch/wav2vec/data/scribblelens + vocab_path: '${env:PWD}/fairseq/data/handwriting/tasman.alphabet.plus.space.mode5.json' + enable_padding: True + pad_to_multiples_of: 4 + max_sample_size: 250000 + min_sample_size: 32000 + normalize: false + labels: True + +dataset: + num_workers: 0 + max_tokens: 10000 + skip_invalid_size_inputs_valid_test: true + valid_subset: test + +distributed_training: + distributed_world_size: 1 + ddp_backend: no_c10d + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 10] + +optimization: + max_update: 400000 + lr: [0.0003] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 20000 + +model: + _name: wav2vec2_scribblelens + conv_feature_layers: '[(64, (3, 3), (1, 2), (1, 1)), (128, (5, 5), (2, 2), (2, 2)), (256, (3,3), (1, 1), (1, 1)), (256, (3,3), (1, 2), (1, 1)), (512, (3,3), (1, 1), (1, 1)), (512, (3,3), (1, 2), (1, 1)), (512, (3,2), (2, 1), (1, 0))]' + quantize_targets: true + final_dim: 256 + encoder_embed_dim: 768 + + encoder_layerdrop: 0.05 + dropout: 0.1 + attention_dropout: 0.1 + dropout_input: 0.1 + dropout_features: 0.1 + feature_grad_mult: 0.1 + + latent_vars: 320 + latent_groups: 2 + latent_temp: [2,0.5,0.999995] + + probe_defs: + post_extract_proj_mlp: + cls: Conv1DProbe + module_name: post_extract_proj + layer_dims: [768, 512, 73] + kernel_size: 3 + output_selector: 'lambda x: {"output": x.transpose(1, 2)}' + target_selector: 'lambda x: {"target":x["alignments"], "padding_mask": x["net_input"].get("padding_mask")}' diff --git a/uwr_related/configs/scribblelens_base_finetune.yaml b/uwr_related/configs/scribblelens_base_finetune.yaml new file mode 100644 index 0000000000..f2cc52cd48 --- /dev/null +++ b/uwr_related/configs/scribblelens_base_finetune.yaml @@ -0,0 +1,67 @@ +# @package _group_ + +common: + fp16: false + log_format: json + log_interval: 20 + tensorboard_logdir: tensorboard + +checkpoint: + keep_last_epochs: 3 + best_checkpoint_metric: wer + +task: + _name: scribblelens + data: /pio/scratch/2/jch/wav2vec/data/scribblelens + vocab_path: '${env:PWD}/fairseq/data/handwriting/tasman.alphabet.plus.space.mode5.json' + enable_padding: True + pad_to_multiples_of: 4 + max_sample_size: 250000 + min_sample_size: 32000 + normalize: false + labels: True + +dataset: + num_workers: 0 + max_tokens: 10000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 1000 + valid_subset: test + +distributed_training: + distributed_world_size: 1 + ddp_backend: no_c10d + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [0.00005] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + diff --git a/uwr_related/configs/scribblelens_base_onlyprobetrain.yaml b/uwr_related/configs/scribblelens_base_onlyprobetrain.yaml new file mode 100644 index 0000000000..bb83319232 --- /dev/null +++ b/uwr_related/configs/scribblelens_base_onlyprobetrain.yaml @@ -0,0 +1,76 @@ +# @package _group_ + +common: + fp16: false + log_format: json + log_interval: 20 + tensorboard_logdir: tensorboard + +checkpoint: + keep_last_epochs: 3 + +task: + _name: scribblelens + data: /pio/scratch/2/jch/wav2vec/data/scribblelens + vocab_path: '${env:PWD}/fairseq/data/handwriting/tasman.alphabet.plus.space.mode5.json' + enable_padding: True + pad_to_multiples_of: 4 + max_sample_size: 250000 + min_sample_size: 32000 + normalize: false + labels: True + +dataset: + num_workers: 0 + max_tokens: 10000 + skip_invalid_size_inputs_valid_test: true + valid_subset: test + +distributed_training: + distributed_world_size: 1 + ddp_backend: no_c10d + +criterion: + _name: probes + +optimization: + max_update: 400000 + lr: [0.0003] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 20000 + +model: + _name: wav2vec2_scribblelens + conv_feature_layers: '[(64, (3, 3), (1, 2), (1, 1)), (128, (5, 5), (2, 2), (2, 2)), (256, (3,3), (1, 1), (1, 1)), (256, (3,3), (1, 2), (1, 1)), (512, (3,3), (1, 1), (1, 1)), (512, (3,3), (1, 2), (1, 1)), (512, (3,2), (2, 1), (1, 0))]' + quantize_targets: true + final_dim: 256 + encoder_embed_dim: 768 + + encoder_layerdrop: 0.05 + dropout: 0.1 + attention_dropout: 0.1 + dropout_input: 0.1 + dropout_features: 0.1 + feature_grad_mult: 0.1 + + latent_vars: 320 + latent_groups: 2 + latent_temp: [2,0.5,0.999995] + + probe_defs: + post_extract_proj_mlp: + cls: Conv1DProbe + module_name: post_extract_proj + layer_dims: [768, 512, 73] + kernel_size: 3 + backprop_to_main: true + output_selector: 'lambda x: {"output": x.transpose(1, 2)}' + target_selector: 'lambda x: {"target":x["alignments"], "padding_mask": x["net_input"].get("padding_mask")}' diff --git a/uwr_related/configs/scribblelens_base_probetrain.yaml b/uwr_related/configs/scribblelens_base_probetrain.yaml new file mode 100644 index 0000000000..3d1261adb7 --- /dev/null +++ b/uwr_related/configs/scribblelens_base_probetrain.yaml @@ -0,0 +1,79 @@ +# @package _group_ + +common: + fp16: false + log_format: json + log_interval: 20 + tensorboard_logdir: tensorboard + +checkpoint: + keep_last_epochs: 3 + +task: + _name: scribblelens + data: /pio/scratch/2/jch/wav2vec/data/scribblelens + vocab_path: '${env:PWD}/fairseq/data/handwriting/tasman.alphabet.plus.space.mode5.json' + enable_padding: True + pad_to_multiples_of: 4 + max_sample_size: 250000 + min_sample_size: 32000 + normalize: false + labels: True + +dataset: + num_workers: 0 + max_tokens: 10000 + skip_invalid_size_inputs_valid_test: true + valid_subset: test + +distributed_training: + distributed_world_size: 1 + ddp_backend: no_c10d + +criterion: + _name: wav2vec + infonce: true + log_keys: ["prob_perplexity","code_perplexity","temp"] + loss_weights: [0.1, 10] + +optimization: + max_update: 400000 + lr: [0.0003] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-06 + weight_decay: 0.01 + +lr_scheduler: + _name: polynomial_decay + warmup_updates: 20000 + +model: + _name: wav2vec2_scribblelens + conv_feature_layers: '[(64, (3, 3), (1, 2), (1, 1)), (128, (5, 5), (2, 2), (2, 2)), (256, (3,3), (1, 1), (1, 1)), (256, (3,3), (1, 2), (1, 1)), (512, (3,3), (1, 1), (1, 1)), (512, (3,3), (1, 2), (1, 1)), (512, (3,2), (2, 1), (1, 0))]' + quantize_targets: true + final_dim: 256 + encoder_embed_dim: 768 + + encoder_layerdrop: 0.05 + dropout: 0.1 + attention_dropout: 0.1 + dropout_input: 0.1 + dropout_features: 0.1 + feature_grad_mult: 0.1 + + latent_vars: 320 + latent_groups: 2 + latent_temp: [2,0.5,0.999995] + + probe_defs: + post_extract_proj_mlp: + cls: Conv1DProbe + module_name: post_extract_proj + layer_dims: [768, 512, 73] + kernel_size: 3 + backprop_to_main: true + output_selector: 'lambda x: {"output": x.transpose(1, 2)}' + target_selector: 'lambda x: {"target":x["alignments"], "padding_mask": x["net_input"].get("padding_mask")}' diff --git a/uwr_related/ctc_eval_audio.sh b/uwr_related/ctc_eval_audio.sh new file mode 100644 index 0000000000..e992af1f35 --- /dev/null +++ b/uwr_related/ctc_eval_audio.sh @@ -0,0 +1,12 @@ +# kenlm bib --w2l-decoder kenlm +# raw numbers --w2l-decoder viterbi +# transformer language model --w2l-decoder fairseqlm + +python examples/speech_recognition/infer.py \ +/checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw \ +--task audio_pretraining \ +--nbest 1 --path /path/to/model --gen-subset dev_other \ +--results-path /path/to/save/results/for/sclite --w2l-decoder viterbi \ +--lm-model /path/to/kenlm.bin --lm-weight 2 --word-score -1 --sil-weight 0 \ +--criterion ctc --labels ltr --max-tokens 4000000 \ +--post-process letter \ No newline at end of file diff --git a/uwr_related/ctc_fine_tune_audio.sh b/uwr_related/ctc_fine_tune_audio.sh new file mode 100644 index 0000000000..08155e50a6 --- /dev/null +++ b/uwr_related/ctc_fine_tune_audio.sh @@ -0,0 +1,16 @@ +# omit the --wer-args for no evaluation +# --reset-optiimzer required some times. add it when asked + +python train.py --distributed-world-size 1 ../data/LibriSpeech/ \ + --save-dir ../try1_ctc \ + `# --wer-args '("/path/to/lm/4-gram.bin","/path/to/lexicon",2,-1)'` \ + --post-process letter --valid-subset valid --no-epoch-checkpoints --best-checkpoint-metric wer --num-workers 0 \ + --max-update 80000 --sentence-avg --task audio_pretraining --arch wav2vec_ctc \ + --w2v-path ../try1/checkpoint19.pt \ + --labels ltr --apply-mask --mask-selection static --mask-other 0 --mask-length 10 --mask-prob 0.5 --layerdrop 0.1 \ + --mask-channel-selection static --mask-channel-other 0 --mask-channel-length 64 --mask-channel-prob 0.5 --zero-infinity \ + --feature-grad-mult 0.0 --freeze-finetune-updates 10000 --validate-after-updates 10000 --optimizer adam \ + --adam-betas '(0.9, 0.98)' --adam-eps 1e-08 --lr 2e-05 --lr-scheduler tri_stage --warmup-steps 8000 --hold-steps 32000 \ + --decay-steps 40000 --final-lr-scale 0.05 --final-dropout 0.0 --dropout 0.0 --activation-dropout 0.1 --criterion ctc \ + --attention-dropout 0.0 --max-tokens 1280000 --seed 2337 --log-format json --log-interval 500 --ddp-backend no_c10d \ + `#--reset-optimizer` \ No newline at end of file diff --git a/uwr_related/ctc_fine_tune_scribble.sh b/uwr_related/ctc_fine_tune_scribble.sh new file mode 100644 index 0000000000..258ab6fdaa --- /dev/null +++ b/uwr_related/ctc_fine_tune_scribble.sh @@ -0,0 +1,3 @@ +python hydra-train.py hydra.run.dir=/pio/lscratch/1/jch/fairseq/try1/finetune \ + model.w2v_path=/pio/lscratch/1/jch/fairseq/try1/checkpoints/checkpoint_last.pt \ + --config-dir uwr_related/configs --config-name scribblelens_base_finetune diff --git a/uwr_related/experiments/jch/scrib.sh b/uwr_related/experiments/jch/scrib.sh index 0f998977a9..31b6c8d6e6 100644 --- a/uwr_related/experiments/jch/scrib.sh +++ b/uwr_related/experiments/jch/scrib.sh @@ -1,21 +1 @@ -python train.py --distributed-world-size 1 --update-freq 2 \ - /pio/scratch/2/jch/wav2vec/data/scribblelens \ - --save-dir /pio/lscratch/1/jch/fairseq/try_sl2 --num-workers 0 \ - --keep-last-epochs 3 \ - --tensorboard-logdir /pio/scratch/2/jch/wav2vec/runs/try_sl2 --log-format simple \ - --task scribblelens --criterion wav2vec --arch wav2vec2_scribblelens \ - --valid-subset test --pad-to-multiples-of 4 `#--max-sample-size 256` \ - --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ - --conv-feature-layers '[(64, (3, 3), (1, 2), (1, 1)), (128, (5, 5), (2, 2), (2, 2)), (256, (3,3), (1, 1), (1, 1)), (256, (3,3), (1, 2), (1, 1)), (512, (3,3), (1, 1), (1, 1)), (512, (3,3), (1, 2), (1, 1)), (512, (3,2), (2, 1), (1, 0))]' \ - --final-dim 256 \ - --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce \ - --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay \ - --total-num-update 400000 --lr 0.0005 --warmup-updates 32000 \ - --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \ - --encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \ - --loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 \ - --num-negatives 100 --cross-sample-negatives 0 \ - `#--max-sample-size 250000 --min-sample-size 32000` \ - --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 10000 --max-update 400000 \ - --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d \ - --enable-padding # crashes without that, needs to make all lines same-size \ No newline at end of file +python hydra-train.py hydra.run.dir=/pio/lscratch/1/jch/fairseq/try1 --config-dir uwr_related/configs --config-name scribblelens_base \ No newline at end of file diff --git a/uwr_related/readme.md b/uwr_related/readme.md index cd7645531f..c2ba9248c8 100644 --- a/uwr_related/readme.md +++ b/uwr_related/readme.md @@ -42,6 +42,9 @@ # Fetch part of librispeech wget http://www.openslr.org/resources/12/dev-clean.tar.gz tar zxvf dev-clean.tar.gz + # for CTC (really needed?) + # wget http://www.openslr.org/resources/12/dev-other.tar.gz + # tar zxvf dev-other.tar.gz # Fetch scribblelens mkdir scribblelens @@ -60,3 +63,61 @@ - Scribblelens: TODO + +8. CTC: + - Audio: + 1. Dependencies (needed for CTC evaluation - https://github.com/facebookresearch/wav2letter/wiki/Building-Python-bindings): + + ``` + conda install -c conda-forge fftw + conda install -c conda-forge cmake + conda install -c conda-forge openblas + + cd .. + git clone git@github.com:kpu/kenlm + cd kenlm + mkdir -p build + cd build + cmake .. + make -j 4 + export KENLM_ROOT_DIR=/path/to/kenlm + + cd ../.. + git clone git@github.com:facebookresearch/wav2letter.git + cd wav2letter + git checkout tlikhomanenko-patch-1 + cd bindings/python/ + pip install -e . + ``` + + 2. Download dictionary and generate vocab for train split: + + ``` + wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ../data/LibriSpeech + python examples/wav2vec/libri_labels.py ../data/LibriSpeech/valid.tsv --output-dir ../data/LibriSpeech --output-name valid + ``` + + 3. Get pretrained language model (optional for audio - check `--w2l-decoder` flag in CTC evaluation script): + TODO + ``` + cd .. + mkdir pretrained_models + cd pretrained_models + wget https://dl.fbaipublicfiles.com/wav2letter/sota/2019/lm/lm_librispeech_word_transformer.pt + # ??? Be sure to upper-case the language model vocab after downloading it. ??? + wget https://dl.fbaipublicfiles.com/wav2letter/sota/2019/lm/lm_librispeech_word_transformer.dict + + ``` + + 4. Fine-tune a pretrained model with CTC: + + From the top of `fairseq` repo call: `uwr_related/bash ctc_fine_tune_audio.sh` + + 5. Evaluate a CTC model: + From the top of `fairseq` repo call: `uwr_related/bash ctc_eval_audio.sh` + + - Scribblelens: + + ``` + bash uwr_related/ctc_fine_tune_scribble.sh + ``` \ No newline at end of file diff --git a/uwr_related/test_cmd_audio.sh b/uwr_related/test_cmd_audio.sh index 27d763babc..6c6f24d67f 100644 --- a/uwr_related/test_cmd_audio.sh +++ b/uwr_related/test_cmd_audio.sh @@ -15,5 +15,5 @@ python train.py --distributed-world-size 1 --update-freq 2 ../data/LibriSpeech/ --encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \ --loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 \ --num-negatives 100 --cross-sample-negatives 0 --max-sample-size 250000 --min-sample-size 32000 \ - --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 1400000 --max-update 400000 \ + --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 500000 --max-update 400000 \ --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d diff --git a/uwr_related/test_cmd_scribble.sh b/uwr_related/test_cmd_scribble.sh old mode 100644 new mode 100755 index d797baa819..c2dcf61bf7 --- a/uwr_related/test_cmd_scribble.sh +++ b/uwr_related/test_cmd_scribble.sh @@ -1,58 +1,3 @@ -# Changes: -# - not distributed -# - no fp16 (data loader randomly dying) -# - using 0 workers (somehow data loading dies when >0 and multiprocessing) - -# python train.py --distributed-world-size 1 --update-freq 2 /pio/scratch/2/mstyp/wav2vec/data/LibriSpeech \ -# --save-dir /pio/scratch/2/mstyp/wav2vec/try1 --num-workers 0 \ -# --task audio_pretraining --criterion wav2vec --arch wav2vec2 \ -# --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ -# --conv-feature-layers '[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2' --final-dim 256 \ -# --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce \ -# --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay \ -# --total-num-update 400000 --lr 0.0005 --warmup-updates 32000 \ -# --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \ -# --encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \ -# --loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 \ -# --num-negatives 100 --cross-sample-negatives 0 --max-sample-size 250000 --min-sample-size 32000 \ -# --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 1400000 --max-update 400000 \ -# --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d - -# [!, need to do this setup first] scp <...>/DistSup/egs/scribblelens/tasman.alphabet.plus.space.mode5.json

-# python train.py --distributed-world-size 1 --update-freq 2 \ -# /home/jch/scratch/wav2vec/data/scribblelens \ -# --save-dir ../try_sl1 --num-workers 0 \ -# --task scribblelens --criterion wav2vec --arch wav2vec2 \ -# --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ -# --conv-feature-layers '[(512, (32,10), 5)] + [(512, (1,3), 2)] * 4 + [(512,(1,2),2)] * 2' --final-dim 256 \ -# --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce \ -# --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay \ -# --total-num-update 400000 --lr 0.0005 --warmup-updates 32000 \ -# --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \ -# --encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \ -# --loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 \ -# --num-negatives 100 --cross-sample-negatives 0 --max-sample-size 250000 --min-sample-size 32000 \ -# --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 1400000 --max-update 400000 \ -# --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d \ -# --enable_padding # crashes without that, needs to make all lines same-size - - -python train.py --distributed-world-size 1 --update-freq 2 \ - /pio/scratch/1/i283340/MGR/NewSetup/DistSup/data \ - --save-dir ../try_sl1 --num-workers 0 \ - --task scribblelens --criterion wav2vec --arch wav2vec2_scribblelens \ - --valid-subset test --pad-to-multiples-of 4 `#--max-sample-size 256` \ - --log-keys '["prob_perplexity","code_perplexity","temp"]' --quantize-targets --extractor-mode default \ - --conv-feature-layers '[(64, (3, 3), (1, 2), (1, 1)), (128, (5, 5), (2, 2), (2, 2)), (256, (3,3), (1, 1), (1, 1)), (256, (3,3), (1, 2), (1, 1)), (512, (3,3), (1, 1), (1, 1)), (512, (3,3), (1, 2), (1, 1)), (512, (3,2), (2, 1), (1, 0))]' \ - --final-dim 256 \ - --latent-vars 320 --latent-groups 2 --latent-temp '(2,0.5,0.999995)' --infonce \ - --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-06 --lr-scheduler polynomial_decay \ - --total-num-update 400000 --lr 0.0005 --warmup-updates 32000 \ - --mask-length 10 --mask-prob 0.65 --mask-selection static --mask-other 0 \ - --encoder-layerdrop 0.05 --dropout-input 0.1 --dropout-features 0.1 --feature-grad-mult 0.1 \ - --loss-weights '[0.1, 10]' --conv-pos 128 --conv-pos-groups 16 \ - --num-negatives 100 --cross-sample-negatives 0 \ - `#--max-sample-size 250000 --min-sample-size 32000` \ - --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 --max-tokens 10000 --max-update 400000 \ - --skip-invalid-size-inputs-valid-test --ddp-backend no_c10d \ - --enable-padding # crashes without that, needs to make all lines same-size \ No newline at end of file +python hydra-train.py \ + hydra.run.dir=/pio/lscratch/1/jch/fairseq/try1 `# optional, if unspecified will save to outputs/date folder`\ + --config-dir uwr_related/configs --config-name scribblelens_base