Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RoPE Frequency Base and Frequency Scale Support #262

Open
ChrisCates opened this issue Aug 28, 2023 · 3 comments
Open

RoPE Frequency Base and Frequency Scale Support #262

ChrisCates opened this issue Aug 28, 2023 · 3 comments

Comments

@ChrisCates
Copy link

ChrisCates commented Aug 28, 2023

As of now, there is no way to modify RoPE Frequency Base and RoPE Frequency Scale.

We would need to edit rope.cu to support parameters for frequency and scale:

__global__ void rope_cuda_kernel
(
half* __restrict__ x,
const half* __restrict__ sin,
const half* __restrict__ cos,
int rows_per_batch,
int head_dim,
int num_heads,
int past_len
)
{

We would also need to add arguments in model_init.py to support frequency and scale for RoPE:

exllama/model_init.py

Lines 29 to 30 in 21f4a12

parser.add_argument("-rnnh2", "--rmsnorm_no_half2", action = "store_true", help = "Don't use half2 in RMS norm kernel")
parser.add_argument("-rpnh2", "--rope_no_half2", action = "store_true", help = "Don't use half2 in RoPE kernel")

Here is a proposed argument to be added to the existing model_init.py:

    parser.add_argument("--rope-freq-base",  type = int, help = "The frequency base for the RoPE Kernel", default=10000)
    parser.add_argument("--rope-freq-scale",  type = int, help = "The frequency scale for the RoPE Kernel", default=1)

Note that this is important to resolve issues like #261 and #260 when context length is larger during inference.

@Ph0rk0z
Copy link
Contributor

Ph0rk0z commented Aug 29, 2023

It's exactly the same as alpha. BTW the "base" for codellama base is about alpha 100.

@ChrisCates
Copy link
Author

@Ph0rk0z thanks man! I was wondering why I couldn't find the relevant source. But, just found it.

exllama/model.py

Lines 126 to 127 in 21f4a12

def calculate_rotary_embedding_base(self):
self.rotary_embedding_base = self.rotary_embedding_base * self.alpha_value ** (self.head_dim / (self.head_dim-2))

@ChrisCates
Copy link
Author

ChrisCates commented Sep 9, 2023

As per discussion in issue #270. This issue is being reopened. The following is a fairly informal proposal for @turboderp to review:

Instead of replacing the current rotary embedding calculation. We have optionality for two. Utilizing rope_alpha and rope_theta for the first calculation and rope_base and rope_frequency for the second. We should change the --alpha flag to --rope-alpha for extra clarity. We should use something like --use-rope-alpha and --use-rope-base to flag for calculation types.

Second, let's just pull the calculations done with RoPE from the llama.cpp repo. This will be easier and faster and given the nature of how rotary embeddings function, should not be a problem.

Third, while not necessary, an additional testing script for PPL and maybe reviewing sample outputs would be nice. Just to see what are the optimal alpha, theta or base and frequency values are. This is up for discussion and should be a separate PR.

I'd be happy to formalize this into a spec now. In terms of implementation. I will take a deep dive in a couple weeks assuming no one else is working on it.

@ChrisCates ChrisCates reopened this Sep 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants