Skip to content

Commit

Permalink
add embedding trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Nov 27, 2024
1 parent 8fd33a9 commit 4d974e5
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 3 deletions.
8 changes: 5 additions & 3 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1093,9 +1093,9 @@ def _inner_training_loop(
if is_no_sync:
# Avoid unnecessary DDP synchronization since there will be no backward pass on this example.
with model.no_sync():
tr_loss_step = self.training_step(model, inputs)
tr_loss_step = self.training_step(model, inputs, step_control=step_control)
else:
tr_loss_step = self.training_step(model, inputs)
tr_loss_step = self.training_step(model, inputs, step_control=step_control)

tr_loss += tr_loss_step

Expand Down Expand Up @@ -2267,7 +2267,9 @@ def _enable_delay_scale_loss(self):
else:
return False

def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
def training_step(
self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]], step_control=0
) -> paddle.Tensor:
"""
Perform a training step on a batch of inputs.
Expand Down
36 changes: 36 additions & 0 deletions paddlenlp/transformers/contrastive_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle


class SimpleContrastiveLoss(paddle.nn.Layer):
def __init__(self, embedding_temperature: float = 0.02):
super().__init__()
self.embedding_temperature = embedding_temperature
self.cross_entropy = paddle.nn.CrossEntropyLoss(reduction="mean")

def forward(self, q_reps, p_reps):
scores = paddle.matmul(q_reps, p_reps.transpose([1, 0]))
scores = scores / self.embedding_temperature

group_size = p_reps.shape[0] // q_reps.shape[0]
batch_size = q_reps.shape[0]

target = paddle.arange(batch_size, dtype="int64")
target = target * group_size

loss = self.cross_entropy(scores, target)
return loss
1 change: 1 addition & 0 deletions paddlenlp/trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from .dpo_criterion import DPOCriterion
from .dpo_trainer import DPOTrainer
from .embedding_trainer import EmbeddingTrainer
from .kto_criterion import KTOCriterion
from .kto_trainer import KTOTrainer
from .sft_trainer import *
Expand Down
161 changes: 161 additions & 0 deletions paddlenlp/trl/embedding_trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext

import paddle
from paddle.base import core
from paddle.distributed import fleet

from paddlenlp.trainer import Trainer
from paddlenlp.transformers.contrastive_loss import SimpleContrastiveLoss

__all__ = ["EmbeddingTrainer"]


class EmbeddingTrainer(Trainer):
def __init__(self, model_args, use_gradient_cache=False, **kwargs):
super().__init__(**kwargs)

self.model_args = model_args
self.use_gradient_cache = use_gradient_cache
self.accum_data = []
self.accum_freq = 0
self.accum_q_features = []
self.accum_p_features = []
self.accum_rng_states = {}
self.accum_rng_states["cpu"] = []
self.accum_rng_states["cuda"] = []
self.accum_rng_states["hybrid"] = []
self.loss_fn = SimpleContrastiveLoss(self.model_args.embedding_temperature)

def clear_memory(self):
self.accum_q_features.clear()
self.accum_p_features.clear()
paddle.device.cuda.empty_cache()

def clear_state(self):
self.accum_data.clear()
self.accum_rng_states["cpu"].clear()
self.accum_rng_states["cuda"].clear()
self.accum_rng_states["hybrid"].clear()
self.accum_freq = 0

@paddle.no_grad()
def forward_no_grad(self, model, inputs):
# Step1: graph-less forward
self.accum_data.append(inputs)
inputs = self._prepare_inputs(inputs)
with self.autocast_smart_context_manager():
# collect rand states
self.accum_rng_states["cpu"].append(paddle.framework.core.default_cpu_generator().get_state())
self.accum_rng_states["cuda"].append(paddle.get_rng_state())
if self.args.use_hybrid_parallel:
self.accum_rng_states["hybrid"].append(
fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()
)

query_reps, passage_reps = model(**inputs, return_encode=True)

self.accum_q_features.append(query_reps)
self.accum_p_features.append(passage_reps)

self.accum_freq += 1

def get_current_rng_state(self):
return {
"cpu": [paddle.framework.core.default_cpu_generator().get_state()],
"cuda": [paddle.get_rng_state()],
"hybrid": [fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()],
}

def reset_rng_state(self, states, index=0):
# set random states
if len(states) != 3:
raise ValueError("The length of state should be 3")
cpu_state = states["cpu"][index]
cuda_state = states["cuda"][index]
hybrid_state = states["hybrid"][index]
paddle.framework.core.default_cpu_generator().set_state(cpu_state)
# TODO(daisiming): support xpu and other custom devices.
if core.is_compiled_with_cuda():
for j in range(core.get_cuda_device_count()):
core.default_cuda_generator(j).set_state(cuda_state[j])
if self.args.use_hybrid_parallel:
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(hybrid_state)

def accum_forward_backward(self, model):
# Step2: representation gradient computation and caching
for i in range(len(self.accum_q_features)):
self.accum_q_features[i].stop_gradient = False
q_reps = paddle.concat(self.accum_q_features, axis=0)
for i in range(len(self.accum_p_features)):
self.accum_p_features[i].stop_gradient = False
p_reps = paddle.concat(self.accum_p_features, axis=0)

loss = self.loss_fn(q_reps, p_reps)
if self.do_grad_scaling:
self.scaler.scale(loss).backward()
else:
loss.backward()
# get represetation gradient cache
accum_q_grads = [q.grad for q in self.accum_q_features]
accum_p_grads = [p.grad for p in self.accum_p_features]
del q_reps, p_reps

# clear trash memory
self.clear_memory()

current_rng_state = self.get_current_rng_state()
# Step3: sub-batch gradient accumulation
for i in range(self.accum_freq):
inputs = self.accum_data[i]
inputs = self._prepare_inputs(inputs)

sync_context = model.no_sync() if i != self.accum_freq - 1 and hasattr(model, "no_sync") else nullcontext()
with sync_context:
self.reset_rng_state(self.accum_rng_states, index=i)

with self.autocast_smart_context_manager():
query_reps, passage_reps = model(**inputs, return_encode=True)

_loss = paddle.dot(query_reps.flatten(), accum_q_grads[i].flatten()) + paddle.dot(
passage_reps.flatten(), accum_p_grads[i].flatten()
)
_loss.backward()

self.reset_rng_state(current_rng_state)
self.clear_state()
return loss.detach()

def training_step(
self,
model,
inputs,
step_control=0,
):
if self.args.pipeline_parallel_degree > 1:
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.")

if self.args.gradient_accumulation_steps == 1 or not self.use_gradient_cache:
return super().training_step(model, inputs)
else:
self.forward_no_grad(model, inputs)

# if (step_control + 1) % self.args.gradient_accumulation_steps is not zero, move on to next batch.
if (step_control + 1) % self.args.gradient_accumulation_steps != 0:
return 0.0

loss = self.accum_forward_backward(model)
return loss

0 comments on commit 4d974e5

Please sign in to comment.