-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
doc-doc
committed
May 17, 2021
1 parent
a93ab41
commit 2744b57
Showing
32 changed files
with
11,624 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.idea | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,68 @@ | ||
# NExT-OE | ||
open-ended QA | ||
# NExT-QA <img src="images/logo.png" height="64" width="128"> | ||
|
||
We reproduce some SOTA VideoQA methods to provide benchmark results for our NExT-QA dataset published on CVPR2021 (with 1 Strong Accept and 2 Weak Accepts). | ||
|
||
NExT-QA is a VideoQA benchmark targeting the explanation of video contents. It challenges QA models to reason about the causal and temporal actions and understand the rich object interactions in daily activities. We set up both multi-choice and open-ended QA tasks on the dataset. This repo. provides resources for open-ended QA; multi-choice QA is found in [NExT-QA](https://github.com/doc-doc/NExT-QA). For more details, please refer to our [dataset](https://doc-doc.github.io/junbin.github.io/docs/nextqa.html) page. | ||
|
||
## Environment | ||
|
||
Anaconda 4.8.4, python 3.6.8, pytorch 1.6 and cuda 10.2. For other libs, please refer to the file requirements.txt. | ||
|
||
## Install | ||
Please create an env for this project using anaconda (should install [anaconda](https://docs.anaconda.com/anaconda/install/linux/) first) | ||
``` | ||
>conda create -n videoqa python=3.6.8 | ||
>conda activate videoqa | ||
>git clone https://github.com/doc-doc/NExT-OE.git | ||
>pip install -r requirements.txt | ||
``` | ||
## Data Preparation | ||
Please download the pre-computed features and QA annotations from [here](https://drive.google.com/drive/folders/1gKRR2es8-gRTyP25CvrrVtV6aN5UxttF?usp=sharing). There are 3 zip files: | ||
- ```['vid_feat.zip']```: Appearance and motion feature for video representation. | ||
- ```['nextqa.zip']```: Annotations of QAs and GloVe Embeddings. | ||
- ```['models.zip']```: Learned HGA model. | ||
|
||
After downloading the data, please create a folder ```['data']``` at the same directory as ```['NExT-OE']```, then unzip the video and QA features into it. You will have directories like ```['data/vid_feat/', and 'NExT-OE/']``` in your workspace. Please unzip the files in ```['nextqa.zip']``` into ```['NExT-OE/dataset/nextqa']``` and ```['models.zip']``` into ```['NExT-OE/models/']```. | ||
|
||
|
||
## Usage | ||
Once the data is ready, you can easily run the code. First, to test the environment and code, we provide the prediction and model of the SOTA approach (i.e., HGA) on NExT-QA. | ||
You can get the results reported in the paper by running: | ||
``` | ||
>python eval_oe.py | ||
``` | ||
The command above will load the prediction file under ['results/'] and evaluate it. | ||
You can also obtain the prediction by running: | ||
``` | ||
>./main.sh 0 val #Test the model with GPU id 0 | ||
``` | ||
The command above will load the model under ['models/'] and generate the prediction file. | ||
If you want to train the model, please run | ||
``` | ||
>./main.sh 0 train # Train the model with GPU id 0 | ||
``` | ||
It will train the model and save to ['models']. (*The results may be slightly different depending on the environments*) | ||
## Results | ||
| Methods | Text Rep. | Acc_C | Acc_T | Acc_D | Acc | | ||
| -------------------------| --------: | ----: | ----: | ----: | ---:| | ||
| BlindQA | GloVe | 12.14 | 14.85 | 40.41 | 18.88 | | ||
| [STVQA](https://openaccess.thecvf.com/content_cvpr_2017/papers/Jang_TGIF-QA_Toward_Spatio-Temporal_CVPR_2017_paper.pdf) [CVPR17] | GloVe | 12.52 | 14.57 | 45.64 | 20.08 | | ||
| [UATT](https://ieeexplore.ieee.org/document/8017608) [TIP17] | GloVe | 13.62 | **16.23** | 43.41 | 20.65 | | ||
| [HME](https://openaccess.thecvf.com/content_CVPR_2019/papers/Fan_Heterogeneous_Memory_Enhanced_Multimodal_Attention_Model_for_Video_Question_Answering_CVPR_2019_paper.pdf) [CVPR19] | GloVe | 12.83 | 14.76 | 45.13 | 20.18 | | ||
| [HCRN](https://openaccess.thecvf.com/content_CVPR_2020/papers/Le_Hierarchical_Conditional_Relation_Networks_for_Video_Question_Answering_CVPR_2020_paper.pdf) [CVPR20] | GloVe | 12.53 | 15.37 | 45.29 | 20.25 | | ||
| [HGA](https://ojs.aaai.org//index.php/AAAI/article/view/6767) [AAAI20] | GloVe | **14.76** | 14.90 | **46.60** | **21.48** | | ||
|
||
## Multi-choice QA v.s Open-ended QA | ||
![vis mc_oe](./images/res-mc-oe.png) | ||
## Citation | ||
``` | ||
@inproceedings{Xiao2021NExT-QA, | ||
title={NExT-QA: Next Phase of Question-Answering to Explaining Temporal Actions}, | ||
author={Xiao, Junbin and Shang, Xindi and Angela Yao and Chua, Tat-Seng}, | ||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, | ||
year={2021}, | ||
organization={IEEE} | ||
} | ||
``` | ||
## Acknowledgement | ||
Our reproduction of the methods are based on the respective official repositories, we thanks the authors to release their code. If you use the related part, please cite the corresponding paper commented in the code. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import nltk | ||
# nltk.download('punkt') | ||
import pickle | ||
import argparse | ||
from utils import load_file, save_file | ||
from collections import Counter | ||
|
||
|
||
|
||
class Vocabulary(object): | ||
"""Simple vocabulary wrapper.""" | ||
def __init__(self): | ||
self.word2idx = {} | ||
self.idx2word = {} | ||
self.idx = 0 | ||
|
||
def add_word(self, word): | ||
if not word in self.word2idx: | ||
self.word2idx[word] = self.idx | ||
self.idx2word[self.idx] = word | ||
self.idx += 1 | ||
|
||
def __call__(self, word): | ||
if not word in self.word2idx: | ||
return self.word2idx['<unk>'] | ||
return self.word2idx[word] | ||
|
||
def __len__(self): | ||
return len(self.word2idx) | ||
|
||
|
||
|
||
def build_vocab(anno_file, threshold): | ||
"""Build a simple vocabulary wrapper.""" | ||
|
||
annos = load_file(anno_file) | ||
print('total QA pairs', len(annos)) | ||
counter = Counter() | ||
|
||
for rid, (qns, ans) in enumerate(zip(annos['question'], annos['answer'])): | ||
# qns, ans = vqa['question'], vqa['answer'] | ||
text = qns +' ' +ans | ||
tokens = nltk.tokenize.word_tokenize(text.lower()) | ||
counter.update(tokens) | ||
|
||
counter = sorted(counter.items(), key=lambda item:item[1], reverse=True) | ||
|
||
# If the word frequency is less than 'threshold', then the word is discarded. | ||
words = [item[0] for item in counter if item[1] >= threshold] | ||
|
||
# Create a vocab wrapper and add some special tokens. | ||
vocab = Vocabulary() | ||
vocab.add_word('<pad>') | ||
vocab.add_word('<start>') | ||
vocab.add_word('<end>') | ||
vocab.add_word('<unk>') | ||
|
||
# Add the words to the vocabulary. | ||
for i, word in enumerate(words): | ||
vocab.add_word(word) | ||
|
||
return vocab | ||
|
||
|
||
def main(args): | ||
vocab = build_vocab(args.caption_path, args.threshold) | ||
vocab_path = args.vocab_path | ||
with open(vocab_path, 'wb') as f: | ||
pickle.dump(vocab, f) | ||
print("Total vocabulary size: {}".format(len(vocab))) | ||
print("Saved the vocabulary wrapper to '{}'".format(vocab_path)) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--anno_path', type=str, | ||
default='dataset/nextqa/all.csv', | ||
help='path for train annotation file') | ||
parser.add_argument('--vocab_path', type=str, default='dataset/nextqa/vocab.pkl', | ||
help='path for saving vocabulary wrapper') | ||
parser.add_argument('--threshold', type=int, default=5, | ||
help='minimum word count threshold') | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# ==================================================== | ||
# @Time : 15/5/20 3:48 PM | ||
# @Author : Xiao Junbin | ||
# @Email : [email protected] | ||
# @File : __init__.py | ||
# ==================================================== | ||
from .sample_loader import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
# ==================================================== | ||
# @Time : 19/5/20 10:42 PM | ||
# @Author : Xiao Junbin | ||
# @Email : [email protected] | ||
# @File : sample_loader.py | ||
# ==================================================== | ||
import torch | ||
from torch.utils.data import Dataset, DataLoader | ||
from .util import load_file, pkdump, pkload | ||
import os.path as osp | ||
import numpy as np | ||
import nltk | ||
import h5py | ||
|
||
class VidQADataset(Dataset): | ||
"""load the dataset in dataloader""" | ||
|
||
def __init__(self, video_feature_path, video_feature_cache, sample_list_path, vocab_qns, vocab_ans, mode): | ||
self.video_feature_path = video_feature_path | ||
self.vocab_qns = vocab_qns | ||
self.vocab_ans = vocab_ans | ||
sample_list_file = osp.join(sample_list_path, '{}.csv'.format(mode)) | ||
self.sample_list = load_file(sample_list_file) | ||
self.video_feature_cache = video_feature_cache | ||
self.use_frame = True | ||
self.use_mot = True | ||
self.frame_feats = {} | ||
self.mot_feats = {} | ||
vid_feat_file = osp.join(video_feature_path, 'vid_feat/app_mot_{}.h5'.format(mode)) | ||
with h5py.File(vid_feat_file, 'r') as fp: | ||
vids = fp['ids'] | ||
feats = fp['feat'] | ||
for id, (vid, feat) in enumerate(zip(vids, feats)): | ||
if self.use_frame: | ||
self.frame_feats[str(vid)] = feat[:, :2048] # (16, 2048) | ||
if self.use_mot: | ||
self.mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048) | ||
|
||
|
||
def __len__(self): | ||
return len(self.sample_list) | ||
|
||
|
||
def get_video_feature(self, video_name): | ||
""" | ||
""" | ||
if self.use_frame: | ||
app_feat = self.frame_feats[video_name] | ||
video_feature = app_feat # (16, 2048) | ||
if self.use_mot: | ||
mot_feat = self.mot_feats[video_name] | ||
video_feature = np.concatenate((video_feature, mot_feat), axis=1) #(16, 4096) | ||
|
||
return torch.from_numpy(video_feature).type(torch.float32) | ||
|
||
|
||
def get_word_idx(self, text, src='qns'): | ||
""" | ||
convert relation to index sequence | ||
:param relation: | ||
:return: | ||
""" | ||
if src=='qns': vocab = self.vocab_qns | ||
elif src=='ans': vocab = self.vocab_ans | ||
tokens = nltk.tokenize.word_tokenize(str(text).lower()) | ||
text = [] | ||
text.append(vocab('<start>')) | ||
text.extend([vocab(token) for i,token in enumerate(tokens) if i < 23]) | ||
#text.append(vocab('<end>')) | ||
target = torch.Tensor(text) | ||
|
||
return target | ||
|
||
|
||
def __getitem__(self, idx): | ||
""" | ||
""" | ||
|
||
sample = self.sample_list.loc[idx] | ||
video_name, qns, ans = sample['video'], sample['question'], sample['answer'] | ||
qid, qtype = sample['qid'], sample['type'] | ||
video_name = str(video_name) | ||
qns, ans, qid, qtype = str(qns), str(ans), str(qid), str(qtype) | ||
|
||
|
||
#video_feature = torch.tensor([0]) | ||
video_feature = self.get_video_feature(video_name) | ||
|
||
qns2idx = self.get_word_idx(qns, 'qns') | ||
ans2idx = self.get_word_idx(ans, 'ans') | ||
|
||
return video_feature, qns2idx, ans2idx, video_name, qid, qtype | ||
|
||
|
||
class QALoader(): | ||
def __init__(self, batch_size, num_worker, video_feature_path, video_feature_cache, | ||
sample_list_path, vocab_qns, vocab_ans, train_shuffle=True, val_shuffle=False): | ||
self.batch_size = batch_size | ||
self.num_worker = num_worker | ||
self.video_feature_path = video_feature_path | ||
self.video_feature_cache = video_feature_cache | ||
self.sample_list_path = sample_list_path | ||
self.vocab_qns = vocab_qns | ||
self.vocab_ans = vocab_ans | ||
self.train_shuffle = train_shuffle | ||
self.val_shuffle = val_shuffle | ||
|
||
|
||
def run(self, mode=''): | ||
if mode != 'train': | ||
train_loader = '' | ||
val_loader = self.validate(mode) | ||
else: | ||
train_loader = self.train('train') | ||
val_loader = self.validate('val') | ||
return train_loader, val_loader | ||
|
||
|
||
def train(self, mode): | ||
|
||
training_set = VidQADataset(self.video_feature_path, self.video_feature_cache, self.sample_list_path, | ||
self.vocab_qns, self.vocab_ans, mode) | ||
|
||
print('Eligible QA pairs for training : {}'.format(len(training_set))) | ||
train_loader = DataLoader( | ||
dataset=training_set, | ||
batch_size=self.batch_size, | ||
shuffle=self.train_shuffle, | ||
num_workers=self.num_worker, | ||
collate_fn=collate_fn) | ||
|
||
return train_loader | ||
|
||
def validate(self, mode): | ||
|
||
validation_set = VidQADataset(self.video_feature_path, self.video_feature_cache, self.sample_list_path, | ||
self.vocab_qns, self.vocab_ans, mode) | ||
|
||
print('Eligible QA pairs for validation : {}'.format(len(validation_set))) | ||
val_loader = DataLoader( | ||
dataset=validation_set, | ||
batch_size=self.batch_size, | ||
shuffle=self.val_shuffle, | ||
num_workers=self.num_worker, | ||
collate_fn=collate_fn) | ||
|
||
return val_loader | ||
|
||
|
||
def collate_fn (data): | ||
""" | ||
""" | ||
data.sort(key=lambda x : len(x[1]), reverse=True) | ||
videos, qns2idx, ans2idx, video_names, qids, qtypes = zip(*data) | ||
|
||
#merge videos | ||
videos = torch.stack(videos, 0) | ||
|
||
#merge relations | ||
qns_lengths = [len(qns) for qns in qns2idx] | ||
targets_qns = torch.zeros(len(qns2idx), max(qns_lengths)).long() | ||
for i, qns in enumerate(qns2idx): | ||
end = qns_lengths[i] | ||
targets_qns[i, :end] = qns[:end] | ||
|
||
ans_lengths = [len(ans) for ans in ans2idx] | ||
targets_ans = torch.zeros(len(ans2idx), max(ans_lengths)).long() | ||
for i, ans in enumerate(ans2idx): | ||
end = ans_lengths[i] | ||
targets_ans[i, :end] = ans[:end] | ||
|
||
return videos, targets_qns, qns_lengths, targets_ans, ans_lengths, video_names, qids, qtypes |
Oops, something went wrong.