Skip to content

Commit

Permalink
minimal impl. of gpu_idx generation for moe model
Browse files Browse the repository at this point in the history
  • Loading branch information
hodlen committed Apr 8, 2024
1 parent 5b4c22b commit 09d79ec
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
4 changes: 3 additions & 1 deletion powerinfer-py/powerinfer/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
parser.add_argument('--capacity', type=int, help='Max neurons that can be stored in VRAM.')
parser.add_argument('--layer', type=int, default=59, help='Total number of layers in the neural network.')
parser.add_argument('--vram-capacity', type=int, help='Total VRAM capacity (Bytes) available for splitting')
parser.add_argument('--moe', type=bool, default=False, help='Flag to indicate whether the model is a MoE model.')
parser.add_argument('--batch', type=int, default=256, help='Batch size for processing.')
parser.add_argument('--threshold', type=int, default=0, help='Threshold for splitting a layer across multiple GPUs.')
parser.add_argument('--output', type=str, required=True, help='File path for the output pickle file.')
Expand All @@ -36,7 +37,8 @@
activations_path=args.activation,
output_path=args.output,
solved_list=solved,
vram_capacity=args.vram_capacity
vram_capacity=args.vram_capacity,
for_moe=args.moe,
)

print(f"Exported to {args.output}")
37 changes: 29 additions & 8 deletions powerinfer-py/powerinfer/export_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ def load_activation_weights(models_base: Path):
activation_files.sort()
return [torch.load(models_base / f) for f in activation_files]

def append_gpu_idx(gguf: GGUFWriter, i_layer: int, activation, select_count) -> None:

def append_gpu_idx(
gguf: GGUFWriter,
i_layer: int,
activation: torch.Tensor,
select_count: int,
skip_bucket=False,
):
_, indices = torch.topk(activation, k=int(select_count))
gpu_idx = torch.zeros_like(activation)
gpu_idx[indices] = 1
Expand All @@ -32,7 +39,12 @@ def append_gpu_idx(gguf: GGUFWriter, i_layer: int, activation, select_count) ->
raw_shape=gpu_idx.shape[::-1],
raw_dtype=GGMLQuantizationType.I32,
)
if skip_bucket:
return
append_gpu_bucket(gguf, i_layer, indices)


def append_gpu_bucket(gguf: GGUFWriter, i_layer: int, indices: torch.Tensor):
indices = indices.numpy().astype(np.int32)
gpu_bucket = np.sort(indices)
key = f"blk.{i_layer}.gpu_bucket"
Expand All @@ -46,14 +58,24 @@ def append_gpu_idx(gguf: GGUFWriter, i_layer: int, activation, select_count) ->
raw_dtype=GGMLQuantizationType.I32,
)

def export_split(activations_path: str, output_path: str, solved_list: list[int], vram_capacity: int):
predictors = load_activation_weights(Path(activations_path)) # predictor => activation acount
gguf_out = GGUFWriter(output_path, "generic.gpu_index")
for i, (activation, selected_count) in enumerate(zip(predictors, solved_list)):
append_gpu_idx(gguf_out, i, activation, selected_count)

def export_split(
activations_path: str,
output_path: str,
solved_list: list[int],
vram_capacity: int,
for_moe=False,
):
activations = load_activation_weights(Path(activations_path))
gguf_out = GGUFWriter(
output_path, "moe.gpu_idx" if for_moe else "generic.gpu_index"
)
for i, (activation, selected_count) in enumerate(zip(activations, solved_list)):
# MoE models do not have remapping of neurons, so skip the bucket
append_gpu_idx(gguf_out, i, activation, selected_count, skip_bucket=not for_moe)

# set kvs
gguf_out.add_block_count(len(predictors))
gguf_out.add_block_count(len(activations))
# TODO: better to save the actual capacity that split neurons require
gguf_out.add_uint64(gguf.Keys.Split.VRAM_CAPACITY, vram_capacity)

Expand All @@ -69,4 +91,3 @@ def export_split(activations_path: str, output_path: str, solved_list: list[int]
fout.write(struct.pack("<I", 3))

print(f"exported GPU index to {output_path}")

0 comments on commit 09d79ec

Please sign in to comment.