Skip to content

Commit

Permalink
PLbart support (#709)
Browse files Browse the repository at this point in the history
  • Loading branch information
FahadEbrahim authored Jul 13, 2024
1 parent dc53695 commit a6387c1
Show file tree
Hide file tree
Showing 43 changed files with 1,061 additions and 43 deletions.
20 changes: 20 additions & 0 deletions docs/classes/models/plbart.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
PLBART
=====

The PLBART model was proposed in [Unified Pre-training for Program Understanding and Generation](https://arxiv.org/abs/2103.06333) by Wasi Uddin Ahmad, Saikat Chakraborty, Baishakhi Ray, Kai-Wei Chang.
This is a BART-like model which can be used to perform code-summarization, code-generation, and code-translation tasks. The pre-trained model `plbart-base` has been trained using multilingual denoising task
on Java, Python and English.

According to the abstract,

- PLBART is a sequence-to-sequence model capable of performing a broad spectrum of program and language understanding and generation tasks
- PLBART is pre-trained on an extensive collection of Java and Python functions and associated NL text via denoising autoencoding.
- PLBART learns program syntax, style (e.g., identifier naming convention) and logical flow.


PLBartAdapterModel
~~~~~~~~~~~~~~~~~~~~

.. autoclass:: adapters.PLBartAdapterModel
:members:
:inherited-members: PLBartPretrainedModel
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Currently, we support the PyTorch versions of all models as listed on the `Model
classes/models/llama
classes/models/mbart
classes/models/mt5
classes/models/plbart
classes/models/roberta
classes/models/t5
classes/models/vit
Expand Down
1 change: 1 addition & 0 deletions docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ The table below further shows which model architectures support which adaptation
| [Llama](classes/models/llama.html) |||||||| ||
| [MBart](classes/models/mbart.html) |||||||| ||
| [MT5](classes/models/mt5.html) |||||||| ||
| [PLBart](classes/models/plbart.html) |||||||| ||
| [RoBERTa](classes/models/roberta.html) ||||||||||
| [T5](classes/models/t5.html) |||||||| ||
| [ViT](classes/models/vit.html) ||||||||||
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/adapterfusion/run_fusion_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for sequence classification on
"""Finetuning the library models for sequence classification on
GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa)."""


Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/dependency-parsing/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021)
https://arxiv.org/abs/2012.15613
"""

from collections import defaultdict
from typing import List

Expand Down
25 changes: 16 additions & 9 deletions examples/pytorch/dependency-parsing/run_udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021)
https://arxiv.org/abs/2012.15613
"""

import logging
import os
import sys
Expand Down Expand Up @@ -156,9 +157,11 @@ def main():
use_fast=model_args.use_fast,
do_lower_case=model_args.do_lower_case,
add_prefix_space=True, # Used e.g. for RoBERTa
mecab_kwargs={"mecab_option": f"-r {model_args.mecab_dir} -d {model_args.mecab_dic_dir}"}
if model_args.is_japanese
else None,
mecab_kwargs=(
{"mecab_option": f"-r {model_args.mecab_dir} -d {model_args.mecab_dic_dir}"}
if model_args.is_japanese
else None
),
)

# The task name (with prefix)
Expand Down Expand Up @@ -244,19 +247,23 @@ def main():
if adapter_args.train_adapter:
adapter_config = AdapterConfig.load(adapter_args.adapter_config, **adapter_config_kwargs)
model.load_adapter(
os.path.join(training_args.output_dir, "best_model", task_name)
if training_args.do_train
else adapter_args.load_adapter,
(
os.path.join(training_args.output_dir, "best_model", task_name)
if training_args.do_train
else adapter_args.load_adapter
),
config=adapter_config,
load_as=task_name,
**adapter_load_kwargs,
)
if adapter_args.load_lang_adapter:
lang_adapter_config = AdapterConfig.load(adapter_args.lang_adapter_config, **adapter_config_kwargs)
lang_adapter_name = model.load_adapter(
os.path.join(training_args.output_dir, "best_model", lang_adapter_name)
if training_args.do_train
else adapter_args.load_lang_adapter,
(
os.path.join(training_args.output_dir, "best_model", lang_adapter_name)
if training_args.do_train
else adapter_args.load_lang_adapter
),
config=lang_adapter_config,
load_as=lang_adapter_name,
**adapter_load_kwargs,
Expand Down
1 change: 1 addition & 0 deletions examples/pytorch/dependency-parsing/utils_udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Credits: "How Good is Your Tokenizer? On the Monolingual Performance of Multilingual Language Models" (Rust et al., 2021)
https://arxiv.org/abs/2012.15613
"""

import collections
import logging
import os
Expand Down
6 changes: 3 additions & 3 deletions examples/pytorch/language-modeling/run_clm.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,9 +541,9 @@ def compute_metrics(eval_preds):
# Data collator will default to DataCollatorWithPadding, so we change it.
data_collator=default_data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
preprocess_logits_for_metrics=(
preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available() else None
),
)

# Training
Expand Down
6 changes: 3 additions & 3 deletions examples/pytorch/language-modeling/run_mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ def compute_metrics(eval_preds):
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
if training_args.do_eval and not is_torch_tpu_available()
else None,
preprocess_logits_for_metrics=(
preprocess_logits_for_metrics if training_args.do_eval and not is_torch_tpu_available() else None
),
)

# Training
Expand Down
2 changes: 1 addition & 1 deletion examples/pytorch/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for sequence classification on GLUE."""
"""Finetuning the library models for sequence classification on GLUE."""
# You can also adapt this script on your own text classification task. Pointers for this are left as comments.

import logging
Expand Down
3 changes: 1 addition & 2 deletions examples/pytorch/text-generation/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)
"""
"""Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)"""


import argparse
Expand Down
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
"models.llama": ["LlamaAdapterModel"],
"models.mbart": ["MBartAdapterModel"],
"models.mt5": ["MT5AdapterModel"],
"models.plbart": ["PLBartAdapterModel"],
"models.roberta": ["RobertaAdapterModel"],
"models.t5": ["T5AdapterModel"],
"models.vit": ["ViTAdapterModel"],
Expand Down Expand Up @@ -217,6 +218,7 @@
from .models.llama import LlamaAdapterModel
from .models.mbart import MBartAdapterModel
from .models.mt5 import MT5AdapterModel
from .models.plbart import PLBartAdapterModel
from .models.roberta import RobertaAdapterModel
from .models.t5 import T5AdapterModel
from .models.vit import ViTAdapterModel
Expand Down
1 change: 1 addition & 0 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __init__(
"bart",
"mbart",
"mt5",
"plbart",
"gpt2",
"gptj",
"t5",
Expand Down
21 changes: 21 additions & 0 deletions src/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,27 @@
},
"layers": ["lm_head"],
},
# PLBART
"PLBartForSequenceClassification": {
"config": {
"head_type": "classification",
"layers": 2,
"activation_function": "tanh",
},
"layers": [
None,
"classification_head.dense",
None,
None,
"classification_head.out_proj",
],
},
"PLBartForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
},
"layers": ["lm_head"],
},
# MT5
"MT5ForConditionalGeneration": {
"config": {
Expand Down
1 change: 1 addition & 0 deletions src/adapters/heads/dependency_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Code taken and modified from: https://github.com/Adapter-Hub/hgiyt. Credits: "How Good is Your Tokenizer? On the
Monolingual Performance of Multilingual Language Models" (Rust et al., 2021) https://arxiv.org/abs/2012.15613
"""

from dataclasses import dataclass
from typing import Optional, Tuple

Expand Down
8 changes: 5 additions & 3 deletions src/adapters/methods/bottleneck.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,11 @@ def pad_and_concat(self, states: List[BottleneckState]) -> BottleneckState:
torch.cat([state.input_tensor for state in states], dim=0),
torch.cat([state.adapter_residual for state in states], dim=0),
states[0].layer_norm,
torch.cat([state.bottleneck_up for state in states], dim=0)
if states[0].bottleneck_up is not None
else None,
(
torch.cat([state.bottleneck_up for state in states], dim=0)
if states[0].bottleneck_up is not None
else None
),
states[-1].last,
)

Expand Down
8 changes: 5 additions & 3 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,11 @@ def repeat(self, state: LoRAState, channels: int) -> LoRAState:
def mean(self, states: List[LoRAState], weights: torch.Tensor) -> LoRAState:
return LoRAState(
states[0].layer_input,
torch.mean(torch.stack([s.hidden_states for s in states], dim=0) * weights, dim=0)
if states[0].hidden_states is not None
else None,
(
torch.mean(torch.stack([s.hidden_states for s in states], dim=0) * weights, dim=0)
if states[0].hidden_states is not None
else None
),
states[0].layer_output,
states[-1].last,
)
Expand Down
12 changes: 11 additions & 1 deletion src/adapters/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
from .gpt2.mixin_gpt2 import GPT2ModelAdapterMixin
from .gptj.mixin_gptj import GPTJMLPAdaptersMixin, GPTJModelAdapterMixin
from .llama.mixin_llama import LlamaForQuestionAnsweringAdapterMixin, LlamaModelAdapterMixin
from .plbart.mixin_plbart import (
PLBartDecoderAdaptersMixin,
PLBartDecoderWrapperAdaptersMixin,
PLBartEncoderAdaptersMixin,
PLBartModelAdaptersMixin,
)
from .t5.mixin_t5 import (
T5BlockAdaptersMixin,
T5ForCondiditionalGenerationWithHeadsMixin,
Expand All @@ -33,8 +39,8 @@
"AlbertModel": AlbertModelAdaptersMixin,
"BartEncoder": BartEncoderAdaptersMixin,
"BartDecoder": BartDecoderAdaptersMixin,
"BartModel": BartModelAdaptersMixin,
"BartDecoderWrapper": BartDecoderWrapperAdaptersMixin,
"BartModel": BartModelAdaptersMixin,
"BeitIntermediate": BeitIntermediateAdaptersMixin,
"BeitOutput": BeitOutputAdaptersMixin,
"BeitModel": BeitModelAdaptersMixin,
Expand All @@ -60,6 +66,10 @@
"MT5ForConditionalGeneration": T5ForCondiditionalGenerationWithHeadsMixin,
"MT5ForQuestionAnswering": T5ForQuestionAnsweringWithHeadsMixin,
"MT5EncoderModel": T5ModelAdaptersMixin,
"PLBartEncoder": PLBartEncoderAdaptersMixin,
"PLBartDecoder": PLBartDecoderAdaptersMixin,
"PLBartModel": PLBartModelAdaptersMixin,
"PLBartDecoderWrapper": PLBartDecoderWrapperAdaptersMixin,
"GPT2Model": GPT2ModelAdapterMixin,
"GPTJMLP": GPTJMLPAdaptersMixin,
"GPTJModel": GPTJModelAdapterMixin,
Expand Down
1 change: 1 addition & 0 deletions src/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
("llama", "LlamaAdapterModel"),
("mbart", "MBartAdapterModel"),
("mt5", "MT5AdapterModel"),
("plbart", "PLBartAdapterModel"),
("roberta", "RobertaAdapterModel"),
("t5", "T5AdapterModel"),
("vit", "ViTAdapterModel"),
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/bart/modeling_bart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch BART model."""
"""PyTorch BART model."""
from typing import Optional, Tuple

import torch
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/beit/modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch BEiT model."""
"""PyTorch BEiT model."""


import math
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch CLIP model."""
"""PyTorch CLIP model."""


from typing import Optional, Tuple
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch DeBERTa model."""
"""PyTorch DeBERTa model."""

import torch
import torch.utils.checkpoint
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch DeBERTa-v2 model."""
"""PyTorch DeBERTa-v2 model."""

import torch
import torch.utils.checkpoint
Expand Down
4 changes: 2 additions & 2 deletions src/adapters/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# limitations under the License.

"""
PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in
part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
PyTorch DistilBERT model adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) and in
part from HuggingFace PyTorch version of Google AI Bert model (https://github.com/google-research/bert)
"""


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Classes to support Encoder-Decoder architectures"""
"""Classes to support Encoder-Decoder architectures"""

from transformers.models.encoder_decoder.modeling_encoder_decoder import EncoderDecoderModel

Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch GPT-J model."""
"""PyTorch GPT-J model."""

from typing import Optional, Tuple, Union

Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
"""PyTorch LLaMA model."""
import math
import warnings
from typing import Optional, Tuple
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/mbart/modeling_mbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch MBART model."""
"""PyTorch MBART model."""
from typing import Optional, Tuple

import torch
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch MT5 model."""
"""PyTorch MT5 model."""

import torch
from torch import nn
Expand Down
Loading

0 comments on commit a6387c1

Please sign in to comment.