-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ValueError: operands could not be broadcast together with shapes (38,384) (37,384) #61
Comments
hi @williambarberjr, to resume training, you can try to specify the checkpoint path to from angle_emb import AnglE
angle = AnglE.from_pretrained(
backbone='your_choice_backbone',
pretrained_model_path='/checkpoint-1100',
max_length=512,
pooling_strategy='cls').cuda() |
Thanks so much for your help. I was able to get it started again but I do have two other questions. I launched the training like this:
So I'm just going to assume it's not a LoRA model, it's just a pretrained_model_path. On a different note, when trying to test the model checkpoint like this:
outputs is printing out as:
Which naturally means that the similarities get printed as: I tried a few different ways of replicating the code here: https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1 But couldn't seem to make the checkpoint work for me. I'm just attempting to validate that the model is getting better or at least staying the same "in domain" as a sanity check that I haven't messed anything else up. |
hi @williambarberjr , I reproduced your problem. Here is the solution:
angle = AnglE.from_pretrained(
'mixedbread-ai/mxbai-embed-large-v1',
torch_dtype=torch.float32,
max_length=512,
pooling_strategy='cls').cuda() B.T.W, here are some tips to improve the training:
ds = ds.map(lambda x: {'text': 'Represent this sentence for searching relevant passages: {x["text"]}'})
train_ds = ds['train'].shuffle().map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8)
...
Here, I provide a training example for NLI, hopefully it can help :) |
Very helpful. I got it running on a single GPU setup and then on a multi-GPU setup to speed things up. However, I spot tested the most recent checkpoint with two query positive pairs and noticed that the similarity scores had both gotten worse than the original scores from the "mixedbread-ai/mxbai-embed-large-v1" model. But I'll either do a more comprehensive test or let it run out further before I assume I've done something wrong. I was having trouble getting it running distributed across multiple GPUs. But one small tweak got it working. I'll leave the revised version below incase it's helpful to someone else or in the event you have some constructive criticism or feedback for me on how to improve it. In the colab you linked, you set w2 to 2, should I have done that as well - for context, I only intend to use the embeddings for retrieval. from datasets import load_dataset, DatasetDict, Dataset
from angle_emb import AnglE, AngleDataTokenizer, DatasetFormats
import os
import torch
import random
import json
from transformers import AutoTokenizer
from angle_emb import AnglE, AngleDataTokenizer
os.environ['WANDB_MODE'] = 'disabled'
data = []
with open('/root/train_small.jsonl', 'r') as file:
for line in file:
obj = json.loads(line)
obj['text'] = f'Represent this sentence for searching relevant passages: {obj["text"]}'
data.append(obj)
random.shuffle(data)
train_ds = Dataset.from_list(data)
test_data = []
with open('/root/test.jsonl', 'r') as file:
for line in file:
obj = json.loads(line)
obj['text'] = f'Represent this sentence for searching relevant passages: {obj["text"]}'
test_data.append(obj)
random.shuffle(test_data)
test_ds = Dataset.from_list(test_data)
validation_data = []
with open('/root/validation.jsonl', 'r') as file:
for line in file:
obj = json.loads(line)
obj['text'] = f'Represent this sentence for searching relevant passages: {obj["text"]}'
validation_data.append(obj)
random.shuffle(validation_data)
valid_ds = Dataset.from_list(validation_data)
max_len = 512
model_id = 'mixedbread-ai/mxbai-embed-large-v1'
tokenizer = AutoTokenizer.from_pretrained(model_id)
train_ds = train_ds.shuffle().map(AngleDataTokenizer(tokenizer, max_len), num_proc=8)
valid_ds = valid_ds.map(AngleDataTokenizer(tokenizer, max_len), num_proc=8)
test_ds = test_ds.map(AngleDataTokenizer(tokenizer, max_len), num_proc=8)
angle = AnglE.from_pretrained(model_id, max_length=max_len, torch_dtype=torch.float32, pooling_strategy='cls').cuda()
angle.fit(train_ds=train_ds,
valid_ds=valid_ds,
output_dir='/root',
batch_size=3,
epochs=5,
learning_rate=1e-5,
save_steps=100,
eval_steps=1000,
warmup_steps=0,
gradient_accumulation_steps= 2050,
loss_kwargs={
'w1':0,
'w2':20,
'w3':1,
'cosine_tau':20,
'ibn_tau':20,
'angle_tau':20,
},
fp16=True,
logging_steps=10) I'm calling the code with this command on a 8 x RTX 4090 setup: |
hi @williambarberjr , i checked it in multi-GPU setting and did not reproduce your problem. It works fine in my training. I fixed the NaN issue several minutes ago. Could you update the angle_emb to v0.3.9 and try to train the model using the train_cli.py Here is an example to train with
$ head -3 snli_5k.jsonl
{"text": "A person on a horse jumps over a broken down airplane.", "positive": "A person is outdoors, on a horse.", "negative": "A person is at a diner, ordering an omelette."}
{"text": "Children smiling and waving at camera", "positive": "There are children present", "negative": "The kids are frowning"}
{"text": "A boy is jumping on skateboard in the middle of a red bridge.", "positive": "The boy does a skateboarding trick.", "negative": "The boy skates down the sidewalk."}
NCCL_P2P_DISABLE=1 NCCL_IB_DISABLE=1 CUDA_VISIBLE_DEVICES=0,1 WANDB_MODE=disabled CUDA_VISIBLE_DEVICES=2,3 torchrun --nproc_per_node=2 --master_port=2345 train_cli.py \
--model_name_or_path mixedbread-ai/mxbai-embed-large-v1 \
--train_name_or_path ./snli_5k.jsonl --save_dir mxbai-snli-ckpts \
--w1 0. --w2 20.0 --w3 1.0 --angle_tau 20.0 --learning_rate 3e-6 --maxlen 64 \
--pooling_strategy cls \
--epochs 1 \
--batch_size 32 \
--logging_steps 100 \
--warmup_steps 200 \
--save_steps 1000 --seed 42 --gradient_accumulation_steps 2 --fp16 1 --torch_dtype 'float32' When the training is done, the model will be saved to from angle_emb import AnglE
angle = AnglE.from_pretrained('mixedbread-ai/mxbai-embed-large-v1', pretrained_model_path='mxbai-snli-ckpts').cuda()
... |
Thanks for your help. The training is running smoothly now distributed across several GPUs. I've got one running with an effective batch size of 30k+ and another with an effective batch size of about 4k. The grad_norm values for the 4k batch size seem to be a bit all over the place - does that suggest an issue to you? |
Got this error while using this library to train an embedding model:
I confirmed that
valid_ds
andtrain_ds
were of even length, so ultimately I just modified one line of the evaluate method of the AnglE class. After this line:x_vecs = l2_normalize(x_vecs)
I added:
Hopefully that doesn't break anything/everything else? Any thougths on what else might be the source of the issue?
Also, I attempted to restart training by running the same
angle.fit()
as I did when I started it but adjusting thefrom_pretrained
to point to the most recent checkpoint:angle = AnglE.from_pretrained('/checkpoint-1100', max_length=512, pooling_strategy='cls').cuda()
I don't see a
resume_from_checkpoint=True
argument option anywhere... so it's not clear that it's aware of how many epochs have already been run etc.The text was updated successfully, but these errors were encountered: