Skip to content

Commit

Permalink
update configuration of chatglm and add support in coati
Browse files Browse the repository at this point in the history
  • Loading branch information
yingliu-hpc committed Aug 14, 2023
1 parent ff83679 commit 2f492fa
Show file tree
Hide file tree
Showing 7 changed files with 2,070 additions and 9 deletions.
41 changes: 35 additions & 6 deletions applications/Chat/coati/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer

from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
from colossalai.logging import get_dist_logger

from .utils import is_rank_0, jload
Expand Down Expand Up @@ -71,6 +71,28 @@ def _preprocess(sources: Sequence[str],
return sequences_token["input_ids"], labels, sequences_token["attention_mask"]


def _preprocess_chatglm(sources: Sequence[str],
targets: Sequence[str],
tokenizer: PreTrainedTokenizer,
max_length: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Preprocess the data by tokenizing."""
sequences = [s + "[gMASK]" + tokenizer.bos_token + t for s, t in zip(sources, targets)]
sequences_token = tokenizer(sequences,
max_length=max_length,
padding="max_length",
truncation=True,
return_tensors="pt")

labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]):
assert tokenizer.padding_side == "left", "chatglm's tokenizer should be padded at left"
context_len = torch.nonzero(sequences_token["input_ids"][i] == tokenizer.bos_token_id).squeeze()[0]
labels[i][:context_len] = IGNORE_INDEX

return sequences_token["input_ids"], labels, sequences_token["attention_mask"]


class SFTDataset(Dataset):
"""
Dataset for sft model
Expand All @@ -94,9 +116,12 @@ def __init__(self,
data["completion"] + tokenizer.eos_token
for data in tqdm(dataset, disable=not is_rank_0())
]

self.input_ids, self.labels, self.attention_mask = \
_preprocess(sources, targets, tokenizer, max_length)
if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = \
_preprocess_chatglm(sources, targets, tokenizer, max_length)
else:
self.input_ids, self.labels, self.attention_mask = \
_preprocess(sources, targets, tokenizer, max_length)

def __len__(self):
length = self.input_ids.shape[0]
Expand Down Expand Up @@ -137,8 +162,12 @@ def __init__(self,
]

logger.info("Tokenizing inputs... This may take some time...")
self.input_ids, self.labels, self.attention_mask = \
_preprocess(sources, targets, tokenizer, max_length)
if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = \
_preprocess_chatglm(sources, targets, tokenizer, max_length)
else:
self.input_ids, self.labels, self.attention_mask = \
_preprocess(sources, targets, tokenizer, max_length)

def __len__(self):
length = self.input_ids.shape[0]
Expand Down
3 changes: 3 additions & 0 deletions applications/Chat/coati/models/chatglm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .chatglm_actor import ChatGLMActor

__all__ = ['ChatGLMActor']
34 changes: 34 additions & 0 deletions applications/Chat/coati/models/chatglm/chatglm_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Optional

import torch
from .configuration_chatglm import ChatGLMConfig
from .modeling_chatglm import ChatGLMForConditionalGeneration

from ..base import Actor


class ChatGLMActor(Actor):
"""
ChatGLM Actor model.
Args:
pretrained (str): Pretrained model name or path.
config (ChatGLMConfig): Model config.
checkpoint (bool): Enable gradient checkpointing.
do not support lora for now.
"""

def __init__(self,
pretrained: str = None,
config: Optional[ChatGLMConfig] = None,
checkpoint: bool = False) -> None:
if pretrained is not None:
model = ChatGLMForConditionalGeneration.from_pretrained(pretrained)
elif config is not None:
model = ChatGLMForConditionalGeneration(config)
else:
model = ChatGLMForConditionalGeneration(ChatGLMConfig())
if checkpoint:
model.gradient_checkpointing_enable()
super().__init__(model, lora_rank=0, lora_train_bias='none')
Loading

0 comments on commit 2f492fa

Please sign in to comment.