From 1024c872645f9bd84134ebef42b91e0513d24b3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=87=91=E9=B9=8F=28Peng=2EJ=29?= Date: Wed, 11 Oct 2023 11:55:48 +0800 Subject: [PATCH] banzhaf estimator --- HBI/models/modeling.py | 4 + HBI/models/modeling_estimator.py | 378 +++++++++++++++++++++++++++++++ README.md | 37 ++- banzhaf_estimator.py | 360 +++++++++++++++++++++++++++++ 4 files changed, 777 insertions(+), 2 deletions(-) create mode 100644 HBI/models/modeling_estimator.py create mode 100644 banzhaf_estimator.py diff --git a/HBI/models/modeling.py b/HBI/models/modeling.py index be848e0..a98e0d0 100644 --- a/HBI/models/modeling.py +++ b/HBI/models/modeling.py @@ -225,6 +225,10 @@ def forward(self, text_ids, text_mask, video, video_mask=None, idx=None, global_ banzhaf = self.banzhafmodel(logits.unsqueeze(1)).squeeze(1) with torch.no_grad(): teacher = self.banzhafteacher(logits.unsqueeze(1).clone().detach()).squeeze(1).detach() + teacher = torch.einsum('btv,bt->btv', [teacher, text_mask]) + teacher = torch.einsum('btv,bv->btv', [teacher, video_mask]) + banzhaf = torch.einsum('btv,bt->btv', [banzhaf, text_mask]) + banzhaf = torch.einsum('btv,bv->btv', [banzhaf, video_mask]) s_loss = self.kl(banzhaf, teacher) + self.kl(banzhaf.T, teacher.T) loss += M_loss + self.config.skl * s_loss diff --git a/HBI/models/modeling_estimator.py b/HBI/models/modeling_estimator.py new file mode 100644 index 0000000..2722098 --- /dev/null +++ b/HBI/models/modeling_estimator.py @@ -0,0 +1,378 @@ +import os +from collections import OrderedDict +from types import SimpleNamespace +import torch +from torch import nn +from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence +import torch.nn.functional as F +from .module_clip import CLIP, convert_weights, _PT_NAME +from .module_cross import CrossModel, Transformer as TransformerClip +from .until_module import LayerNorm, AllGather, AllGather2, CrossEn, MSE, ArcCrossEn, KL +import numpy as np +from .banzhaf import BanzhafModule, BanzhafInteraction +from .cluster import CTM, TCBlock + +allgather = AllGather.apply +allgather2 = AllGather2.apply + + +class ResidualLinear(nn.Module): + def __init__(self, d_int: int): + super(ResidualLinear, self).__init__() + + self.fc_relu = nn.Sequential(nn.Linear(d_int, d_int), + nn.ReLU(inplace=True)) + + def forward(self, x): + x = x + self.fc_relu(x) + return x + + +class HBI(nn.Module): + def __init__(self, config): + super(HBI, self).__init__() + + self.config = config + self.interaction = config.interaction + self.agg_module = getattr(config, 'agg_module', 'meanP') + backbone = getattr(config, 'base_encoder', "ViT-B/32") + + assert backbone in _PT_NAME + model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), _PT_NAME[backbone]) + if os.path.exists(model_path): + FileNotFoundError + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location="cpu").eval() + state_dict = model.state_dict() + except RuntimeError: + state_dict = torch.load(model_path, map_location="cpu") + + vision_width = state_dict["visual.conv1.weight"].shape[0] + vision_layers = len( + [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) + vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] + grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) + image_resolution = vision_patch_size * grid_size + + embed_dim = state_dict["text_projection"].shape[1] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) + + self.clip = CLIP(embed_dim, image_resolution, vision_layers, vision_width, vision_patch_size, + context_length, vocab_size, transformer_width, transformer_heads, transformer_layers) + + if torch.cuda.is_available(): + convert_weights(self.clip) # fp16 + + cross_config = SimpleNamespace(**{ + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 512, + "initializer_range": 0.02, + "intermediate_size": 2048, + "max_position_embeddings": 128, + "num_attention_heads": 8, + "num_hidden_layers": 4, + "vocab_size": 512, + "soft_t": 0.07, + }) + cross_config.max_position_embeddings = context_length + cross_config.hidden_size = transformer_width + self.cross_config = cross_config + if self.interaction == 'wti': + self.text_weight_fc = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + self.video_weight_fc = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + + self.text_weight_fc0 = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + self.video_weight_fc0 = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + + self.text_weight_fc1 = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + self.video_weight_fc1 = nn.Sequential( + nn.Linear(transformer_width, 2 * transformer_width), nn.ReLU(inplace=True), + nn.Linear(2 * transformer_width, 1)) + + if self.agg_module in ["seqLSTM", "seqTransf"]: + self.frame_position_embeddings = nn.Embedding(cross_config.max_position_embeddings, + cross_config.hidden_size) + if self.agg_module == "seqTransf": + self.transformerClip = TransformerClip(width=transformer_width, + layers=config.num_hidden_layers, + heads=transformer_heads) + if self.agg_module == "seqLSTM": + self.lstm_visual = nn.LSTM(input_size=cross_config.hidden_size, hidden_size=cross_config.hidden_size, + batch_first=True, bidirectional=False, num_layers=1) + + self.loss_fct = CrossEn(config) + self.loss_arcfct = ArcCrossEn(margin=10) + self.banzhafteacher = BanzhafModule(64) + self.banzhafinteraction = BanzhafInteraction(config.max_words, config.max_frames, 100) + + self.apply(self.init_weights) # random init must before loading pretrain + self.clip.load_state_dict(state_dict, strict=False) + + self.mse = MSE() + self.kl = KL() + + ## ===> Initialization trick [HARD CODE] + new_state_dict = OrderedDict() + + if self.agg_module in ["seqLSTM", "seqTransf"]: + contain_frame_position = False + for key in state_dict.keys(): + if key.find("frame_position_embeddings") > -1: + contain_frame_position = True + break + if contain_frame_position is False: + for key, val in state_dict.items(): + if key == "positional_embedding": + new_state_dict["frame_position_embeddings.weight"] = val.clone() + continue + if self.agg_module in ["seqTransf"] and key.find("transformer.resblocks") == 0: + num_layer = int(key.split(".")[2]) + # cut from beginning + if num_layer < config.num_hidden_layers: + new_state_dict[key.replace("transformer.", "transformerClip.")] = val.clone() + continue + + self.load_state_dict(new_state_dict, strict=False) # only update new state (seqTransf/seqLSTM/tightTransf) + ## <=== End of initialization trick + + for param in self.clip.parameters(): + param.requires_grad = False # not update by gradient + for param in self.transformerClip.parameters(): + param.requires_grad = False # not update by gradient + for param in self.frame_position_embeddings.parameters(): + param.requires_grad = False # not update by gradient + + for param in self.text_weight_fc.parameters(): + param.requires_grad = False # not update by gradient + for param in self.video_weight_fc.parameters(): + param.requires_grad = False # not update by gradient + + + def forward(self, text_ids, text_mask, video, video_mask=None, idx=None, global_step=0): + text_ids = text_ids.view(-1, text_ids.shape[-1]) + text_mask = text_mask.view(-1, text_mask.shape[-1]) + video_mask = video_mask.view(-1, video_mask.shape[-1]) + # B x N_v x 3 x H x W - > (B x N_v) x 3 x H x W + video = torch.as_tensor(video).float() + if len(video.size()) == 5: + b, n_v, d, h, w = video.shape + video = video.view(b * n_v, d, h, w) + else: + b, pair, bs, ts, channel, h, w = video.shape + video = video.view(b * pair * bs * ts, channel, h, w) + + text_feat, video_feat, cls = self.get_text_video_feat(text_ids, text_mask, video, video_mask, shaped=True) + + if self.training: + if torch.cuda.is_available(): # batch merge here + idx = allgather(idx, self.config) + text_feat = allgather(text_feat, self.config) + video_feat = allgather(video_feat, self.config) + text_mask = allgather(text_mask, self.config) + video_mask = allgather(video_mask, self.config) + cls = allgather(cls, self.config) + torch.distributed.barrier() # force sync + + idx = idx.view(-1, 1) + idx_all = idx.t() + pos_idx = torch.eq(idx, idx_all).float() + sim_targets = pos_idx / pos_idx.sum(1, keepdim=True) + logit_scale = self.clip.logit_scale.exp() + loss = 0. + + # entity level + logits, text_weight, video_weight = self.entity_level(text_feat, cls, video_feat, + text_mask, video_mask) + logits = torch.diagonal(logits, dim1=0, dim2=1).permute(2, 0, 1) + y = self.banzhafinteraction(logits.clone().detach(), text_mask, video_mask, text_weight, video_weight).detach() + p = self.banzhafteacher(logits.unsqueeze(1)).squeeze(1) + + p = torch.einsum('btv,bt->btv', [p, text_mask]) + p = torch.einsum('btv,bv->btv', [p, video_mask]) + + loss += self.mse(y, p) + + return loss + else: + return None + + def entity_level(self, text_feat, cls, video_feat, text_mask, video_mask): + if self.config.interaction == 'wti': + text_weight = self.text_weight_fc(text_feat).squeeze(2) # B x N_t x D -> B x N_t + text_weight.masked_fill_((1 - text_mask).to(torch.bool), float(-9e15)) + text_weight = torch.softmax(text_weight, dim=-1) # B x N_t + + video_weight = self.video_weight_fc(video_feat).squeeze(2) # B x N_v x D -> B x N_v + video_weight.masked_fill_((1 - video_mask).to(torch.bool), float(-9e15)) + video_weight = torch.softmax(video_weight, dim=-1) # B x N_v + + text_feat = text_feat / text_feat.norm(dim=-1, keepdim=True) + video_feat = video_feat / video_feat.norm(dim=-1, keepdim=True) + + retrieve_logits = torch.einsum('atd,bvd->abtv', [text_feat, video_feat]) + retrieve_logits = torch.einsum('abtv,at->abtv', [retrieve_logits, text_mask]) + retrieve_logits = torch.einsum('abtv,bv->abtv', [retrieve_logits, video_mask]) + + text_sum = text_mask.sum(-1) + video_sum = video_mask.sum(-1) + + if self.config.interaction == 'wti': # weighted token-wise interaction + t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt + t2v_logits = torch.einsum('abt,at->ab', [t2v_logits, text_weight]) + + v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv + v2t_logits = torch.einsum('abv,bv->ab', [v2t_logits, video_weight]) + + _retrieve_logits = (t2v_logits + v2t_logits) / 2.0 + else: + # max for video token + t2v_logits, max_idx1 = retrieve_logits.max(dim=-1) # abtv -> abt + v2t_logits, max_idx2 = retrieve_logits.max(dim=-2) # abtv -> abv + t2v_logits = torch.sum(t2v_logits, dim=2) / (text_sum.unsqueeze(1)) + v2t_logits = torch.sum(v2t_logits, dim=2) / (video_sum.unsqueeze(0)) + _retrieve_logits = (t2v_logits + v2t_logits) / 2.0 + + return retrieve_logits, text_weight, video_weight + + def get_text_feat(self, text_ids, text_mask, shaped=False): + if shaped is False: + text_ids = text_ids.view(-1, text_ids.shape[-1]) + text_mask = text_mask.view(-1, text_mask.shape[-1]) + + bs_pair = text_ids.size(0) + cls, text_feat = self.clip.encode_text(text_ids, return_hidden=True, mask=text_mask) + cls, text_feat = cls.float(), text_feat.float() + text_feat = text_feat.view(bs_pair, -1, text_feat.size(-1)) + cls = cls.view(bs_pair, -1, cls.size(-1)).squeeze(1) + return text_feat, cls + + def get_video_feat(self, video, video_mask, shaped=False): + if shaped is False: + video_mask = video_mask.view(-1, video_mask.shape[-1]) + video = torch.as_tensor(video).float() + if len(video.size()) == 5: + b, n_v, d, h, w = video.shape + video = video.view(b * n_v, d, h, w) + else: + b, pair, bs, ts, channel, h, w = video.shape + video = video.view(b * pair * bs * ts, channel, h, w) + + bs_pair, n_v = video_mask.size() + video_feat = self.clip.encode_image(video, return_hidden=True)[0].float() + video_feat = video_feat.float().view(bs_pair, -1, video_feat.size(-1)) + video_feat = self.agg_video_feat(video_feat, video_mask, self.agg_module) + return video_feat + + def get_text_video_feat(self, text_ids, text_mask, video, video_mask, shaped=False): + if shaped is False: + text_ids = text_ids.view(-1, text_ids.shape[-1]) + text_mask = text_mask.view(-1, text_mask.shape[-1]) + video_mask = video_mask.view(-1, video_mask.shape[-1]) + video = torch.as_tensor(video).float() + if len(video.shape) == 5: + b, n_v, d, h, w = video.shape + video = video.view(b * n_v, d, h, w) + else: + b, pair, bs, ts, channel, h, w = video.shape + video = video.view(b * pair * bs * ts, channel, h, w) + + text_feat, cls = self.get_text_feat(text_ids, text_mask, shaped=True) + video_feat = self.get_video_feat(video, video_mask, shaped=True) + + return text_feat, video_feat, cls + + def get_video_avg_feat(self, video_feat, video_mask): + video_mask_un = video_mask.to(dtype=torch.float).unsqueeze(-1) + video_feat = video_feat * video_mask_un + video_mask_un_sum = torch.sum(video_mask_un, dim=1, dtype=torch.float) + video_mask_un_sum[video_mask_un_sum == 0.] = 1. + video_feat = torch.sum(video_feat, dim=1) / video_mask_un_sum + return video_feat + + def get_text_sep_feat(self, text_feat, text_mask): + text_feat = text_feat.contiguous() + text_feat = text_feat[torch.arange(text_feat.shape[0]), torch.sum(text_mask, dim=-1) - 1, :] + text_feat = text_feat.unsqueeze(1).contiguous() + return text_feat + + def agg_video_feat(self, video_feat, video_mask, agg_module): + video_feat = video_feat.contiguous() + if agg_module == "None": + pass + elif agg_module == "seqLSTM": + # Sequential type: LSTM + video_feat_original = video_feat + video_feat = pack_padded_sequence(video_feat, torch.sum(video_mask, dim=-1).cpu(), + batch_first=True, enforce_sorted=False) + video_feat, _ = self.lstm_visual(video_feat) + if self.training: self.lstm_visual.flatten_parameters() + video_feat, _ = pad_packed_sequence(video_feat, batch_first=True) + video_feat = torch.cat( + (video_feat, video_feat_original[:, video_feat.size(1):, ...].contiguous()), dim=1) + video_feat = video_feat + video_feat_original + elif agg_module == "seqTransf": + # Sequential type: Transformer Encoder + video_feat_original = video_feat + seq_length = video_feat.size(1) + position_ids = torch.arange(seq_length, dtype=torch.long, device=video_feat.device) + position_ids = position_ids.unsqueeze(0).expand(video_feat.size(0), -1) + frame_position_embeddings = self.frame_position_embeddings(position_ids) + video_feat = video_feat + frame_position_embeddings + extended_video_mask = (1.0 - video_mask.unsqueeze(1)) * -1000000.0 + extended_video_mask = extended_video_mask.expand(-1, video_mask.size(1), -1) + video_feat = video_feat.permute(1, 0, 2) # NLD -> LND + video_feat = self.transformerClip(video_feat, extended_video_mask) + video_feat = video_feat.permute(1, 0, 2) # LND -> NLD + video_feat = video_feat + video_feat_original + return video_feat + + @property + def dtype(self): + """ + :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + try: + return next(self.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + def find_tensor_attributes(module: nn.Module): + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = self._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + def init_weights(self, module): + """ Initialize the weights. + """ + if isinstance(module, (nn.Linear, nn.Embedding)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=0.02) + elif isinstance(module, LayerNorm): + if 'beta' in dir(module) and 'gamma' in dir(module): + module.beta.data.zero_() + module.gamma.data.fill_(1.0) + else: + module.bias.data.zero_() + module.weight.data.fill_(1.0) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() \ No newline at end of file diff --git a/README.md b/README.md index c36300f..f2c92a2 100644 --- a/README.md +++ b/README.md @@ -39,6 +39,7 @@ If you find this paper useful, please consider staring 🌟 this repo and citing

## 📣 Updates +* Oct 11 2023: Release code for Banzhaf Interaction estimator. * Oct 08 2023: I am working on the code for Banzhaf Interaction estimator, which is expected to be released soon. * Jun 28 2023: Release code for reimplementing the experiments in the paper. * Mar 28 2023: Our **HBI** has been selected as a Highlight paper at CVPR 2023! (Top 2.5% of 9155 submissions). @@ -127,8 +128,40 @@ wget https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702 #### Train the Banzhaf Interaction Estimator -Train the estimator according to the label generated by the BanzhafInteraction in HBI/models/banzhaf.py. -Training code is under preparation... +Train the estimator according to the label generated by the BanzhafInteraction in HBI/models/banzhaf.py. The training code is provided in banzhaf_estimator.py. + +Recommended running parameters will be provided shortly, and we will also release our pre-trained estimator weights. + +
+ +| Models | Google Cloud | Baidu Yun |Peking University Yun| +|:-----------:|:------------:|:---------:|:-----------:| +| Estimator | TODO | TODO | TODO | + +
+ +```shell +CUDA_VISIBLE_DEVICES=0,1,2,3 \ +python -m torch.distributed.launch \ +--master_port 2502 \ +--nproc_per_node=4 \ +banzhaf_estimator.py \ +--do_train 1 \ +--workers 8 \ +--n_display 1 \ +--epochs ${epochs} \ +--lr ${learning rate} \ +--coef_lr 1e-3 \ +--batch_size 128 \ +--batch_size_val 128 \ +--anno_path data/MSR-VTT/anns \ +--video_path ${DATA_PATH}/MSRVTT_Videos \ +--datatype msrvtt \ +--max_words 24 \ +--max_frames 12 \ +--video_framerate 1 \ +--output_dir ${OUTPUT_PATH} +``` ### Text-video Retrieval
diff --git a/banzhaf_estimator.py b/banzhaf_estimator.py new file mode 100644 index 0000000..42292d5 --- /dev/null +++ b/banzhaf_estimator.py @@ -0,0 +1,360 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import unicode_literals +from __future__ import print_function + +import os +import time +import random +import argparse +import numpy as np +from tqdm import tqdm +import datetime +from os.path import join, exists + +import torch + +from HBI.models.tokenization_clip import SimpleTokenizer as ClipTokenizer +from HBI.dataloaders.data_dataloaders import DATALOADER_DICT +from HBI.dataloaders.dataloader_msrvtt_retrieval import MSRVTTDataset +from HBI.models.modeling_estimator import HBI, AllGather +from HBI.models.optimization import BertAdam +from HBI.utils.metrics import compute_metrics, tensor_text_to_video_metrics, tensor_video_to_text_sim + +from HBI.utils.comm import is_main_process, synchronize +from HBI.utils.logger import setup_logger +from HBI.utils.metric_logger import MetricLogger + +allgather = AllGather.apply + +global logger + + +def get_args( + description='Video-Text as Game Players: Hierarchical Banzhaf Interaction for Cross-Modal Representation Learning'): + parser = argparse.ArgumentParser(description=description) + parser.add_argument("--do_train", type=int, default=0, help="Whether to run training.") + parser.add_argument("--do_eval", type=int, default=0, help="Whether to run evaluation.") + + parser.add_argument("--datatype", default="msrvtt", type=str, help="Point the dataset to finetune.") + parser.add_argument('--anno_path', type=str, default='data/MSR-VTT/anns', help='annotation path') + parser.add_argument('--video_path', type=str, default='data/MSR-VTT/videos', help='video path') + + parser.add_argument('--seed', type=int, default=42, help='random seed') + parser.add_argument('--workers', default=4, type=int, help='number of data loading workers (default: 4)') + parser.add_argument('--lr', type=float, default=1e-4, help='initial learning rate') + parser.add_argument('--coef_lr', type=float, default=1e-3, help='coefficient for bert branch.') + parser.add_argument("--warmup_proportion", default=0.1, type=float, + help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% of training.") + parser.add_argument('--weight_decay', type=float, default=0.2, help='weight decay') + parser.add_argument('--epochs', type=int, default=5, help='upper epoch limit') + parser.add_argument('--batch_size', type=int, default=128, help='batch size') + parser.add_argument('--batch_size_val', type=int, default=128, help='batch size eval') + + parser.add_argument('--max_words', type=int, default=32, help='max text token number') + parser.add_argument('--max_frames', type=int, default=12, help='max key frames') + parser.add_argument('--video_framerate', type=int, default=1, help='framerate to sample video frame') + + parser.add_argument("--device", default='cpu', type=str, help="cpu/cuda") + parser.add_argument("--world_size", default=1, type=int, help="distribted training") + parser.add_argument("--local_rank", default=0, type=int, help="distribted training") + parser.add_argument("--distributed", default=0, type=int, help="multi machine DDP") + + parser.add_argument('--n_display', type=int, default=50, help='Information display frequence') + parser.add_argument("--output_dir", default=None, type=str, required=True, + help="The output directory where the model predictions and checkpoints will be written.") + + parser.add_argument("--base_encoder", default="ViT-B/32", type=str, help="Choose a CLIP version") + parser.add_argument('--agg_module', type=str, default="seqTransf", choices=["None", "seqLSTM", "seqTransf"], + help="choice a feature aggregation module for video.") + parser.add_argument('--interaction', type=str, default='wti', help="interaction type for retrieval.") + parser.add_argument('--num_hidden_layers', type=int, default=4) + + parser.add_argument("--init_model", default=None, type=str, required=False, help="Initial model.") + args = parser.parse_args() + + return args + + +def set_seed_logger(args): + global logger + # predefining random initial seeds + random.seed(args.seed) + os.environ['PYTHONHASHSEED'] = str(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) # if you are using multi-GPU. + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + if torch.cuda.is_available(): + torch.distributed.init_process_group(backend="nccl") + torch.cuda.set_device(args.local_rank) + args.device = torch.device("cuda", args.local_rank) + args.world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1 + if torch.cuda.is_available(): + torch.distributed.barrier() + logger.info("local_rank: {} world_size: {}".format(args.local_rank, args.world_size)) + + if args.batch_size % args.world_size != 0 or args.batch_size_val % args.world_size != 0: + raise ValueError( + "Invalid batch_size/batch_size_val and world_size parameter: {}%{} and {}%{}, should be == 0".format( + args.batch_size, args.world_size, args.batch_size_val, args.world_size)) + + logger.info("Effective parameters:") + for key in sorted(args.__dict__): + logger.info(" <<< {}: {}".format(key, args.__dict__[key])) + + return args + + +def build_model(args): + model = HBI(args) + if args.init_model: + if not exists(args.init_model): + raise FileNotFoundError + model_state_dict = torch.load(args.init_model, map_location='cpu') + model.load_state_dict(model_state_dict, strict=False) + + model.to(args.device) + return model + + +def build_dataloader(args): + ## #################################### + # dataloader loading + ## #################################### + tokenizer = ClipTokenizer() + assert args.datatype in DATALOADER_DICT + + assert DATALOADER_DICT[args.datatype]["test"] is not None or DATALOADER_DICT[args.datatype]["val"] is not None + + test_dataloader, test_length = None, 0 + if DATALOADER_DICT[args.datatype]["test"] is not None: + test_dataloader, test_length = DATALOADER_DICT[args.datatype]["test"](args, tokenizer) + + if DATALOADER_DICT[args.datatype]["val"] is not None: + val_dataloader, val_length = DATALOADER_DICT[args.datatype]["val"](args, tokenizer, subset="val") + else: + val_dataloader, val_length = test_dataloader, test_length + + ## report validation results if the ["test"] is None + if test_dataloader is None: + test_dataloader, test_length = val_dataloader, val_length + + if isinstance(test_length, int): + logger.info("***** Running test *****") + logger.info(" Num examples = %d", test_length) + logger.info(" Batch size = %d", args.batch_size_val) + logger.info(" Num steps = %d", len(test_dataloader)) + logger.info("***** Running val *****") + logger.info(" Num examples = %d", val_length) + elif len(test_length) == 2: + logger.info("***** Running test *****") + logger.info(" Num examples = %dt %dv", test_length[0], test_length[1]) + logger.info(" Batch size = %d", args.batch_size_val) + logger.info(" Num steps = %d %d", len(test_dataloader[0]), len(test_dataloader[1])) + logger.info("***** Running val *****") + logger.info(" Num examples = %dt %dv", val_length[0], val_length[1]) + + if args.do_train: + train_dataloader, train_length, train_sampler = DATALOADER_DICT[args.datatype]["train"](args, tokenizer) + logger.info("***** Running training *****") + logger.info(" Num examples = %d", train_length) + logger.info(" Batch size = %d", args.batch_size) + logger.info(" Num steps = %d", len(train_dataloader) * args.epochs) + else: + train_dataloader, train_sampler = None, None + + return test_dataloader, val_dataloader, train_dataloader, train_sampler + + +def prep_optimizer(args, model, num_train_optimization_steps, local_rank): + if hasattr(model, 'module'): + model = model.module + lr = args.lr # 0.0001 + coef_lr = args.coef_lr # 0.001 + weight_decay = args.weight_decay # 0.2 + warmup_proportion = args.warmup_proportion + param_optimizer = list(model.named_parameters()) + no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] + + decay_param_tp = [(n, p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)] + no_decay_param_tp = [(n, p) for n, p in param_optimizer if any(nd in n for nd in no_decay)] + + decay_clip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." in n] + decay_noclip_param_tp = [(n, p) for n, p in decay_param_tp if "clip." not in n] + + no_decay_clip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." in n] + no_decay_noclip_param_tp = [(n, p) for n, p in no_decay_param_tp if "clip." not in n] + + optimizer_grouped_parameters = [ + {'params': [p for n, p in decay_clip_param_tp], 'weight_decay': weight_decay, 'lr': lr * coef_lr}, + {'params': [p for n, p in decay_noclip_param_tp], 'weight_decay': weight_decay}, + {'params': [p for n, p in no_decay_clip_param_tp], 'weight_decay': 0.0, 'lr': lr * coef_lr}, + {'params': [p for n, p in no_decay_noclip_param_tp], 'weight_decay': 0.0} + ] + + scheduler = None + + optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=warmup_proportion, + schedule='warmup_cosine', b1=0.9, b2=0.98, e=1e-6, + t_total=num_train_optimization_steps, weight_decay=weight_decay, + max_grad_norm=1.0) + + if torch.cuda.is_available(): + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, + find_unused_parameters=True) + return optimizer, scheduler, model + + +def save_model(epoch, args, model, type_name=""): + # Only save the model it-self + model_to_save = model.module.banzhafteacher if hasattr(model, 'module') else model.banzhafteacher + output_model_file = join( + args.output_dir, "pytorch_model.bin.{}{}".format("" if type_name == "" else type_name + ".", epoch)) + torch.save(model_to_save.state_dict(), output_model_file) + logger.info("Model saved to %s", output_model_file) + return output_model_file + + +def reduce_loss(loss, args): + world_size = args.world_size + if world_size < 2: + return loss + with torch.no_grad(): + torch.distributed.reduce(loss, dst=0) + if torch.distributed.get_rank() == 0: + # only main process gets accumulated, so only divide by + # world_size in this case + loss /= world_size + return loss + + +def train_epoch(epoch, args, model, train_dataloader, device, n_gpu, optimizer, + scheduler, global_step, max_steps, val_dataloader): + global logger + global best_score + global meters + + torch.cuda.empty_cache() + model.train() + log_step = args.n_display + total_loss = 0 + + end = time.time() + logit_scale = 0 + for step, batch in enumerate(train_dataloader, start=1): + global_step += 1 + data_time = time.time() - end + + if n_gpu == 1: + # multi-gpu does scattering it-self + batch = tuple(t.to(device=device, non_blocking=True) for t in batch) + + text_ids, text_mask, video, video_mask, inds, idx = batch + loss = model(text_ids, text_mask, video, video_mask, idx, global_step) + + if n_gpu > 1: + # print(loss.shape) + loss = loss.mean() # mean() to average on multi-gpu. + + with torch.autograd.detect_anomaly(): + loss.backward() + + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + optimizer.step() + + if scheduler is not None: + scheduler.step() # Update learning rate schedule + + optimizer.zero_grad() + + batch_time = time.time() - end + end = time.time() + + reduced_l = reduce_loss(loss, args) + meters.update(time=batch_time, data=data_time, loss=float(reduced_l)) + + eta_seconds = meters.time.global_avg * (max_steps - global_step) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + + if (global_step % log_step == 0 or global_step == 1) and is_main_process(): + logger.info( + meters.delimiter.join( + [ + "eta: {eta}", + "epoch: {epoch}/{max_epoch}", + "iteration: {iteration}/{max_iteration}", + "{meters}", + "lr: {lr}", + "logit_scale: {logit_scale:.2f}" + "max mem: {memory:.0f}", + ] + ).format( + eta=eta_string, + epoch=epoch, + max_epoch=args.epochs, + iteration=global_step, + max_iteration=max_steps, + meters=str(meters), + lr="/".join([str('%.9f' % itm) for itm in sorted(list(set(optimizer.get_lr())))]), + logit_scale=logit_scale, + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0, + ) + ) + + total_loss = total_loss / len(train_dataloader) + return total_loss, global_step + + +def main(): + global logger + global best_score + global meters + + meters = MetricLogger(delimiter=" ") + args = get_args() + if not exists(args.output_dir): + os.makedirs(args.output_dir, exist_ok=True) + logger = setup_logger('tvr', args.output_dir, args.local_rank) + + args = set_seed_logger(args) + + model = build_model(args) + + test_dataloader, val_dataloader, train_dataloader, train_sampler = build_dataloader(args) + + ## #################################### + # train and eval + ## #################################### + if args.do_train: + tic = time.time() + max_steps = len(train_dataloader) * args.epochs + _max_steps = len(train_dataloader) * 5 + optimizer, scheduler, model = prep_optimizer(args, model, _max_steps, args.local_rank) + + best_score = 0.00001 + best_output_model_file = "None" + global_step = 0 + for epoch in range(args.epochs): + if train_sampler is not None: train_sampler.set_epoch(epoch) + synchronize() + torch.cuda.empty_cache() + tr_loss, global_step = train_epoch(epoch, args, model, train_dataloader, + args.device, args.world_size, optimizer, + scheduler, global_step, max_steps, val_dataloader) + + if args.local_rank == 0: + output_model_file = save_model(epoch, args, model, type_name="") + synchronize() + + toc = time.time() - tic + training_time = time.strftime("%Hh %Mmin %Ss", time.gmtime(toc)) + logger.info("*" * 20 + '\n' + f'training finished with {training_time}' + "*" * 20 + '\n') + + +if __name__ == "__main__": + main() \ No newline at end of file