diff --git a/train.py b/train.py index d1685ce..31cbe55 100644 --- a/train.py +++ b/train.py @@ -60,7 +60,6 @@ def main(args): device_map=accelerator.device, torch_dtype=torch.bfloat16, rope_theta=args.rope_theta, - sliding_window=None, _attn_implementation="flash_attention_2", )