forked from PaddlePaddle/PaddleNLP
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update embedding trainer (PaddlePaddle#9608)
- Loading branch information
Showing
5 changed files
with
306 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List, Optional | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
|
||
|
||
class SimpleContrastiveLoss(nn.Layer): | ||
def __init__(self, embedding_temperature: float = 0.02): | ||
super().__init__() | ||
self.embedding_temperature = embedding_temperature | ||
self.cross_entropy = nn.CrossEntropyLoss(reduction="mean") | ||
|
||
def forward(self, q_reps, p_reps): | ||
scores = paddle.matmul(q_reps, p_reps.transpose([1, 0])) | ||
scores = scores / self.embedding_temperature | ||
|
||
group_size = p_reps.shape[0] // q_reps.shape[0] | ||
batch_size = q_reps.shape[0] | ||
|
||
target = paddle.arange(batch_size, dtype="int64") | ||
target = target * group_size | ||
|
||
loss = self.cross_entropy(scores, target) | ||
return loss | ||
|
||
|
||
class MatryoshkaContrastiveLoss(nn.Layer): | ||
def __init__(self, embedding_temperature: float = 0.02, embedding_matryoshka_dims: Optional[List[int]] = None): | ||
super().__init__() | ||
self.embedding_temperature = embedding_temperature | ||
if embedding_matryoshka_dims is None: | ||
self.embedding_matryoshka_dims = [] | ||
else: | ||
self.embedding_matryoshka_dims = embedding_matryoshka_dims | ||
self.loss_fn = SimpleContrastiveLoss(embedding_temperature) | ||
|
||
def forward(self, q_reps, p_reps): | ||
if len(self.embedding_matryoshka_dims) > 0: | ||
loss = 0.0 | ||
for dim in self.embedding_matryoshka_dims: | ||
reduced_q_reps = q_reps[:, :dim] | ||
reduced_q_reps = nn.functional.normalize(reduced_q_reps, axis=-1) | ||
|
||
reduced_p_reps = p_reps[:, :dim] | ||
reduced_p_reps = nn.functional.normalize(reduced_p_reps, axis=-1) | ||
|
||
dim_loss = self.loss_fn(reduced_q_reps, reduced_p_reps) | ||
loss += dim_loss | ||
else: | ||
loss = self.loss_fn(q_reps, p_reps) | ||
return loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
from paddle.distributed import fleet | ||
|
||
|
||
def dist_gather_tensor_with_gradient(tensor): | ||
if tensor is None: | ||
return None | ||
|
||
if paddle.distributed.get_world_size() <= 1: | ||
return tensor | ||
|
||
hcg = fleet.get_hybrid_communicate_group() | ||
sharding_group = hcg.get_sharding_parallel_group() | ||
sharding_rank = sharding_group.rank | ||
data_group = hcg.get_data_parallel_group() | ||
data_rank = data_group.rank | ||
|
||
if sharding_group.nranks == 1 and data_group.nranks == 1: | ||
return tensor | ||
|
||
if sharding_group.nranks > 1: | ||
all_tensors = [] | ||
paddle.distributed.all_gather(all_tensors, tensor.contiguous(), group=sharding_group) | ||
all_tensors[sharding_rank] = tensor | ||
all_tensors = paddle.concat(all_tensors, axis=0) | ||
else: | ||
all_tensors = tensor | ||
|
||
if data_group.nranks > 1: | ||
final_tensors = [] | ||
paddle.distributed.all_gather(final_tensors, all_tensors.contiguous(), group=data_group) | ||
final_tensors[data_rank] = all_tensors | ||
final_tensors = paddle.concat(final_tensors, axis=0) | ||
else: | ||
final_tensors = all_tensors | ||
|
||
return final_tensors |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from contextlib import nullcontext | ||
|
||
import paddle | ||
from paddle.base import core | ||
from paddle.distributed import fleet | ||
|
||
from paddlenlp.trainer import Trainer | ||
from paddlenlp.transformers.contrastive_loss import ( | ||
MatryoshkaContrastiveLoss, | ||
SimpleContrastiveLoss, | ||
) | ||
from paddlenlp.transformers.embedding_utils import dist_gather_tensor_with_gradient | ||
|
||
__all__ = ["EmbeddingTrainer"] | ||
|
||
|
||
class EmbeddingTrainer(Trainer): | ||
def __init__(self, model_args, **kwargs): | ||
super().__init__(**kwargs) | ||
|
||
self.model_args = model_args | ||
self.embedding_negatives_cross_device = model_args.embedding_negatives_cross_device | ||
self.accum_data = [] | ||
self.accum_freq = 0 | ||
self.accum_q_features = [] | ||
self.accum_p_features = [] | ||
self.accum_rng_states = {} | ||
self.accum_rng_states["cpu"] = [] | ||
self.accum_rng_states["cuda"] = [] | ||
self.accum_rng_states["hybrid"] = [] | ||
|
||
if model_args.embedding_matryoshka_dims is not None and len(model_args.embedding_matryoshka_dims) > 0: | ||
self.loss_fn = MatryoshkaContrastiveLoss( | ||
model_args.embedding_temperature, model_args.embedding_matryoshka_dims | ||
) | ||
else: | ||
self.loss_fn = SimpleContrastiveLoss(model_args.embedding_temperature) | ||
|
||
def clear_memory(self): | ||
self.accum_q_features.clear() | ||
self.accum_p_features.clear() | ||
paddle.device.cuda.empty_cache() | ||
|
||
def clear_state(self): | ||
self.accum_data.clear() | ||
self.accum_rng_states["cpu"].clear() | ||
self.accum_rng_states["cuda"].clear() | ||
self.accum_rng_states["hybrid"].clear() | ||
self.accum_freq = 0 | ||
|
||
@paddle.no_grad() | ||
def forward_no_grad(self, model, inputs): | ||
# Step1: graph-less forward | ||
self.accum_data.append(inputs) | ||
inputs = self._prepare_inputs(inputs) | ||
with self.autocast_smart_context_manager(): | ||
# collect rand states | ||
self.accum_rng_states["cpu"].append(paddle.framework.core.default_cpu_generator().get_state()) | ||
self.accum_rng_states["cuda"].append(paddle.get_rng_state()) | ||
if self.args.use_hybrid_parallel: | ||
self.accum_rng_states["hybrid"].append( | ||
fleet.meta_parallel.get_rng_state_tracker().get_states_tracker() | ||
) | ||
|
||
query_reps, passage_reps = model(**inputs, return_encode=True) | ||
|
||
if self.embedding_negatives_cross_device: | ||
query_reps = dist_gather_tensor_with_gradient(query_reps) | ||
passage_reps = dist_gather_tensor_with_gradient(passage_reps) | ||
|
||
self.accum_q_features.append(query_reps) | ||
self.accum_p_features.append(passage_reps) | ||
|
||
self.accum_freq += 1 | ||
|
||
def get_current_rng_state(self): | ||
return { | ||
"cpu": [paddle.framework.core.default_cpu_generator().get_state()], | ||
"cuda": [paddle.get_rng_state()], | ||
"hybrid": [fleet.meta_parallel.get_rng_state_tracker().get_states_tracker()] | ||
if self.args.use_hybrid_parallel | ||
else [], | ||
} | ||
|
||
def reset_rng_state(self, states, index=0): | ||
# set random states | ||
if len(states) != 3: | ||
raise ValueError("The length of state should be 3") | ||
cpu_state = states["cpu"][index] | ||
cuda_state = states["cuda"][index] | ||
paddle.framework.core.default_cpu_generator().set_state(cpu_state) | ||
# TODO(daisiming): support xpu and other custom devices. | ||
if core.is_compiled_with_cuda(): | ||
for j in range(core.get_cuda_device_count()): | ||
core.default_cuda_generator(j).set_state(cuda_state[j]) | ||
if self.args.use_hybrid_parallel: | ||
hybrid_state = states["hybrid"][index] | ||
fleet.meta_parallel.get_rng_state_tracker().set_states_tracker(hybrid_state) | ||
|
||
def accum_forward_backward(self, model): | ||
# Step2: representation gradient computation and caching | ||
for i in range(len(self.accum_q_features)): | ||
self.accum_q_features[i].stop_gradient = False | ||
q_reps = paddle.concat(self.accum_q_features, axis=0) | ||
for i in range(len(self.accum_p_features)): | ||
self.accum_p_features[i].stop_gradient = False | ||
p_reps = paddle.concat(self.accum_p_features, axis=0) | ||
|
||
loss = self.loss_fn(q_reps, p_reps) | ||
if self.do_grad_scaling: | ||
self.scaler.scale(loss).backward() | ||
else: | ||
loss.backward() | ||
# get represetation gradient cache | ||
accum_q_grads = [q.grad for q in self.accum_q_features] | ||
accum_p_grads = [p.grad for p in self.accum_p_features] | ||
del q_reps, p_reps | ||
|
||
# clear trash memory | ||
self.clear_memory() | ||
|
||
current_rng_state = self.get_current_rng_state() | ||
# Step3: sub-batch gradient accumulation | ||
for i in range(self.accum_freq): | ||
inputs = self.accum_data[i] | ||
inputs = self._prepare_inputs(inputs) | ||
|
||
sync_context = model.no_sync() if i != self.accum_freq - 1 and hasattr(model, "no_sync") else nullcontext() | ||
with sync_context: | ||
self.reset_rng_state(self.accum_rng_states, index=i) | ||
|
||
with self.autocast_smart_context_manager(): | ||
query_reps, passage_reps = model(**inputs, return_encode=True) | ||
|
||
if self.embedding_negatives_cross_device: | ||
query_reps = dist_gather_tensor_with_gradient(query_reps) | ||
passage_reps = dist_gather_tensor_with_gradient(passage_reps) | ||
|
||
_loss = paddle.dot(query_reps.flatten(), accum_q_grads[i].flatten()) + paddle.dot( | ||
passage_reps.flatten(), accum_p_grads[i].flatten() | ||
) | ||
_loss.backward() | ||
|
||
self.reset_rng_state(current_rng_state) | ||
self.clear_state() | ||
return loss.detach() | ||
|
||
def training_step( | ||
self, | ||
model, | ||
inputs, | ||
step_control=0, | ||
): | ||
if self.args.pipeline_parallel_degree > 1: | ||
raise NotImplementedError("Cannot support pipeline parallel for Embedding training now.") | ||
|
||
if self.args.gradient_accumulation_steps == 1: | ||
return super().training_step(model, inputs) | ||
else: | ||
self.forward_no_grad(model, inputs) | ||
|
||
# if (step_control + 1) % self.args.gradient_accumulation_steps is not zero, move on to next batch. | ||
if (step_control + 1) % self.args.gradient_accumulation_steps != 0: | ||
return 0.0 | ||
|
||
loss = self.accum_forward_backward(model) | ||
return loss |