-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
bb2bc44
commit 7142003
Showing
12 changed files
with
3,019 additions
and
5 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
Copyright 2023 Meta | ||
|
||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: | ||
|
||
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. | ||
|
||
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. | ||
|
||
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. | ||
|
||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
# gpt-fast on Modal | ||
|
||
This is a demo of https://github.com/pytorch-labs/gpt-fast running on | ||
[Modal](https://modal.com). It demonstrates how to use speculative sampling, | ||
quantized models, and pytorch compilation to achieve upwards of 125 tokens/s | ||
with batch sizes of 1 (i.e. no vLLM-style continuous batching), on 7B models | ||
running on individual A100 80GB GPUs. It's a multi-file Modal app that | ||
integrates into an existing codebase (files other than `modal.py` were mostly | ||
taken as-is from `pytorch-labs/gpt-fast`), makes of container-lifecyle | ||
primitives, streams responses, and is also able to invoke already-deployed | ||
functions. | ||
|
||
TODO: | ||
- [ ] Make use of GPU checkpointing to avoid long cold starts. | ||
- [ ] Doc-ify modal.py, publish to website. | ||
- [ ] Make use of draft models for speculative sampling. | ||
- [ ] Run them on secondary GPUs? | ||
- [ ] Make use of tensor parallelism. | ||
- [ ] Fix (gpt-fast?) bug where subsequent generations end up generating | ||
using the prompt used to compile the model itself, or earlier prompt. | ||
Maybe some internal tensor getting recycled? | ||
|
||
To run one-off inference: | ||
``` | ||
۩ modal run gpt-fast.modal::main --prompt "Implement fibonacci in python" | ||
\ --no-compile-model | ||
... | ||
Loading model weights ... | ||
Using int8 weight-only quantization! | ||
Loading model weights took 11.08 seconds | ||
Starting inference for prompt = 'Implement fibonacci in python' | ||
with memoization. | ||
The time complexity should be O(n) | ||
The space complexity should be O(n) | ||
""" | ||
def fibonacci(n, mem=dict()): | ||
if n == 0: | ||
return 0 | ||
if n == 1: | ||
return 1 | ||
if n in mem: | ||
return mem[n] | ||
Time for inference 1: 13.24 sec total, 7.55 tokens/sec | ||
Bandwidth achieved: 51.91 GB/s | ||
... | ||
``` | ||
|
||
Compile the model for faster inference, at the cost of much longer cold-starts: | ||
``` | ||
۩ modal run gpt-fast.modal::main --prompt "Implement fibonacci in python" \ | ||
--compile-model | ||
... | ||
Running warmup inference ... | ||
Model compilation time: 298.49 seconds | ||
Starting inference for prompt = 'Implement fibonacci in python' | ||
... | ||
Time for inference 1: 0.81 sec total, 123.54 tokens/sec | ||
Bandwidth achieved: 856.83 GB/s | ||
``` | ||
|
||
Deploy the model and run inference against a container that's already compiled | ||
the pytorch model: | ||
``` | ||
۩ modal deploy gpt-fast.modal | ||
# Should happen instantaneously once deployed model is fully compiled, at | ||
# upwards of 125 tokens/sec. | ||
۩ modal run gpt-fast.modal::main --lookup-existing \ | ||
--prompt "Add two numbers in python" --num-samples 10 | ||
``` | ||
|
||
Run a web-version of the app using: | ||
``` | ||
۩ modal serve gpt-fast.app | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from pathlib import Path | ||
|
||
from modal import Function, Image, Mount, Stub, asgi_app | ||
|
||
stub = Stub("gpt-fast-app", image=Image.debian_slim()) | ||
|
||
|
||
@stub.function( | ||
mounts=[ | ||
Mount.from_local_dir( | ||
Path(__file__).parent.parent / "llm-frontend", | ||
remote_path="/assets", | ||
), | ||
], | ||
allow_concurrent_inputs=10, | ||
timeout=10 * 60, | ||
) | ||
@asgi_app(label="gpt-fast-app") | ||
def app(): | ||
import json | ||
from urllib.parse import unquote | ||
|
||
import fastapi | ||
import fastapi.staticfiles | ||
from fastapi.responses import StreamingResponse | ||
|
||
web_app = fastapi.FastAPI() | ||
|
||
@web_app.get("/model") | ||
async def model(): | ||
return {"name": "Llama-2-7b-chat-hf"} | ||
|
||
@web_app.get("/stats") | ||
async def stats(): | ||
stats = await Function.lookup( | ||
"gpt-fast", "Model.generate" | ||
).get_current_stats.aio() | ||
return { | ||
"backlog": stats.backlog, | ||
"num_total_runners": stats.num_total_runners, | ||
} | ||
|
||
@web_app.get("/completion/{question}") | ||
async def completion(question: str): | ||
async def generate(): | ||
fn = Function.lookup("gpt-fast", "Model.generate") | ||
for generated in fn.remote_gen(unquote(question)): | ||
yield f"data: {json.dumps(dict(text=generated), ensure_ascii=False)}\n\n" | ||
|
||
return StreamingResponse(generate(), media_type="text/event-stream") | ||
|
||
web_app.mount( | ||
"/", fastapi.staticfiles.StaticFiles(directory="/assets", html=True) | ||
) | ||
return web_app |
Oops, something went wrong.