diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 1b7b2a39b67..364141c8e40 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -332,10 +332,11 @@ In that case, you should see a warning message and we will fall back to the (slo -By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.backends.cuda.sdp_kernel`](https://pytorch.org/docs/master/backends.html#torch.backends.cuda.sdp_kernel) as a context manager: +By default, SDPA selects the most performant kernel available but you can check whether a backend is available in a given setting (hardware, problem size) with [`torch.nn.attention.sdpa_kernel`](https://pytorch.org/docs/stable/generated/torch.nn.attention.sdpa_kernel.html) as a context manager: ```diff import torch ++ from torch.nn.attention import SDPBackend, sdpa_kernel from transformers import AutoModelForCausalLM, AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") @@ -344,7 +345,7 @@ model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=to input_text = "Hello my dog is cute and" inputs = tokenizer(input_text, return_tensors="pt").to("cuda") -+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): ++ with sdpa_kernel(SDPBackend.FLASH_ATTENTION): outputs = model.generate(**inputs) print(tokenizer.decode(outputs[0], skip_special_tokens=True)) @@ -518,6 +519,7 @@ It is often possible to combine several of the optimization techniques described ```py import torch +from torch.nn.attention import SDPBackend, sdpa_kernel from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # load model in 4-bit @@ -536,7 +538,7 @@ input_text = "Hello my dog is cute and" inputs = tokenizer(input_text, return_tensors="pt").to("cuda") # enable FlashAttention -with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): +with sdpa_kernel(SDPBackend.FLASH_ATTENTION): outputs = model.generate(**inputs) print(tokenizer.decode(outputs[0], skip_special_tokens=True))