Skip to content

💫 Fusion in decoder customizing (FiD-Light, FiDO)

Notifications You must be signed in to change notification settings

jjonhwa/Fusion-in-Decoder-Custom

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

25 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FiD-Custom

Contents

FiD

image

FiD-Light

image
  • First-K: When concatenate each embedding vector to single sequence in FiD, concatenation is performed using only the incorporation of K tokens from the beginning. (Applied)
  • Source Pointing: Return the index of the evidence passages from which the correct answer was extracted -> Use it to perform Re-rank on the passages, and use re-ranked passages to extract the correct answer. (Not Applied)

FiDO

Screen Shot 2023-11-06 at 11 03 00 AM Screen Shot 2023-11-06 at 11 03 56 AM

  • LSA(Layer-Sparse cross-Attention): Just apply the cross-attention at (n, 2n, 3n, ..)th layer. (Applied)
  • MQA(Multi-Query Attention): When applying MHA, Key and Value are applied by sharing one single head. (Applied as GQA)
  • Decoder Scaling: Scale up the decoder model's size. (Applied)
  • NOTE: It was implemented only by increasing the size with a decoder model based on GPT-2. If you want to connect other Large models, please refer to the fid/FiDSKT_train.py. In addition, in the case of decoder scaling, only First-K and LSA are applicable.

Retrieval

  • If you want to train your own retrieval model. then, check this repository
  • The code related to the retrieval was created based on the above repository

Data

+- fid 
    |   +- src 
    |   +- FiDT5_train.py
    |   +- FiDSKT_train.py

+- preprocess (make retrieved dataset)
    |   +- fid_data.py
    |   +- preprocess.py

+- retrieval (train retrieval model)
    |   +- datatset.py
    |   +- model.py
    |   +- retreiver_train.py
    |   +- utils.py

+- inference.py
+- requirements.sh

Run

setup

bash requirements.sh

preprocess

  • Make retrieved dataset (If you have a lot of data, then, It takes a very long time. If someone can optimize that code, please pull-request to me)
  • Can use already made jjonhwa/SECOND_KQ_V2 dataset. It is linked with fid train code.
python3 preprocess/fid_data.py

fid

# Original fid
python3 fid/FiDT5_train.py

# with first_k
# python3 fid/FiDT5_train.py --first_k 8

# with LSA
# python3 fid/FiDT5_train.py --n_cross_layer 6

# with GQA
# python3 fid/FiDT5_train.py --kv_heads 4

# The above can be combined and utilized. as follows
# python3 fid/FiDT5_train.py --n_cross_layer 6 --kv_heads 4 --first_k 8

# with decoder scaling
# python3 fid/FiDSKT_train.py

# with decoder scaling & LSA
# python3 fid/FiDSKT_train.py --n_cross_layer 6

inference

  • If you want to experience inference.py, then need to download the fid-trained model. (example model is fine-tuned by summarization dataset)
  • But, that model is not well trained. It is trained with KorQuAD Dataset and slightly overfitted that data.
# download fid-trained model
gdown https://drive.google.com/uc?id=1bg8tCSImGuNQGNrGdR99RcY_pFHZJUT0

# Go inference
python3 inference.py

Experiments

FiDT5 Original

Screen Shot 2023-11-06 at 2 16 33 PM

FiDT5 Original, LSA and FiDSKT

Screen Shot 2023-11-02 at 5 42 30 PM

FiDT5 LSA and GQA

Screen Shot 2023-11-02 at 5 41 08 PM

FiDT5 First-K

Screen Shot 2023-11-02 at 5 40 40 PM

Model Info

EVAL EM EVAL TIME MODEL PARAMETERS
FiDT5 Original 39.63 1,519s 783,019,008
FiDT5 K8 (FiD-Light) 26.37 1,215s 783,019,008
FiDT5 K32 (FiD-Light) 26.33 1,233s 783,019,008
FiDT5 LSA6 (FiDO) 37.12 1,268s 699,112,448
FiDT5 LSA4 (FiDO) 36.68 1,265s 707,503,104
FiDT5 GQA4 (FiDO) 23.00 1,271s 783,019,008
FiDSKT Original (FiDO) . . 1,892,694,144
FiDSKT LSA6 (FiDO) 10.82 2,576s 1,597,551,744
  • Dataset: KorQuAD 2.0 Dataset
  • EVAL EM: Best Evaluation Score
  • EVAL TIME: Average evaluation times in the 3 steps in order of best performance
  • Backbone Model(FiDT5): KETI-AIR/ke-t5-large
  • Backbone Model(FiDSKT): Encoder -> KETI-AIR/ke-t5-large(encoder), Decoder -> skt/ko-gpt-trinity-1.2B-v0.5
  • n_context: FiDT5 - 10, FiDSKT - 5
  • Learning Rate: 1e-4
  • Optimizer: Adam
  • batch_size: 2 per GPU

Analysis

  • All methods improve evaluation time, but degrade performance
  • But, LSA don't downgrade performance much and also get great improve evaluation time
  • The difference between LSA4 and LSA6 is not significant. based on this point, difference in the number of cross-attention does not show a significant difference in performance when reduced more than a certain number.
  • In the case of FiDSKT, the performance seems to be low because the parameters of Cross-Attention are not learned. If additional learning is conducted, there is room for performance improvement.

About

💫 Fusion in decoder customizing (FiD-Light, FiDO)

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages