-
Notifications
You must be signed in to change notification settings - Fork 0
/
positional_embedding.py
65 lines (56 loc) · 2.16 KB
/
positional_embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import torch.nn as nn
from typing import Any
class PositionalEmbedding(nn.Module):
"""
PositionalEmbedding is a module that generates positional embeddings for sequence positions.
Attributes:
-----------
max_sequence_length : int
The maximum length of the sequences.
embedding_dim : int
The dimension of the embedding vectors.
embedding : nn.Parameter
The learnable positional embedding matrix initialised with a standard normal distribution.
Methods:
--------
forward(position: int) -> torch.Tensor:
Retrieves the positional embeddings for the given positions.
"""
def __init__(self, max_sequence_length: int, embedding_dim: int, **kwargs: Any) -> None:
"""
Initialises the PositionalEmbedding module.
Parameters:
-----------
max_sequence_length : int
The maximum length of the sequences.
embedding_dim : int
The dimension of the embedding vectors.
**kwargs : dict
Additional keyword arguments (not used in this implementation).
"""
super().__init__()
self.max_sequence_length = max_sequence_length
self.embedding_dim = embedding_dim
self.embedding = nn.Parameter(torch.normal(0, 1, (self.max_sequence_length, self.embedding_dim)))
def forward(self, position: int) -> torch.Tensor:
"""
Forward pass for the PositionalEmbedding module.
Retrieves the positional embeddings for the given positions.
Parameters:
-----------
position : int
The position up to which the embeddings are retrieved.
Returns:
--------
torch.Tensor
A tensor of positional embeddings up to the specified position.
"""
return self.embedding[:position, :]
if __name__ == "__main__":
embedding_dim = 32
max_sequence_length = 100
positional_embedding = PositionalEmbedding(max_sequence_length, embedding_dim)
# lookup the positional embedding of a random position in the sequence
positional_embedding(max_sequence_length)
# (max_sequence_length, embedding_dim)