- (Unoffitial) PyTorch implementation of FiD: Leveraging Passage Retrieval with Generative Models for Open Domain Question Answering, FiD-Light: Efficient and Effective Retrieval-Augmented Text GenerationDownload PDF and FiDO: Fusion-in-Decoder optimized for stronger performance and faster inference
- Also, the following pages were referenced for coding. fid-official and GQA
- First-K: When concatenate each embedding vector to single sequence in FiD, concatenation is performed using only the incorporation of K tokens from the beginning. (Applied)
- Source Pointing: Return the index of the evidence passages from which the correct answer was extracted -> Use it to perform Re-rank on the passages, and use re-ranked passages to extract the correct answer. (Not Applied)
- LSA(Layer-Sparse cross-Attention): Just apply the cross-attention at (n, 2n, 3n, ..)th layer. (Applied)
- MQA(Multi-Query Attention): When applying MHA, Key and Value are applied by sharing one single head. (Applied as GQA)
- Decoder Scaling: Scale up the decoder model's size. (Applied)
- NOTE: It was implemented only by increasing the size with a decoder model based on GPT-2. If you want to connect other Large models, please refer to the
fid/FiDSKT_train.py
. In addition, in the case of decoder scaling, only First-K and LSA are applicable.
- If you want to train your own retrieval model. then, check this repository
- The code related to the retrieval was created based on the above repository
+- fid
| +- src
| +- FiDT5_train.py
| +- FiDSKT_train.py
+- preprocess (make retrieved dataset)
| +- fid_data.py
| +- preprocess.py
+- retrieval (train retrieval model)
| +- datatset.py
| +- model.py
| +- retreiver_train.py
| +- utils.py
+- inference.py
+- requirements.sh
bash requirements.sh
- Make retrieved dataset (If you have a lot of data, then, It takes a very long time. If someone can optimize that code, please pull-request to me)
- Can use already made
jjonhwa/SECOND_KQ_V2
dataset. It is linked withfid
train code.
python3 preprocess/fid_data.py
# Original fid
python3 fid/FiDT5_train.py
# with first_k
# python3 fid/FiDT5_train.py --first_k 8
# with LSA
# python3 fid/FiDT5_train.py --n_cross_layer 6
# with GQA
# python3 fid/FiDT5_train.py --kv_heads 4
# The above can be combined and utilized. as follows
# python3 fid/FiDT5_train.py --n_cross_layer 6 --kv_heads 4 --first_k 8
# with decoder scaling
# python3 fid/FiDSKT_train.py
# with decoder scaling & LSA
# python3 fid/FiDSKT_train.py --n_cross_layer 6
- If you want to experience
inference.py
, then need to download the fid-trained model. (example model is fine-tuned by summarization dataset) - But, that model is not well trained. It is trained with KorQuAD Dataset and slightly overfitted that data.
# download fid-trained model
gdown https://drive.google.com/uc?id=1bg8tCSImGuNQGNrGdR99RcY_pFHZJUT0
# Go inference
python3 inference.py
EVAL EM | EVAL TIME | MODEL PARAMETERS | |
---|---|---|---|
FiDT5 Original | 39.63 | 1,519s | 783,019,008 |
FiDT5 K8 (FiD-Light) | 26.37 | 1,215s | 783,019,008 |
FiDT5 K32 (FiD-Light) | 26.33 | 1,233s | 783,019,008 |
FiDT5 LSA6 (FiDO) | 37.12 | 1,268s | 699,112,448 |
FiDT5 LSA4 (FiDO) | 36.68 | 1,265s | 707,503,104 |
FiDT5 GQA4 (FiDO) | 23.00 | 1,271s | 783,019,008 |
FiDSKT Original (FiDO) | . | . | 1,892,694,144 |
FiDSKT LSA6 (FiDO) | 10.82 | 2,576s | 1,597,551,744 |
- Dataset: KorQuAD 2.0 Dataset
- EVAL EM: Best Evaluation Score
- EVAL TIME: Average evaluation times in the 3 steps in order of best performance
- Backbone Model(FiDT5):
KETI-AIR/ke-t5-large
- Backbone Model(FiDSKT):
Encoder -> KETI-AIR/ke-t5-large(encoder)
,Decoder -> skt/ko-gpt-trinity-1.2B-v0.5
- n_context:
FiDT5 - 10
,FiDSKT - 5
- Learning Rate:
1e-4
- Optimizer:
Adam
- batch_size:
2 per GPU
- All methods improve evaluation time, but degrade performance
- But, LSA don't downgrade performance much and also get great improve evaluation time
- The difference between LSA4 and LSA6 is not significant. based on this point, difference in the number of cross-attention does not show a significant difference in performance when reduced more than a certain number.
- In the case of FiDSKT, the performance seems to be low because the parameters of Cross-Attention are not learned. If additional learning is conducted, there is room for performance improvement.