You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm fine-tuning and evaluating Chronos on a large dataset, and I'm wondering about how to improve Chronos' inference throughput. I'm using an EC2 instance with multiple GPUs.
I'm doing something like this:
dataset = load_dataset() # shape [number of samples, common sample size (I've left-padded already)]
chronos = ChronosPipeline.from_pretrained(
pretrained_model_name_or_path=model_name,
device_map=device,
torch_dtype=torch.bfloat16,
)
batch_size = 128
forecasts = []
for ii in tqdm(range(0, len(dataset), batch_size)):
batch_forecasts = chronos.predict(
context=torch.tensor(dataset[ii : ii + batch_size, :]),
prediction_length=prediction_length,
num_samples=num_samples,
temperature=1.0,
top_k=50,
top_p=1.0,
)
forecasts.append(batch_forecasts)
I've noticed that the code above runs much faster using a single GPU (e.g. device_map="CUDA:0") than using device_map="auto", which uses all available GPUs. I guess this is due to inefficient parallelisation across GPUs.
I've done some googling and the methods people normally use (e.g. vllm) cannot be used here as they don't support T5 models.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi 👋🏽
I'm fine-tuning and evaluating Chronos on a large dataset, and I'm wondering about how to improve Chronos' inference throughput. I'm using an EC2 instance with multiple GPUs.
I'm doing something like this:
I've noticed that the code above runs much faster using a single GPU (e.g.
device_map="CUDA:0"
) than usingdevice_map="auto"
, which uses all available GPUs. I guess this is due to inefficient parallelisation across GPUs.I've done some googling and the methods people normally use (e.g. vllm) cannot be used here as they don't support T5 models.
Any ideas welcome 😃
Beta Was this translation helpful? Give feedback.
All reactions