Skip to content

Commit

Permalink
Improve cache manager handling to reuse one allocated static cache
Browse files Browse the repository at this point in the history
  • Loading branch information
fiedorowicz1 committed Nov 27, 2024
1 parent 0991125 commit 838eb98
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
8 changes: 3 additions & 5 deletions chat_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
app = FastAPI()

# Global variables
inputs = model = tokenizer = gen_queue = cache_manager = None
inputs = model = tokenizer = gen_queue = None
device = torch.device("cuda:0")


Expand Down Expand Up @@ -171,6 +171,7 @@ def master_loop(inputs, device, gen_queue, interval_minutes=5):


def worker_loop():
cache_manager = KVCacheManager(model)
info: ControlInfo = chat_synchronize_ranks(inputs, device)
while not info.exit:
if not info.keepalive:
Expand Down Expand Up @@ -224,13 +225,10 @@ def main():
global inputs
inputs = torch.full((1, 131072), 128002, dtype=torch.long, device=device)

global cache_manager
cache_manager = KVCacheManager(model)

if args.compile:
model.model.original_forward = model.model.forward
model.model.compiled_forward = torch.compile(
model.model.forward, mode="reduce-overhead", fullgraph=True
model.model.forward, mode="reduce-overhead", dynamic=True
)

# Run the uvicorn server
Expand Down
16 changes: 13 additions & 3 deletions llama/chat_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,13 @@ class KVCacheManager:

def __init__(self, model: DistributedLlama):
self.model = model
self.compiled = hasattr(self.model.model, "compiled_forward")

if self.compiled:
self.static_cache = PipelineStaticCache(self.model)

self.cached_tokens = None
self.kv_cache = PipelineStaticCache(self.model)
self.kv_cache = self.static_cache if self.compiled else PipelineDynamicCache()

def get_cache(self, inputs, input_len, max_new_tokens):
# Check if the cache can be reused
Expand All @@ -168,7 +173,11 @@ def get_cache(self, inputs, input_len, max_new_tokens):
):
print("Cache miss")
self.cached_tokens = None
self.kv_cache = PipelineStaticCache(self.model)
if self.compiled:
self.kv_cache = self.static_cache
self.kv_cache.reset()
else:
self.kv_cache = PipelineDynamicCache()
else:
print("Cache hit")

Expand All @@ -178,10 +187,11 @@ def get_cache(self, inputs, input_len, max_new_tokens):
) and self.kv_cache.max_cache_len < (input_len + max_new_tokens):
print("Switching to dynamic cache")
self.cached_tokens = None
self.kv_cache.reset()
self.kv_cache = PipelineDynamicCache()

# Switch to compiled forward if available
if hasattr(self.model.model, "compiled_forward"):
if self.compiled:
if isinstance(self.kv_cache, PipelineStaticCache):
self.model.model.forward = self.model.model.compiled_forward
else:
Expand Down

0 comments on commit 838eb98

Please sign in to comment.