From 1024c872645f9bd84134ebef42b91e0513d24b3a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E9=87=91=E9=B9=8F=28Peng=2EJ=29?=
Date: Wed, 11 Oct 2023 11:55:48 +0800
Subject: [PATCH] banzhaf estimator
---
HBI/models/modeling.py | 4 +
HBI/models/modeling_estimator.py | 378 +++++++++++++++++++++++++++++++
README.md | 37 ++-
banzhaf_estimator.py | 360 +++++++++++++++++++++++++++++
4 files changed, 777 insertions(+), 2 deletions(-)
create mode 100644 HBI/models/modeling_estimator.py
create mode 100644 banzhaf_estimator.py
diff --git a/HBI/models/modeling.py b/HBI/models/modeling.py
index be848e0..a98e0d0 100644
--- a/HBI/models/modeling.py
+++ b/HBI/models/modeling.py
@@ -225,6 +225,10 @@ def forward(self, text_ids, text_mask, video, video_mask=None, idx=None, global_
banzhaf = self.banzhafmodel(logits.unsqueeze(1)).squeeze(1)
with torch.no_grad():
teacher = self.banzhafteacher(logits.unsqueeze(1).clone().detach()).squeeze(1).detach()
+ teacher = torch.einsum('btv,bt->btv', [teacher, text_mask])
+ teacher = torch.einsum('btv,bv->btv', [teacher, video_mask])
+ banzhaf = torch.einsum('btv,bt->btv', [banzhaf, text_mask])
+ banzhaf = torch.einsum('btv,bv->btv', [banzhaf, video_mask])
s_loss = self.kl(banzhaf, teacher) + self.kl(banzhaf.T, teacher.T)
loss += M_loss + self.config.skl * s_loss
diff --git a/HBI/models/modeling_estimator.py b/HBI/models/modeling_estimator.py
new file mode 100644
index 0000000..2722098
--- /dev/null
+++ b/HBI/models/modeling_estimator.py
@@ -0,0 +1,378 @@
+import os
+from collections import OrderedDict
+from types import SimpleNamespace
+import torch
+from torch import nn
+from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
+import torch.nn.functional as F
+from .module_clip import CLIP, convert_weights, _PT_NAME
+from .module_cross import CrossModel, Transformer as TransformerClip
+from .until_module import LayerNorm, AllGather, AllGather2, CrossEn, MSE, ArcCrossEn, KL
+import numpy as np
+from .banzhaf import BanzhafModule, BanzhafInteraction
+from .cluster import CTM, TCBlock
+
+allgather = AllGather.apply
+allgather2 = AllGather2.apply
+
+
+class ResidualLinear(nn.Module):
+ def __init__(self, d_int: int):
+ super(ResidualLinear, self).__init__()
+
+ self.fc_relu = nn.Sequential(nn.Linear(d_int, d_int),
+ nn.ReLU(inplace=True))
+
+ def forward(self, x):
+ x = x + self.fc_relu(x)
+ return x
+
+
+class HBI(nn.Module):
+ def __init__(self, config):
+ super(HBI, self).__init__()
+
+ self.config = config
+ self.interaction = config.interaction
+ self.agg_module = getattr(config, 'agg_module', 'meanP')
+ backbone = getattr(config, 'base_encoder', "ViT-B/32")
+
+ assert backbone in _PT_NAME
+ model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[backbone])
+ if os.path.exists(model_path):
+ FileNotFoundError
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location="cpu").eval()
+ state_dict = model.state_dict()
+ except RuntimeError:
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len(
+ [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_resolution = vision_patch_size * grid_size
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+ self.clip = CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size,
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers)
+
+ if torch.cuda.is_available():
+ convert_weights(self.clip) # fp16
+
+ cross_config = SimpleNamespace(**{
+ "attention_probs_dropout_prob": 0.1,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.1,
+ "hidden_size": 512,
+ "initializer_range": 0.02,
+ "intermediate_size": 2048,
+ "max_position_embeddings": 128,
+ "num_attention_heads": 8,
+ "num_hidden_layers": 4,
+ "vocab_size": 512,
+ "soft_t": 0.07,
+ })
+ cross_config.max_position_embeddings = context_length
+ cross_config.hidden_size = transformer_width
+ self.cross_config = cross_config
+ if self.interaction == 'wti':
+ self.text_weight_fc = nn.Sequential(
+ nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True),
+ nn.Linear(2 * transformer_width, 1))
+ self.video_weight_fc = nn.Sequential(
+ nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True),
+ nn.Linear(2 * transformer_width, 1))
+
+ self.text_weight_fc0 = nn.Sequential(
+ nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True),
+ nn.Linear(2 * transformer_width, 1))
+ self.video_weight_fc0 = nn.Sequential(
+ nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True),
+ nn.Linear(2 * transformer_width, 1))
+
+ self.text_weight_fc1 = nn.Sequential(
+ nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True),
+ nn.Linear(2 * transformer_width, 1))
+ self.video_weight_fc1 = nn.Sequential(
+ nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True),
+ nn.Linear(2 * transformer_width, 1))
+
+ if self.agg_module in ["seqLSTM", "seqTransf"]:
+ self.frame_position_embeddings = nn.Embedding(cross_config.max_position_embeddings,
+ cross_config.hidden_size)
+ if self.agg_module == "seqTransf":
+ self.transformerClip = TransformerClip(width=transformer_width,
+ layers=config.num_hidden_layers,
+ heads=transformer_heads)
+ if self.agg_module == "seqLSTM":
+ self.lstm_visual = nn.LSTM(input_size=cross_config.hidden_size, hidden_size=cross_config.hidden_size,
+ batch_first=True, bidirectional=False, num_layers=1)
+
+ self.loss_fct = CrossEn(config)
+ self.loss_arcfct = ArcCrossEn(margin=10)
+ self.banzhafteacher = BanzhafModule(64)
+ self.banzhafinteraction = BanzhafInteraction(config.max_words, config.max_frames, 100)
+
+ self.apply(self.init_weights) # random init must before loading pretrain
+ self.clip.load_state_dict(state_dict, strict=False)
+
+ self.mse = MSE()
+ self.kl = KL()
+
+ ## ===> Initialization trick [HARD CODE]
+ new_state_dict = OrderedDict()
+
+ if self.agg_module in ["seqLSTM", "seqTransf"]:
+ contain_frame_position = False
+ for key in state_dict.keys():
+ if key.find("frame_position_embeddings") > -1:
+ contain_frame_position = True
+ break
+ if contain_frame_position is False:
+ for key, val in state_dict.items():
+ if key == "positional_embedding":
+ new_state_dict["frame_position_embeddings.weight"] = val.clone()
+ continue
+ if self.agg_module in ["seqTransf"] and key.find("transformer.resblocks") == 0:
+ num_layer = int(key.split(".")[2])
+ # cut from beginning
+ if num_layer < config.num_hidden_layers:
+ new_state_dict[key.replace("transformer.", "transformerClip.")] = val.clone()
+ continue
+
+ self.load_state_dict(new_state_dict, strict=False) # only update new state (seqTransf/seqLSTM/tightTransf)
+ ## <=== End of initialization trick
+
+ for param in self.clip.parameters():
+ param.requires_grad = False # not update by gradient
+ for param in self.transformerClip.parameters():
+ param.requires_grad = False # not update by gradient
+ for param in self.frame_position_embeddings.parameters():
+ param.requires_grad = False # not update by gradient
+
+ for param in self.text_weight_fc.parameters():
+ param.requires_grad = False # not update by gradient
+ for param in self.video_weight_fc.parameters():
+ param.requires_grad = False # not update by gradient
+
+
+ def forward(self, text_ids, text_mask, video, video_mask=None, idx=None, global_step=0):
+ text_ids = text_ids.view(-1, text_ids.shape[-1])
+ text_mask = text_mask.view(-1, text_mask.shape[-1])
+ video_mask = video_mask.view(-1, video_mask.shape[-1])
+ # B x N_v x 3 x H x W - > (B x N_v) x 3 x H x W
+ video = torch.as_tensor(video).float()
+ if len(video.size()) == 5:
+ b, n_v, d, h, w = video.shape
+ video = video.view(b * n_v, d, h, w)
+ else:
+ b, pair, bs, ts, channel, h, w = video.shape
+ video = video.view(b * pair * bs * ts, channel, h, w)
+
+ text_feat, video_feat, cls = self.get_text_video_feat(text_ids, text_mask, video, video_mask, shaped=True)
+
+ if self.training:
+ if torch.cuda.is_available(): # batch merge here
+ idx = allgather(idx, self.config)
+ text_feat = allgather(text_feat, self.config)
+ video_feat = allgather(video_feat, self.config)
+ text_mask = allgather(text_mask, self.config)
+ video_mask = allgather(video_mask, self.config)
+ cls = allgather(cls, self.config)
+ torch.distributed.barrier() # force sync
+
+ idx = idx.view(-1, 1)
+ idx_all = idx.t()
+ pos_idx = torch.eq(idx, idx_all).float()
+ sim_targets = pos_idx / pos_idx.sum(1, keepdim=True)
+ logit_scale = self.clip.logit_scale.exp()
+ loss = 0.
+
+ # entity level
+ logits, text_weight, video_weight = self.entity_level(text_feat, cls, video_feat,
+ text_mask, video_mask)
+ logits = torch.diagonal(logits, dim1=0, dim2=1).permute(2, 0, 1)
+ y = self.banzhafinteraction(logits.clone().detach(), text_mask, video_mask, text_weight, video_weight).detach()
+ p = self.banzhafteacher(logits.unsqueeze(1)).squeeze(1)
+
+ p = torch.einsum('btv,bt->btv', [p, text_mask])
+ p = torch.einsum('btv,bv->btv', [p, video_mask])
+
+ loss += self.mse(y, p)
+
+ return loss
+ else:
+ return None
+
+ def entity_level(self, text_feat, cls, video_feat, text_mask, video_mask):
+ if self.config.interaction == 'wti':
+ text_weight = self.text_weight_fc(text_feat).squeeze(2) # B x N_t x D -> B x N_t
+ text_weight.masked_fill_((1 - text_mask).to(torch.bool), float(-9e15))
+ text_weight = torch.softmax(text_weight, dim=-1) # B x N_t
+
+ video_weight = self.video_weight_fc(video_feat).squeeze(2) # B x N_v x D -> B x N_v
+ video_weight.masked_fill_((1 - video_mask).to(torch.bool), float(-9e15))
+ video_weight = torch.softmax(video_weight, dim=-1) # B x N_v
+
+ text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True)
+ video_feat = video_feat / video_feat.norm(dim=-1, keepdim=True)
+
+ retrieve_logits = torch.einsum('atd,bvd->abtv', [text_feat, video_feat])
+ retrieve_logits = torch.einsum('abtv,at->abtv', [retrieve_logits, text_mask])
+ retrieve_logits = torch.einsum('abtv,bv->abtv', [retrieve_logits, video_mask])
+
+ text_sum = text_mask.sum(-1)
+ video_sum = video_mask.sum(-1)
+
+ if self.config.interaction == 'wti': # weighted token-wise interaction
+ t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt
+ t2v_logits = torch.einsum('abt,at->ab', [t2v_logits, text_weight])
+
+ v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv
+ v2t_logits = torch.einsum('abv,bv->ab', [v2t_logits, video_weight])
+
+ _retrieve_logits = (t2v_logits + v2t_logits) / 2.0
+ else:
+ # max for video token
+ t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt
+ v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv
+ t2v_logits = torch.sum(t2v_logits, dim=2) / (text_sum.unsqueeze(1))
+ v2t_logits = torch.sum(v2t_logits, dim=2) / (video_sum.unsqueeze(0))
+ _retrieve_logits = (t2v_logits + v2t_logits) / 2.0
+
+ return retrieve_logits, text_weight, video_weight
+
+ def get_text_feat(self, text_ids, text_mask, shaped=False):
+ if shaped is False:
+ text_ids = text_ids.view(-1, text_ids.shape[-1])
+ text_mask = text_mask.view(-1, text_mask.shape[-1])
+
+ bs_pair = text_ids.size(0)
+ cls, text_feat = self.clip.encode_text(text_ids, return_hidden=True, mask=text_mask)
+ cls, text_feat = cls.float(), text_feat.float()
+ text_feat = text_feat.view(bs_pair, -1, text_feat.size(-1))
+ cls = cls.view(bs_pair, -1, cls.size(-1)).squeeze(1)
+ return text_feat, cls
+
+ def get_video_feat(self, video, video_mask, shaped=False):
+ if shaped is False:
+ video_mask = video_mask.view(-1, video_mask.shape[-1])
+ video = torch.as_tensor(video).float()
+ if len(video.size()) == 5:
+ b, n_v, d, h, w = video.shape
+ video = video.view(b * n_v, d, h, w)
+ else:
+ b, pair, bs, ts, channel, h, w = video.shape
+ video = video.view(b * pair * bs * ts, channel, h, w)
+
+ bs_pair, n_v = video_mask.size()
+ video_feat = self.clip.encode_image(video, return_hidden=True)[0].float()
+ video_feat = video_feat.float().view(bs_pair, -1, video_feat.size(-1))
+ video_feat = self.agg_video_feat(video_feat, video_mask, self.agg_module)
+ return video_feat
+
+ def get_text_video_feat(self, text_ids, text_mask, video, video_mask, shaped=False):
+ if shaped is False:
+ text_ids = text_ids.view(-1, text_ids.shape[-1])
+ text_mask = text_mask.view(-1, text_mask.shape[-1])
+ video_mask = video_mask.view(-1, video_mask.shape[-1])
+ video = torch.as_tensor(video).float()
+ if len(video.shape) == 5:
+ b, n_v, d, h, w = video.shape
+ video = video.view(b * n_v, d, h, w)
+ else:
+ b, pair, bs, ts, channel, h, w = video.shape
+ video = video.view(b * pair * bs * ts, channel, h, w)
+
+ text_feat, cls = self.get_text_feat(text_ids, text_mask, shaped=True)
+ video_feat = self.get_video_feat(video, video_mask, shaped=True)
+
+ return text_feat, video_feat, cls
+
+ def get_video_avg_feat(self, video_feat, video_mask):
+ video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1)
+ video_feat = video_feat * video_mask_un
+ video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float)
+ video_mask_un_sum[video_mask_un_sum == 0.] = 1.
+ video_feat = torch.sum(video_feat, dim=1) / video_mask_un_sum
+ return video_feat
+
+ def get_text_sep_feat(self, text_feat, text_mask):
+ text_feat = text_feat.contiguous()
+ text_feat = text_feat[torch.arange(text_feat.shape[0]), torch.sum(text_mask, dim=-1) - 1, :]
+ text_feat = text_feat.unsqueeze(1).contiguous()
+ return text_feat
+
+ def agg_video_feat(self, video_feat, video_mask, agg_module):
+ video_feat = video_feat.contiguous()
+ if agg_module == "None":
+ pass
+ elif agg_module == "seqLSTM":
+ # Sequential type: LSTM
+ video_feat_original = video_feat
+ video_feat = pack_padded_sequence(video_feat, torch.sum(video_mask, dim=-1).cpu(),
+ batch_first=True, enforce_sorted=False)
+ video_feat, _ = self.lstm_visual(video_feat)
+ if self.training: self.lstm_visual.flatten_parameters()
+ video_feat, _ = pad_packed_sequence(video_feat, batch_first=True)
+ video_feat = torch.cat(
+ (video_feat, video_feat_original[:, video_feat.size(1):, ...].contiguous()), dim=1)
+ video_feat = video_feat + video_feat_original
+ elif agg_module == "seqTransf":
+ # Sequential type: Transformer Encoder
+ video_feat_original = video_feat
+ seq_length = video_feat.size(1)
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=video_feat.device)
+ position_ids = position_ids.unsqueeze(0).expand(video_feat.size(0), -1)
+ frame_position_embeddings = self.frame_position_embeddings(position_ids)
+ video_feat = video_feat + frame_position_embeddings
+ extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0
+ extended_video_mask = extended_video_mask.expand(-1, video_mask.size(1), -1)
+ video_feat = video_feat.permute(1, 0, 2) # NLD -> LND
+ video_feat = self.transformerClip(video_feat, extended_video_mask)
+ video_feat = video_feat.permute(1, 0, 2) # LND -> NLD
+ video_feat = video_feat + video_feat_original
+ return video_feat
+
+ @property
+ def dtype(self):
+ """
+ :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ try:
+ return next(self.parameters()).dtype
+ except StopIteration:
+ # For nn.DataParallel compatibility in PyTorch 1.5
+ def find_tensor_attributes(module: nn.Module):
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
+ return tuples
+
+ gen = self._named_members(get_members_fn=find_tensor_attributes)
+ first_tuple = next(gen)
+ return first_tuple[1].dtype
+
+ def init_weights(self, module):
+ """ Initialize the weights.
+ """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=0.02)
+ elif isinstance(module, LayerNorm):
+ if 'beta' in dir(module) and 'gamma' in dir(module):
+ module.beta.data.zero_()
+ module.gamma.data.fill_(1.0)
+ else:
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
\ No newline at end of file
diff --git a/README.md b/README.md
index c36300f..f2c92a2 100644
--- a/README.md
+++ b/README.md
@@ -39,6 +39,7 @@ If you find this paper useful, please consider staring 🌟 this repo and citing
## 📣 Updates
+* Oct 11 2023: Release code for Banzhaf Interaction estimator.
* Oct 08 2023: I am working on the code for Banzhaf Interaction estimator, which is expected to be released soon.
* Jun 28 2023: Release code for reimplementing the experiments in the paper.
* Mar 28 2023: Our **HBI** has been selected as a Highlight paper at CVPR 2023! (Top 2.5% of 9155 submissions).
@@ -127,8 +128,40 @@ wget https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702
#### Train the Banzhaf Interaction Estimator
-Train the estimator according to the label generated by the BanzhafInteraction in HBI/models/banzhaf.py.
-Training code is under preparation...
+Train the estimator according to the label generated by the BanzhafInteraction in HBI/models/banzhaf.py. The training code is provided in banzhaf_estimator.py.
+
+Recommended running parameters will be provided shortly, and we will also release our pre-trained estimator weights.
+
+
+
+| Models | Google Cloud | Baidu Yun |Peking University Yun|
+|:-----------:|:------------:|:---------:|:-----------:|
+| Estimator | TODO | TODO | TODO |
+
+
+
+```shell
+CUDA_VISIBLE_DEVICES=0,1,2,3 \
+python -m torch.distributed.launch \
+--master_port 2502 \
+--nproc_per_node=4 \
+banzhaf_estimator.py \
+--do_train 1 \
+--workers 8 \
+--n_display 1 \
+--epochs ${epochs} \
+--lr ${learning rate} \
+--coef_lr 1e-3 \
+--batch_size 128 \
+--batch_size_val 128 \
+--anno_path data/MSR-VTT/anns \
+--video_path ${DATA_PATH}/MSRVTT_Videos \
+--datatype msrvtt \
+--max_words 24 \
+--max_frames 12 \
+--video_framerate 1 \
+--output_dir ${OUTPUT_PATH}
+```
### Text-video Retrieval
diff --git a/banzhaf_estimator.py b/banzhaf_estimator.py
new file mode 100644
index 0000000..42292d5
--- /dev/null
+++ b/banzhaf_estimator.py
@@ -0,0 +1,360 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import unicode_literals
+from __future__ import print_function
+
+import os
+import time
+import random
+import argparse
+import numpy as np
+from tqdm import tqdm
+import datetime
+from os.path import join, exists
+
+import torch
+
+from HBI.models.tokenization_clip import SimpleTokenizer as ClipTokenizer
+from HBI.dataloaders.data_dataloaders import DATALOADER_DICT
+from HBI.dataloaders.dataloader_msrvtt_retrieval import MSRVTTDataset
+from HBI.models.modeling_estimator import HBI, AllGather
+from HBI.models.optimization import BertAdam
+from HBI.utils.metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim
+
+from HBI.utils.comm import is_main_process, synchronize
+from HBI.utils.logger import setup_logger
+from HBI.utils.metric_logger import MetricLogger
+
+allgather = AllGather.apply
+
+global logger
+
+
+def get_args(
+ description='Video-Text as Game Players: Hierarchical Banzhaf Interaction for Cross-Modal Representation Learning'):
+ parser = argparse.ArgumentParser(description=description)
+ parser.add_argument("--do_train", type=int, default=0, help="Whether to run training.")
+ parser.add_argument("--do_eval", type=int, default=0, help="Whether to run evaluation.")
+
+ parser.add_argument("--datatype", default="msrvtt", type=str, help="Point the dataset to finetune.")
+ parser.add_argument('--anno_path', type=str, default='data/MSR-VTT/anns', help='annotation path')
+ parser.add_argument('--video_path', type=str, default='data/MSR-VTT/videos', help='video path')
+
+ parser.add_argument('--seed', type=int, default=42, help='random seed')
+ parser.add_argument('--workers', default=4, type=int, help='number of data loading workers (default: 4)')
+ parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate')
+ parser.add_argument('--coef_lr', type=float, default=1e-3, help='coefficient for bert branch.')
+ parser.add_argument("--warmup_proportion", default=0.1, type=float,
+ help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% of training.")
+ parser.add_argument('--weight_decay', type=float, default=0.2, help='weight decay')
+ parser.add_argument('--epochs', type=int, default=5, help='upper epoch limit')
+ parser.add_argument('--batch_size', type=int, default=128, help='batch size')
+ parser.add_argument('--batch_size_val', type=int, default=128, help='batch size eval')
+
+ parser.add_argument('--max_words', type=int, default=32, help='max text token number')
+ parser.add_argument('--max_frames', type=int, default=12, help='max key frames')
+ parser.add_argument('--video_framerate', type=int, default=1, help='framerate to sample video frame')
+
+ parser.add_argument("--device", default='cpu', type=str, help="cpu/cuda")
+ parser.add_argument("--world_size", default=1, type=int, help="distribted training")
+ parser.add_argument("--local_rank", default=0, type=int, help="distribted training")
+ parser.add_argument("--distributed", default=0, type=int, help="multi machine DDP")
+
+ parser.add_argument('--n_display', type=int, default=50, help='Information display frequence')
+ parser.add_argument("--output_dir", default=None, type=str, required=True,
+ help="The output directory where the model predictions and checkpoints will be written.")
+
+ parser.add_argument("--base_encoder", default="ViT-B/32", type=str, help="Choose a CLIP version")
+ parser.add_argument('--agg_module', type=str, default="seqTransf", choices=["None", "seqLSTM", "seqTransf"],
+ help="choice a feature aggregation module for video.")
+ parser.add_argument('--interaction', type=str, default='wti', help="interaction type for retrieval.")
+ parser.add_argument('--num_hidden_layers', type=int, default=4)
+
+ parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.")
+ args = parser.parse_args()
+
+ return args
+
+
+def set_seed_logger(args):
+ global logger
+ # predefining random initial seeds
+ random.seed(args.seed)
+ os.environ['PYTHONHASHSEED'] = str(args.seed)
+ np.random.seed(args.seed)
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU.
+ torch.backends.cudnn.benchmark = False
+ torch.backends.cudnn.deterministic = True
+
+ if torch.cuda.is_available():
+ torch.distributed.init_process_group(backend="nccl")
+ torch.cuda.set_device(args.local_rank)
+ args.device = torch.device("cuda", args.local_rank)
+ args.world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
+ if torch.cuda.is_available():
+ torch.distributed.barrier()
+ logger.info("local_rank: {} world_size: {}".format(args.local_rank, args.world_size))
+
+ if args.batch_size % args.world_size != 0 or args.batch_size_val % args.world_size != 0:
+ raise ValueError(
+ "Invalid batch_size/batch_size_val and world_size parameter: {}%{} and {}%{}, should be == 0".format(
+ args.batch_size, args.world_size, args.batch_size_val, args.world_size))
+
+ logger.info("Effective parameters:")
+ for key in sorted(args.__dict__):
+ logger.info(" <<< {}: {}".format(key, args.__dict__[key]))
+
+ return args
+
+
+def build_model(args):
+ model = HBI(args)
+ if args.init_model:
+ if not exists(args.init_model):
+ raise FileNotFoundError
+ model_state_dict = torch.load(args.init_model, map_location='cpu')
+ model.load_state_dict(model_state_dict, strict=False)
+
+ model.to(args.device)
+ return model
+
+
+def build_dataloader(args):
+ ## ####################################
+ # dataloader loading
+ ## ####################################
+ tokenizer = ClipTokenizer()
+ assert args.datatype in DATALOADER_DICT
+
+ assert DATALOADER_DICT[args.datatype]["test"] is not None or DATALOADER_DICT[args.datatype]["val"] is not None
+
+ test_dataloader, test_length = None, 0
+ if DATALOADER_DICT[args.datatype]["test"] is not None:
+ test_dataloader, test_length = DATALOADER_DICT[args.datatype]["test"](args, tokenizer)
+
+ if DATALOADER_DICT[args.datatype]["val"] is not None:
+ val_dataloader, val_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val")
+ else:
+ val_dataloader, val_length = test_dataloader, test_length
+
+ ## report validation results if the ["test"] is None
+ if test_dataloader is None:
+ test_dataloader, test_length = val_dataloader, val_length
+
+ if isinstance(test_length, int):
+ logger.info("***** Running test *****")
+ logger.info(" Num examples = %d", test_length)
+ logger.info(" Batch size = %d", args.batch_size_val)
+ logger.info(" Num steps = %d", len(test_dataloader))
+ logger.info("***** Running val *****")
+ logger.info(" Num examples = %d", val_length)
+ elif len(test_length) == 2:
+ logger.info("***** Running test *****")
+ logger.info(" Num examples = %dt %dv", test_length[0], test_length[1])
+ logger.info(" Batch size = %d", args.batch_size_val)
+ logger.info(" Num steps = %d %d", len(test_dataloader[0]), len(test_dataloader[1]))
+ logger.info("***** Running val *****")
+ logger.info(" Num examples = %dt %dv", val_length[0], val_length[1])
+
+ if args.do_train:
+ train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer)
+ logger.info("***** Running training *****")
+ logger.info(" Num examples = %d", train_length)
+ logger.info(" Batch size = %d", args.batch_size)
+ logger.info(" Num steps = %d", len(train_dataloader) * args.epochs)
+ else:
+ train_dataloader, train_sampler = None, None
+
+ return test_dataloader, val_dataloader, train_dataloader, train_sampler
+
+
+def prep_optimizer(args, model, num_train_optimization_steps, local_rank):
+ if hasattr(model, 'module'):
+ model = model.module
+ lr = args.lr # 0.0001
+ coef_lr = args.coef_lr # 0.001
+ weight_decay = args.weight_decay # 0.2
+ warmup_proportion = args.warmup_proportion
+ param_optimizer = list(model.named_parameters())
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
+
+ decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)]
+ no_decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)]
+
+ decay_clip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." in n]
+ decay_noclip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." not in n]
+
+ no_decay_clip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." in n]
+ no_decay_noclip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." not in n]
+
+ optimizer_grouped_parameters = [
+ {'params': [p for n, p in decay_clip_param_tp], 'weight_decay': weight_decay, 'lr': lr * coef_lr},
+ {'params': [p for n, p in decay_noclip_param_tp], 'weight_decay': weight_decay},
+ {'params': [p for n, p in no_decay_clip_param_tp], 'weight_decay': 0.0, 'lr': lr * coef_lr},
+ {'params': [p for n, p in no_decay_noclip_param_tp], 'weight_decay': 0.0}
+ ]
+
+ scheduler = None
+
+ optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=warmup_proportion,
+ schedule='warmup_cosine', b1=0.9, b2=0.98, e=1e-6,
+ t_total=num_train_optimization_steps, weight_decay=weight_decay,
+ max_grad_norm=1.0)
+
+ if torch.cuda.is_available():
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank,
+ find_unused_parameters=True)
+ return optimizer, scheduler, model
+
+
+def save_model(epoch, args, model, type_name=""):
+ # Only save the model it-self
+ model_to_save = model.module.banzhafteacher if hasattr(model, 'module') else model.banzhafteacher
+ output_model_file = join(
+ args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name == "" else type_name + ".", epoch))
+ torch.save(model_to_save.state_dict(), output_model_file)
+ logger.info("Model saved to %s", output_model_file)
+ return output_model_file
+
+
+def reduce_loss(loss, args):
+ world_size = args.world_size
+ if world_size < 2:
+ return loss
+ with torch.no_grad():
+ torch.distributed.reduce(loss, dst=0)
+ if torch.distributed.get_rank() == 0:
+ # only main process gets accumulated, so only divide by
+ # world_size in this case
+ loss /= world_size
+ return loss
+
+
+def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer,
+ scheduler, global_step, max_steps, val_dataloader):
+ global logger
+ global best_score
+ global meters
+
+ torch.cuda.empty_cache()
+ model.train()
+ log_step = args.n_display
+ total_loss = 0
+
+ end = time.time()
+ logit_scale = 0
+ for step, batch in enumerate(train_dataloader, start=1):
+ global_step += 1
+ data_time = time.time() - end
+
+ if n_gpu == 1:
+ # multi-gpu does scattering it-self
+ batch = tuple(t.to(device=device, non_blocking=True) for t in batch)
+
+ text_ids, text_mask, video, video_mask, inds, idx = batch
+ loss = model(text_ids, text_mask, video, video_mask, idx, global_step)
+
+ if n_gpu > 1:
+ # print(loss.shape)
+ loss = loss.mean() # mean() to average on multi-gpu.
+
+ with torch.autograd.detect_anomaly():
+ loss.backward()
+
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
+
+ optimizer.step()
+
+ if scheduler is not None:
+ scheduler.step() # Update learning rate schedule
+
+ optimizer.zero_grad()
+
+ batch_time = time.time() - end
+ end = time.time()
+
+ reduced_l = reduce_loss(loss, args)
+ meters.update(time=batch_time, data=data_time, loss=float(reduced_l))
+
+ eta_seconds = meters.time.global_avg * (max_steps - global_step)
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+
+ if (global_step % log_step == 0 or global_step == 1) and is_main_process():
+ logger.info(
+ meters.delimiter.join(
+ [
+ "eta: {eta}",
+ "epoch: {epoch}/{max_epoch}",
+ "iteration: {iteration}/{max_iteration}",
+ "{meters}",
+ "lr: {lr}",
+ "logit_scale: {logit_scale:.2f}"
+ "max mem: {memory:.0f}",
+ ]
+ ).format(
+ eta=eta_string,
+ epoch=epoch,
+ max_epoch=args.epochs,
+ iteration=global_step,
+ max_iteration=max_steps,
+ meters=str(meters),
+ lr="/".join([str('%.9f' % itm) for itm in sorted(list(set(optimizer.get_lr())))]),
+ logit_scale=logit_scale,
+ memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0,
+ )
+ )
+
+ total_loss = total_loss / len(train_dataloader)
+ return total_loss, global_step
+
+
+def main():
+ global logger
+ global best_score
+ global meters
+
+ meters = MetricLogger(delimiter=" ")
+ args = get_args()
+ if not exists(args.output_dir):
+ os.makedirs(args.output_dir, exist_ok=True)
+ logger = setup_logger('tvr', args.output_dir, args.local_rank)
+
+ args = set_seed_logger(args)
+
+ model = build_model(args)
+
+ test_dataloader, val_dataloader, train_dataloader, train_sampler = build_dataloader(args)
+
+ ## ####################################
+ # train and eval
+ ## ####################################
+ if args.do_train:
+ tic = time.time()
+ max_steps = len(train_dataloader) * args.epochs
+ _max_steps = len(train_dataloader) * 5
+ optimizer, scheduler, model = prep_optimizer(args, model, _max_steps, args.local_rank)
+
+ best_score = 0.00001
+ best_output_model_file = "None"
+ global_step = 0
+ for epoch in range(args.epochs):
+ if train_sampler is not None: train_sampler.set_epoch(epoch)
+ synchronize()
+ torch.cuda.empty_cache()
+ tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader,
+ args.device, args.world_size, optimizer,
+ scheduler, global_step, max_steps, val_dataloader)
+
+ if args.local_rank == 0:
+ output_model_file = save_model(epoch, args, model, type_name="")
+ synchronize()
+
+ toc = time.time() - tic
+ training_time = time.strftime("%Hh %Mmin %Ss", time.gmtime(toc))
+ logger.info("*" * 20 + '\n' + f'training finished with {training_time}' + "*" * 20 + '\n')
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file