Skip to content

Commit

Permalink
update embedding trainer (PaddlePaddle#9608)
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Dec 12, 2024
1 parent 8db699a commit 1d67a39
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 6 deletions.
14 changes: 8 additions & 6 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,12 +1061,12 @@ def _inner_training_loop(
if dp_master_grad:
is_no_sync = True

if is_no_sync:
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
sync_context = model.no_sync() if is_no_sync else contextlib.nullcontext()
with sync_context:
if "step_control" in inspect.signature(self.training_step).parameters:
tr_loss_step = self.training_step(model, inputs, step_control=step_control)
else:
tr_loss_step = self.training_step(model, inputs)
else:
tr_loss_step = self.training_step(model, inputs)

tr_loss += tr_loss_step

Expand Down Expand Up @@ -2195,7 +2195,9 @@ def _enable_delay_scale_loss(self):
else:
return False

def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
def training_step(
self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]], step_control=0
) -> paddle.Tensor:
"""
Perform a training step on a batch of inputs.
Expand Down
65 changes: 65 additions & 0 deletions paddlenlp/transformers/contrastive_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

import paddle
import paddle.nn as nn


class SimpleContrastiveLoss(nn.Layer):
def __init__(self, embedding_temperature: float = 0.02):
super().__init__()
self.embedding_temperature = embedding_temperature
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean")

def forward(self, q_reps, p_reps):
scores = paddle.matmul(q_reps, p_reps.transpose([1, 0]))
scores = scores / self.embedding_temperature

group_size = p_reps.shape[0] // q_reps.shape[0]
batch_size = q_reps.shape[0]

target = paddle.arange(batch_size, dtype="int64")
target = target * group_size

loss = self.cross_entropy(scores, target)
return loss


class MatryoshkaContrastiveLoss(nn.Layer):
def __init__(self, embedding_temperature: float = 0.02, embedding_matryoshka_dims: Optional[List[int]] = None):
super().__init__()
self.embedding_temperature = embedding_temperature
if embedding_matryoshka_dims is None:
self.embedding_matryoshka_dims = []
else:
self.embedding_matryoshka_dims = embedding_matryoshka_dims
self.loss_fn = SimpleContrastiveLoss(embedding_temperature)

def forward(self, q_reps, p_reps):
if len(self.embedding_matryoshka_dims) > 0:
loss = 0.0
for dim in self.embedding_matryoshka_dims:
reduced_q_reps = q_reps[:, :dim]
reduced_q_reps = nn.functional.normalize(reduced_q_reps, axis=-1)

reduced_p_reps = p_reps[:, :dim]
reduced_p_reps = nn.functional.normalize(reduced_p_reps, axis=-1)

dim_loss = self.loss_fn(reduced_q_reps, reduced_p_reps)
loss += dim_loss
else:
loss = self.loss_fn(q_reps, p_reps)
return loss
51 changes: 51 additions & 0 deletions paddlenlp/transformers/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.distributed import fleet


def dist_gather_tensor_with_gradient(tensor):
if tensor is None:
return None

if paddle.distributed.get_world_size() <= 1:
return tensor

hcg = fleet.get_hybrid_communicate_group()
sharding_group = hcg.get_sharding_parallel_group()
sharding_rank = sharding_group.rank
data_group = hcg.get_data_parallel_group()
data_rank = data_group.rank

if sharding_group.nranks == 1 and data_group.nranks == 1:
return tensor

if sharding_group.nranks > 1:
all_tensors = []
paddle.distributed.all_gather(all_tensors, tensor.contiguous(), group=sharding_group)
all_tensors[sharding_rank] = tensor
all_tensors = paddle.concat(all_tensors, axis=0)
else:
all_tensors = tensor

if data_group.nranks > 1:
final_tensors = []
paddle.distributed.all_gather(final_tensors, all_tensors.contiguous(), group=data_group)
final_tensors[data_rank] = all_tensors
final_tensors = paddle.concat(final_tensors, axis=0)
else:
final_tensors = all_tensors

return final_tensors
1 change: 1 addition & 0 deletions paddlenlp/trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .dpo_criterion import DPOCriterion
from .dpo_trainer import DPOTrainer
from .embedding_trainer import EmbeddingTrainer
from .kto_criterion import KTOCriterion
from .kto_trainer import KTOTrainer
from .trl_data import *
Expand Down
181 changes: 181 additions & 0 deletions paddlenlp/trl/embedding_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext

import paddle
from paddle.base import core
from paddle.distributed import fleet

from paddlenlp.trainer import Trainer
from paddlenlp.transformers.contrastive_loss import (
MatryoshkaContrastiveLoss,
SimpleContrastiveLoss,
)
from paddlenlp.transformers.embedding_utils import dist_gather_tensor_with_gradient

__all__ = ["EmbeddingTrainer"]


class EmbeddingTrainer(Trainer):
def __init__(self, model_args, **kwargs):
super().__init__(**kwargs)

self.model_args = model_args
self.embedding_negatives_cross_device = model_args.embedding_negatives_cross_device
self.accum_data = []
self.accum_freq = 0
self.accum_q_features = []
self.accum_p_features = []
self.accum_rng_states = {}
self.accum_rng_states["cpu"] = []
self.accum_rng_states["cuda"] = []
self.accum_rng_states["hybrid"] = []

if model_args.embedding_matryoshka_dims is not None and len(model_args.embedding_matryoshka_dims) > 0:
self.loss_fn = MatryoshkaContrastiveLoss(
model_args.embedding_temperature, model_args.embedding_matryoshka_dims
)
else:
self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature)

def clear_memory(self):
self.accum_q_features.clear()
self.accum_p_features.clear()
paddle.device.cuda.empty_cache()

def clear_state(self):
self.accum_data.clear()
self.accum_rng_states["cpu"].clear()
self.accum_rng_states["cuda"].clear()
self.accum_rng_states["hybrid"].clear()
self.accum_freq = 0

@paddle.no_grad()
def forward_no_grad(self, model, inputs):
# Step1: graph-less forward
self.accum_data.append(inputs)
inputs = self._prepare_inputs(inputs)
with self.autocast_smart_context_manager():
# collect rand states
self.accum_rng_states["cpu"].append(paddle.framework.core.default_cpu_generator().get_state())
self.accum_rng_states["cuda"].append(paddle.get_rng_state())
if self.args.use_hybrid_parallel:
self.accum_rng_states["hybrid"].append(
fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()
)

query_reps, passage_reps = model(**inputs, return_encode=True)

if self.embedding_negatives_cross_device:
query_reps = dist_gather_tensor_with_gradient(query_reps)
passage_reps = dist_gather_tensor_with_gradient(passage_reps)

self.accum_q_features.append(query_reps)
self.accum_p_features.append(passage_reps)

self.accum_freq += 1

def get_current_rng_state(self):
return {
"cpu": [paddle.framework.core.default_cpu_generator().get_state()],
"cuda": [paddle.get_rng_state()],
"hybrid": [fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()]
if self.args.use_hybrid_parallel
else [],
}

def reset_rng_state(self, states, index=0):
# set random states
if len(states) != 3:
raise ValueError("The length of state should be 3")
cpu_state = states["cpu"][index]
cuda_state = states["cuda"][index]
paddle.framework.core.default_cpu_generator().set_state(cpu_state)
# TODO(daisiming): support xpu and other custom devices.
if core.is_compiled_with_cuda():
for j in range(core.get_cuda_device_count()):
core.default_cuda_generator(j).set_state(cuda_state[j])
if self.args.use_hybrid_parallel:
hybrid_state = states["hybrid"][index]
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(hybrid_state)

def accum_forward_backward(self, model):
# Step2: representation gradient computation and caching
for i in range(len(self.accum_q_features)):
self.accum_q_features[i].stop_gradient = False
q_reps = paddle.concat(self.accum_q_features, axis=0)
for i in range(len(self.accum_p_features)):
self.accum_p_features[i].stop_gradient = False
p_reps = paddle.concat(self.accum_p_features, axis=0)

loss = self.loss_fn(q_reps, p_reps)
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
else:
loss.backward()
# get represetation gradient cache
accum_q_grads = [q.grad for q in self.accum_q_features]
accum_p_grads = [p.grad for p in self.accum_p_features]
del q_reps, p_reps

# clear trash memory
self.clear_memory()

current_rng_state = self.get_current_rng_state()
# Step3: sub-batch gradient accumulation
for i in range(self.accum_freq):
inputs = self.accum_data[i]
inputs = self._prepare_inputs(inputs)

sync_context = model.no_sync() if i != self.accum_freq - 1 and hasattr(model, "no_sync") else nullcontext()
with sync_context:
self.reset_rng_state(self.accum_rng_states, index=i)

with self.autocast_smart_context_manager():
query_reps, passage_reps = model(**inputs, return_encode=True)

if self.embedding_negatives_cross_device:
query_reps = dist_gather_tensor_with_gradient(query_reps)
passage_reps = dist_gather_tensor_with_gradient(passage_reps)

_loss = paddle.dot(query_reps.flatten(), accum_q_grads[i].flatten()) + paddle.dot(
passage_reps.flatten(), accum_p_grads[i].flatten()
)
_loss.backward()

self.reset_rng_state(current_rng_state)
self.clear_state()
return loss.detach()

def training_step(
self,
model,
inputs,
step_control=0,
):
if self.args.pipeline_parallel_degree > 1:
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.")

if self.args.gradient_accumulation_steps == 1:
return super().training_step(model, inputs)
else:
self.forward_no_grad(model, inputs)

# if (step_control + 1) % self.args.gradient_accumulation_steps is not zero, move on to next batch.
if (step_control + 1) % self.args.gradient_accumulation_steps != 0:
return 0.0

loss = self.accum_forward_backward(model)
return loss

0 comments on commit 1d67a39

Please sign in to comment.