Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deepspeed setup for requiring grads on the input (explainability) without huge increase in memory over all gpus #6798

Open
GonyRosenman opened this issue Nov 27, 2024 · 0 comments

Comments

@GonyRosenman
Copy link

I am using DeepSpeed with Zero Optimization (Stage 2) to train a custom model on multiple GPUs. i want to compute gradients on the input for explainability. However, I am facing challenges when integrating gradient computation for the input in this setup. The memory usage increases significantly, and I lose the memory savings typically achieved by DeepSpeed.

Below is the relevant DeepSpeed configuration I use, passed to the Hugging Face Trainer via the deepspeed argument:

DeepSpeed JSON Configuration (./scripts/zero2.json):
json
Copy code
{
"fp16": {
"enabled": "auto",
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"bf16": {
"enabled": "auto"
},
"train_micro_batch_size_per_gpu": "auto",
"train_batch_size": "auto",
"gradient_accumulation_steps": "auto",
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto"
}
}

Below is a minimal example to reproduce the issue:

to launch this script i use this code
/home/user/miniconda3/envs/proejct/bin/python -m deepspeed.launcher.launch --world_info=eyIxMjcuMC4wLjEiOiBbMCwgMSwgMiwgMywgNCwgNSwgNiwgN119 --master_addr=127.0.0.1 --master_port=4242 --no_local_rank /home/user/project/explain/explain.py --deepspeed ./scripts/zero2.json

import torch
from torch import nn
from transformers import BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class ExampleDataset(Dataset):
def init(self, tokenizer, size=10000, max_length=128):
self.tokenizer = tokenizer
self.texts = [f"This is example text {i}" for i in range(size)]
self.labels = torch.randint(0, 2, (size,)) # Binary classification
self.max_length = max_length

def __len__(self):
    return len(self.texts)

def __getitem__(self, idx):
    tokenized = self.tokenizer(
        self.texts[idx],
        padding="max_length",
        truncation=True,
        max_length=self.max_length,
        return_tensors="pt",
    )
    input_ids = tokenized["input_ids"].squeeze(0)
    aux = torch.rand((5,)) * 100  # Random auxiliary input

    return {
        "aux": aux,
        "input_ids": input_ids,
        "attention_mask": tokenized["attention_mask"].squeeze(0),
        "labels": self.labels[idx],
    }

class CustomModel(nn.Module):
def init(self, base_model_name="bert-base-uncased", num_labels=2):
super(CustomModel, self).init()
self.bert = BertForSequenceClassification.from_pretrained(base_model_name, num_labels=num_labels)
self.embedding_dim = self.bert.config.hidden_size
self.aux_linear = nn.Linear(5, self.embedding_dim)

def forward(self, input_ids, attention_mask, aux, labels=None):
    aux_embedded = self.aux_linear(aux)  # Embed auxiliary input
    input_embeddings = self.bert.bert.embeddings(input_ids)
    input_embeddings[:, 0, :] += aux_embedded  # Modify embeddings

    outputs = self.bert(
        inputs_embeds=input_embeddings,
        attention_mask=attention_mask,
        labels=labels,
    )
    return outputs

def compute_saliency_maps(trainer, loader, device, repeat_factor=1000):
"""
Compute gradients for auxiliary input tensor (aux) in a DeepSpeed-enabled setting.
"""
model = trainer._wrap_model(trainer.model, training=false, dataloader=loader)
model.eval()

for _ in range(repeat_factor):
    for batch in tqdm(loader, desc="Computing Saliency Maps"):
        batch = {key: value.to(device) if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
        batch["aux"].requires_grad_(True)  # Enable gradients for aux

        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            aux=batch["aux"],
            labels=batch["labels"]
        )
        loss = outputs.loss
        loss.backward()  # Compute gradients
        grads = batch["aux"].grad
        print(f"Gradient norm: {grads.norm().item()}")

if name == "main":
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = CustomModel()

dataset = ExampleDataset(tokenizer, size=200)
loader = DataLoader(dataset, batch_size=1000)

training_args = TrainingArguments(
    output_dir="./results",
    per_device_eval_batch_size=8,
    do_train=False,
    do_eval=True,
    logging_dir="./logs",
    deepspeed="./scripts/zero2.json",
)
trainer = Trainer(
    model=model,
    args=training_args,
    eval_dataset=dataset,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
compute_saliency_maps(trainer, loader, device, repeat_factor=3)

Observations:
Without Gradient Computation for aux: The model works as expected, and DeepSpeed successfully reduces memory usage.
With Gradient Computation for aux: Memory usage increases significantly, negating the benefits of Zero Optimization Stage 2.
Increased Memory Usage in Multi-GPU Setting: While the toy example fits in memory, my actual model OOMs when gradients are enabled for aux.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant