Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc][LoRA] Move the implementation of lora bias to punica.py #10829

Merged
merged 8 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 27 additions & 33 deletions tests/lora/test_llama_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
return generated_texts


@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):

llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)

def generate_and_test(llm, sql_lora_files):
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT

Expand All @@ -79,6 +71,17 @@ def test_llama_lora(sql_lora_files):
print("removing lora")


@fork_new_process_for_each_test
def test_llama_lora(sql_lora_files):

llm = vllm.LLM(MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=1)
generate_and_test(llm, sql_lora_files)


@fork_new_process_for_each_test
def test_llama_lora_warmup(sql_lora_files):
"""Test that the LLM initialization works with a warmup LORA path and
Expand Down Expand Up @@ -118,20 +121,7 @@ def test_llama_lora_tp4(sql_lora_files):
max_loras=4,
tensor_parallel_size=4,
)

print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT

print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT

print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT

print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT

print("removing lora")
generate_and_test(llm, sql_lora_files)


@multi_gpu_test(num_gpus=4)
Expand All @@ -146,16 +136,20 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files):
tensor_parallel_size=4,
fully_sharded_loras=True,
)
print("lora adapter created")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT

print("lora 1")
assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT
generate_and_test(llm, sql_lora_files)

print("no lora")
assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT

print("lora 2")
assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT
@multi_gpu_test(num_gpus=4)
@fork_new_process_for_each_test
def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files):

print("removing lora")
llm = vllm.LLM(
MODEL_PATH,
enable_lora=True,
max_num_seqs=16,
max_loras=4,
tensor_parallel_size=4,
fully_sharded_loras=True,
enable_lora_bias=True,
)
generate_and_test(llm, sql_lora_files)
41 changes: 12 additions & 29 deletions vllm/lora/fully_sharded_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def apply(self, x: torch.Tensor,
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_stacked,
add_input=True)
# now have column partitioned output

Expand Down Expand Up @@ -131,27 +132,14 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
layer.lora_a_stacked[idx], 1.0)

buffers = tensor_model_parallel_all_gather(buffers)
left_offset = 0
for idx in range(n):
shard_size = layer.lora_b_stacked[idx].shape[2]

if layer.bias_stacked is not None:
bias = layer.bias_stacked[idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[layer.punica_wrapper.token_lora_indices]
bias[layer.punica_wrapper.token_lora_indices == -1] = 0
output[:, left_offset:left_offset + shard_size] += bias

layer.punica_wrapper.add_expand_slice(
output,
buffers[idx],
layer.lora_b_stacked[idx],
left_offset,
shard_size,
add_input=True,
)
left_offset += shard_size
layer.punica_wrapper.add_expand_packed_nslice(
output,
buffers,
layer.lora_b_stacked,
layer.bias_stacked,
1.0,
layer.output_slices,
)

output = output.view(*out_orig_shape)
# now have column partitioned and packed output
Expand Down Expand Up @@ -234,6 +222,7 @@ def apply(self, x: torch.Tensor,
self.punica_wrapper.add_expand(output,
buffer,
self.lora_b_stacked,
self.bias_all,
add_input=True)
# now have column partitioned output
output = output.view(*out_orig_shape)
Expand Down Expand Up @@ -350,15 +339,9 @@ def apply(self, x: torch.Tensor) -> torch.Tensor:
# reduced before being used
shard_size = self.lora_b_stacked.shape[2]
start_idx = self.tp_rank * shard_size

if self.bias_stacked is not None:
bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1])
bias = bias[self.punica_wrapper.token_lora_indices]
bias[self.punica_wrapper.token_lora_indices == -1] = 0
output += bias

self.punica_wrapper.add_expand_slice(output, buffer,
self.lora_b_stacked, start_idx,
self.lora_b_stacked,
self.bias_stacked, start_idx,
shard_size)
output = output.view(*out_orig_shape)
return output
Expand Down
113 changes: 12 additions & 101 deletions vllm/lora/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,63 +67,6 @@ def dec(*args, **kwargs):
return dec


def apply_bias(
indices: torch.Tensor,
output: torch.Tensor,
bias_stacked: torch.Tensor,
):
"""Applies bias to output
Input shapes:
bias_stacked: (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, output_dim)
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)

bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1])
bias_stacked = bias_stacked[indices]
bias_stacked[indices == -1] = 0
output += bias_stacked

return output.view_as(org_output)


def apply_bias_packed_nslice(
indices: torch.Tensor,
output: torch.Tensor,
output_slices: Tuple[int, ...],
bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
):
"""Applies bias to output
Input shapes:
bias_stacked: 3 element tuple of (num_loras, output_dim)
indices: (batch_size)
output: (batch_size, q_slice_size + 2*kv_slice_size)
output_slices: n-1 element tuple of (slice_size...),
where n is number of slices
"""
org_output = output
output = output.view(-1, output.shape[-1])
indices = indices.view(-1)

offset_left = 0
for slice_idx, slice in enumerate(output_slices):
bias = bias_stacked[slice_idx]
if bias is not None:
bias = bias.view(-1, bias.shape[-1])
bias = bias[indices]
bias[indices == -1] = 0
output[:, offset_left:offset_left + slice] += bias

offset_left += slice

return output.view_as(org_output)


@dataclass
class LoRAMapping(AdapterMapping):
is_prefill: bool = False
Expand Down Expand Up @@ -311,6 +254,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.punica_wrapper.add_expand(full_output,
full_lora_a_embeddings,
self.lora_b_stacked,
bias_all=None,
add_input=True)
return full_output.view_as(full_output_org)

Expand Down Expand Up @@ -399,15 +343,9 @@ def set_lora(
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
self.lora_b_stacked, self.bias_stacked,
1.0)
return output

def forward(self, input_):
Expand Down Expand Up @@ -576,15 +514,9 @@ def set_lora(
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
self.lora_b_stacked, self.bias_stacked,
1.0)
return output

def forward(self, input_):
Expand Down Expand Up @@ -687,8 +619,8 @@ def create_lora_weights(
) for _ in range(n_slices))
else:
self.bias_stacked = None

self.output_dim = self.lora_b_stacked[0].shape[2]
self.output_slices = (self.output_dim, self.output_dim)

def reset_lora(self, index: int):
self.lora_a_stacked[0][index] = 0
Expand Down Expand Up @@ -772,17 +704,9 @@ def set_lora(
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
(self.output_dim, self.output_dim),
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice(
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
(self.output_dim, self.output_dim))
output, x, self.lora_a_stacked, self.lora_b_stacked,
self.bias_stacked, 1.0, (self.output_dim, self.output_dim))
return output

@classmethod
Expand Down Expand Up @@ -1129,17 +1053,10 @@ def set_lora(
def apply(self, x: torch.Tensor,
bias: Optional[torch.Tensor]) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias_packed_nslice(
self.indices,
output,
self.output_slices,
self.bias_stacked,
)
self.punica_wrapper.add_lora_packed_nslice(output, x,
self.lora_a_stacked,
self.lora_b_stacked, 1.0,
self.lora_b_stacked,
self.bias_stacked, 1.0,
self.output_slices)
return output

Expand Down Expand Up @@ -1264,15 +1181,9 @@ def set_lora(

def apply(self, x: torch.Tensor) -> torch.Tensor:
output = self.base_layer.quant_method.apply(self.base_layer, x)
if self.bias_stacked is not None:
self.indices = self.punica_wrapper.token_lora_indices
output = apply_bias(
self.indices,
output,
self.bias_stacked,
)
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
self.lora_b_stacked, 1.0)
self.lora_b_stacked, self.bias_stacked,
1.0)
return output

def forward(self, input_):
Expand Down
Loading
Loading