diff --git a/ppcls/arch/backbone/__init__.py b/ppcls/arch/backbone/__init__.py index 669547c8aa..e334bf607b 100644 --- a/ppcls/arch/backbone/__init__.py +++ b/ppcls/arch/backbone/__init__.py @@ -70,6 +70,8 @@ from .model_zoo.peleenet import PeleeNet from .model_zoo.convnext import ConvNeXt_tiny from .model_zoo.cae import cae_base_patch16_224, cae_large_patch16_224 +from .model_zoo.VL_LTR.VL_LTR_finetune import LGR_r50,LGR_vit16 +from .model_zoo.VL_LTR.VL_LTR_pretrain import CVLP_vit16,CVLP_r50 from .variant_models.resnet_variant import ResNet50_last_stage_stride1 from .variant_models.vgg_variant import VGG19Sigmoid diff --git a/ppcls/arch/backbone/model_zoo/VL_LTR/VL_LTR_finetune.py b/ppcls/arch/backbone/model_zoo/VL_LTR/VL_LTR_finetune.py new file mode 100644 index 0000000000..5be6fcc315 --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/VL_LTR/VL_LTR_finetune.py @@ -0,0 +1,369 @@ +import os.path as osp +import os +from typing import Tuple, Union + +import paddle.distributed as dist +import numpy as np +import paddle +from paddle.nn import functional as F +from paddle import nn +from paddle.static.nn import sequence_pad as pad_sequence + +from .utils import trunc_normal_, interpolate_pos_embed +from .VL_LTR_pretrain import ModifiedResNet, VisionTransformer +from paddle.nn.initializer import Assign, Normal, Constant,TruncatedNormal + +__all__ = [ + 'LGR_vit16', + ] + + + +class QuickGELU(nn.Layer): + def forward(self, x: paddle.Tensor): + return x * F.sigmoid(1.702 * x) + +class Mlp(nn.Layer): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + +def masked_fill(x, mask, value): + y = paddle.full(x.shape, value, x.dtype) + return paddle.where(mask, y, x) + +class Attention(nn.Layer): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.norm1q = nn.LayerNorm(dim) + self.norm1k = nn.LayerNorm(dim) + + self.wq = nn.Linear(dim, dim, bias_attr=qkv_bias) + self.wk = nn.Linear(dim, dim, bias_attr=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + + def forward(self, qx: paddle.Tensor, kx: paddle.Tensor, key_padding_mask: paddle.Tensor = None): + assert qx.shape[-1] == kx.shape[-1] and qx.shape[1] == 1 + Bq, _, C = qx.shape + Bk, Nk, _ = kx.shape + q = self.wq(self.norm1q(qx)) + q = paddle.reshape(q,(Bq, 1, self.num_heads, C // self.num_heads)) + q = paddle.transpose(q,(0, 2, 1, 3)) + + k = self.wq(self.norm1k(kx)) + k = paddle.reshape(k,(Bk, Nk, self.num_heads, C // self.num_heads)) + k = paddle.transpose(k,(0, 2, 1, 3)) + + + v = paddle.unsqueeze(kx,axis=1) + attn = paddle.einsum('qhoc,khnc->qkhn', q, k) * self.scale + if key_padding_mask is not None: + attn = masked_fill(attn, paddle.unsqueeze(paddle.unsqueeze(key_padding_mask,axis=0),axis=2),float('-inf')) + attn = F.softmax(attn,axis=-1) + attn = self.attn_drop(attn) + + x = paddle.einsum('khnc,qkhn->qkhc', v, attn) + x = paddle.reshape(x,(Bq, Bk, C)) + + return x + + +class Block(nn.Layer): + + def __init__(self, dim, num_heads, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + op_type='two_branch', num_classes=0, use_constant_norm=False, v_detach=False): + super().__init__() + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, + qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + self.op_type = op_type + self.use_constant_norm = use_constant_norm + self.v_detach = v_detach + if self.op_type == 'concat': + self.fc = nn.Linear(in_features=dim * 2, out_features=1, bias=True) + elif self.op_type == 'add': + self.fc = nn.Linear(in_features=dim, out_features=1, bias=True) + elif self.op_type == 'cosine': + self.fc = None + elif self.op_type == 'two_branch': + self.cos = nn.CosineSimilarity(axis=2, eps=1e-6) + self.visual_fc = nn.Sequential( + nn.Linear(dim, 4 * dim), + nn.ReLU(), + nn.Linear(4 * dim, num_classes)) + else: + self.fc = None + + def forward(self, qx: paddle.Tensor, kx: paddle.Tensor, key_padding_mask: paddle.Tensor = None, logit_scale=None): + v = self.attn(qx, kx, key_padding_mask=key_padding_mask) + if self.op_type == 'concat': + x = paddle.expand(qx,(qx.shape[0], kx.shape[0], qx.shape[-1])) + x = paddle.concat((x,v),axis=-1) + x = self.fc(x) # [Bq, Bk, 1] + elif self.op_type == 'cosine': + if logit_scale is not None: + qx_ = F.normalize(qx, p=2, axis=-1) + if self.v_detach: + v_buff = paddle.linalg.norm(v,axis=-1,keepdim=True).detach() + v_ = v /v_buff + else: + v_ = F.normalize(v, p=2, axis=-1) + x = paddle.einsum('qkc,qoc->qk', v_, qx_) * paddle.exp(logit_scale) + else: + x = paddle.einsum('qkc,qoc->qk', v, qx) + elif self.op_type == 'add': + x = paddle.expand(qx,(qx.shape[0], kx.shape[0], qx.shape[-1])) + x = x + v + x = self.fc(x) # [Bq, Bk, 1] + elif self.op_type == 'two_branch': + x1 = self.visual_fc(paddle.squeeze(qx,axis=1)) + + if logit_scale is not None: + if self.use_constant_norm: + qx_ = F.normalize(qx, p=2, axis=-1) + v_ = v / 21.1578 + x2 = paddle.einsum('qkc,qoc->qk', v_, qx_) * paddle.exp(logit_scale) + else: + qx_ = F.normalize(qx, p=2, axis=-1) + if self.v_detach: + v_buff = paddle.linalg.norm(v,axis=-1,keepdim=True).detach() + v_ = v / v_buff + else: + v_ = F.normalize(v, p=2, axis=-1) + x2 = paddle.einsum('qkc,qoc->qk', v_, qx_) * paddle.exp(logit_scale) + else: + x2 = paddle.einsum('qkc,qoc->qk', v, qx) + + + return x1, x2 + + return paddle.squeeze(x,axis=-1) + + +class LGR(nn.Layer): + def __init__(self, + num_classes: int, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + sent_length: int, + attn_heads: int, + sent_idxs=None, + op_type="two_branch", + use_norm=False, + use_constant_norm=False, + v_detach=False, + img_grad=True, + attn_grad=True, + select_sent=None, + sent_offset=0, + ): + super().__init__() + self.num_classes = num_classes + + self.sent_offset = sent_offset + self.sent_length = sent_length + self.sent_idxs = sent_idxs + self.select_sent = select_sent + + self.use_norm = use_norm + self.img_grad = img_grad + self.attn_grad = attn_grad + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, output_dim=embed_dim, + heads=vision_heads, input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, layers=vision_layers, + heads=vision_heads, output_dim=embed_dim + ) + + if op_type is None: + print("do not use text features") + self.text_embeddings = None + self.text_block = None + self.text_padding_mask = None + self.fc = nn.Linear(embed_dim, num_classes) + else: + self.fc = None + if self.use_norm: + self.logit_scale = self.create_parameter( + (1,), + default_initializer=Assign(paddle.ones([1])* np.log(1 / 0.07)) + ) + self.logit_scale.stop_gradient = True + else: + self.logit_scale = None + self.text_embeddings = self.create_parameter( + (self.num_classes, self.sent_length, embed_dim) + ) + self.text_block = Block(dim=embed_dim, num_heads=attn_heads, + qkv_bias=False, qk_scale=None, drop=0, + attn_drop=0, + op_type=op_type, num_classes=num_classes, + use_constant_norm=use_constant_norm, + v_detach=v_detach + ) + self.text_padding_mask = self.build_key_padding_mask(paddle.to_tensor(self.sent_idxs)) + + if self.img_grad is False: + print('freeze visual norm') + for m in self.visual.parameters(): + if isinstance(m, nn.LayerNorm): + m.eval() + if isinstance(m, nn.BatchNorm2D): + m.eval() + if self.attn_grad is False: + print('freeze attn norm') + for m in self.text_block.attn.parameters(): + if isinstance(m, nn.LayerNorm): + m.eval() + if isinstance(m, nn.BatchNorm2D): + m.eval() + self.initialize_parameters() + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def _init_weights(self, m): + + if isinstance(m, nn.Linear): + TruncatedNormal(std=0.02)(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + Constant(0.0)(m.bias) + + elif isinstance(m, nn.LayerNorm): + Constant(0.0)(m.bias) + Constant(1.0)(m.weight) + + if self.text_block is not None: + Assign(paddle.eye(self.text_block.attn.wq.weight.shape[0]))(self.text_block.attn.wq.weight) + Assign(paddle.eye(self.text_block.attn.wk.weight.shape[0]))(self.text_block.attn.wk.weight) + + + def initialize_parameters(self): + self.apply(self._init_weights) + + def build_key_padding_mask(self, idxs: paddle.Tensor): + #avoid the influence of pad 0 + mask = paddle.arange(0,self.sent_length) + mask = paddle.cast(mask,idxs.dtype) + mask = paddle.unsqueeze(mask,axis=0) + #mask = torch.arange(0, self.sent_length).type_as(idxs).unsqueeze(0) + mask = paddle.expand(mask,(idxs.shape[0], self.sent_length)) + mask = paddle.greater_than(mask, paddle.unsqueeze(idxs,axis=1)-1) + #mask = mask.expand(idxs.shape[0], self.sent_length).gt(idxs.unsqueeze(1) - 1) + return mask + + def load_pretrained_model(self, vis_backbone_path=None, img_grad=True, + attn_grad=True): + + if vis_backbone_path is not None: + self._load_vis_backbone(vis_backbone_path) + self.visual.stop_gradient = not img_grad + self.text_block.attn.stop_gradient = not attn_grad + + + def _load_vis_backbone(self, vis_backbone_path): + assert osp.exists(vis_backbone_path) + pretrained_state_dict = paddle.load( + vis_backbone_path) + + if isinstance(self.visual, VisionTransformer): + num_extra_tokens = 1 + new_size = int((self.visual.positional_embedding.shape[0] - num_extra_tokens) ** 0.5) + new_pos_embed = interpolate_pos_embed(pretrained_state_dict['visual.positional_embedding'], + new_size, num_extra_tokens=num_extra_tokens) + pretrained_state_dict['visual.positional_embedding'] = new_pos_embed + if self.use_norm: + vis_state_dict = { + k: v for k, v in pretrained_state_dict.items() + if k.startswith("visual") or k.startswith('logit_scale') + } + else: + vis_state_dict = { + k: v for k, v in pretrained_state_dict.items() + if k.startswith("visual") + } + + info = self.set_state_dict(vis_state_dict) + print('pretrained visual backbone loaded') + print(info) + + + def encode_image(self, image) -> paddle.Tensor: + self.visual.eval() + x = self.visual(paddle.cast(image,self.dtype)) + #x = self.visual(image.type(self.dtype)) + return x + + def forward(self, x): + x = self.encode_image(x) + if self.text_block is not None: + + x = self.text_block(paddle.unsqueeze(x,axis=1), paddle.cast(self.text_embeddings,x.dtype), + key_padding_mask=self.text_padding_mask, + logit_scale=self.logit_scale) + else: + x = self.fc(x) + return x + + + + +def LGR_vit16(pretrained=False, **kwargs): + cache_root = kwargs["cache_root"] + clip_token_path = os.path.join(cache_root, 'IMNET_LT_text_tokens.pkl') + assert os.path.exists(clip_token_path) + text_tokens = paddle.load(clip_token_path) + sent_idxs = [len(sents) for sents in text_tokens] + + model = LGR( + num_classes= kwargs['class_num'], + embed_dim=512, + image_resolution=224, + vision_layers=12, + vision_width=768, + vision_patch_size=16, + sent_length= kwargs['sent_length'], + attn_heads=1, + use_norm=True, + img_grad=False, + sent_idxs = sent_idxs, + select_sent="rand" + ) + + model.load_pretrained_model( + vis_backbone_path=kwargs['pretrain_model_path'], + img_grad=False + ) + + return model diff --git a/ppcls/arch/backbone/model_zoo/VL_LTR/VL_LTR_pretrain.py b/ppcls/arch/backbone/model_zoo/VL_LTR/VL_LTR_pretrain.py new file mode 100644 index 0000000000..d88ca53229 --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/VL_LTR/VL_LTR_pretrain.py @@ -0,0 +1,458 @@ +from typing import Tuple, Union +import paddle.distributed as dist +import numpy as np +import paddle +from paddle.nn import functional as F +from paddle import nn +from .utils import interpolate_pos_embed,MODEL_URLS,load_dygraph_pretrain_from_url +from paddle.nn.initializer import Assign, Normal, Constant + +__all__ = ["CVLP_vit16"] + + + +class Identity(nn.Layer): + def __init__(self): + super(Identity, self).__init__() + + def forward(self, inputs): + return inputs + + +class QuickGELU(nn.Layer): + def forward(self, x): + return x * nn.functional.sigmoid(1.702 * x) + +class MultiHeadAttention(nn.MultiHeadAttention): + def __init__(self, + embed_dim, + num_heads, + output_dim=None): + super(MultiHeadAttention, self).__init__(embed_dim, num_heads) + self.out_proj = nn.Linear(embed_dim, output_dim or embed_dim) + +class Bottleneck(nn.Layer): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2D(inplanes, planes, 1,bias_attr=False) + self.bn1 = nn.BatchNorm2D(planes) + + self.conv2 = nn.Conv2D(planes, planes, 3, padding=1, bias_attr=False) + self.bn2 = nn.BatchNorm2D(planes) + + self.avgpool = nn.AvgPool2D(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2D(planes, planes * self.expansion, 1, bias_attr=False) + self.bn3 = nn.BatchNorm2D(planes * self.expansion) + + self.relu = nn.ReLU() + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + ("-1", nn.AvgPool2D(stride)), + ("0", nn.Conv2D(inplanes, planes * self.expansion, 1, stride=1, bias_attr=False)), + ("1", nn.BatchNorm2D(planes * self.expansion)) + ) + + def forward(self, x: paddle.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Layer): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = self.create_parameter( + shape=(spacial_dim ** 2 + 1, embed_dim), + default_initializer=Assign( + paddle.randn((spacial_dim ** 2 + 1, embed_dim)) / + embed_dim ** 0.5 + ) + ) + self.add_parameter("positional_embedding", self.positional_embedding) + + self.attn = MultiHeadAttention(embed_dim, num_heads, output_dim) + + def forward(self, x:paddle.Tensor): + x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * + x.shape[3])).transpose((2, 0, 1)) + x = paddle.concat([x.mean(axis=0, keepdim=True), x], axis=0) + x = x + self.positional_embedding.unsqueeze(1) + x = x.transpose((1, 0, 2)) + x = self.attn(query=x, key=x, value=x) + x = x.transpose((1, 0, 2)) + return x[0] + + +class ModifiedResNet(nn.Layer): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): + super().__init__() + self.output_dim = output_dim + self.input_resolution = input_resolution + + # the 3-layer stem + self.conv1 = nn.Conv2D(3, width // 2, kernel_size=3, stride=2, padding=1, bias_attr=False) + self.bn1 = nn.BatchNorm2D(width // 2) + self.conv2 = nn.Conv2D(width // 2, width // 2, kernel_size=3, padding=1, bias_attr=False) + self.bn2 = nn.BatchNorm2D(width // 2) + self.conv3 = nn.Conv2D(width // 2, width, kernel_size=3, padding=1, bias_attr=False) + self.bn3 = nn.BatchNorm2D(width) + self.avgpool = nn.AvgPool2D(2) + self.relu = nn.ReLU() + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + def stem(x): + for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + x = stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: paddle.Tensor): + orig_type = x.dtype + #ret = super().forward(x.type(torch.float32)) + ret = super().forward(paddle.cast(x,paddle.float32)) + return paddle.cast(ret,orig_type) + + + +class ResidualAttentionBlock(nn.Layer): + def __init__(self, d_model: int, n_head: int, attn_mask: paddle.Tensor = None): + super().__init__() + + self.attn = MultiHeadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", QuickGELU()), + ("c_proj", nn.Linear(d_model * 4, d_model)) + ) + self.ln_2 = LayerNorm(d_model) + self.attn_mask = attn_mask + + def attention(self, x: paddle.Tensor): + self.attn_mask = self.attn_mask if self.attn_mask is not None else None + return self.attn(x, x, x, attn_mask=self.attn_mask)[0] + + def forward(self, x: paddle.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Layer): + def __init__(self, width: int, layers: int, heads: int, attn_mask: paddle.Tensor = None): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) + + def forward(self, x: paddle.Tensor): + return self.resblocks(x) + + +class VisionTransformer(nn.Layer): + def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): + super().__init__() + self.input_resolution = input_resolution + self.output_dim = output_dim + self.conv1 = nn.Conv2D(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias_attr=False) + + scale = width ** -0.5 + self.class_embedding = scale * self.create_parameter( + shape=(width,), + default_initializer=Assign( + scale * paddle.randn((width,)) + ) + ) + self.positional_embedding = self.create_parameter( + shape=(width,), + default_initializer=Assign( + scale * + paddle.randn( + ((input_resolution // patch_size) ** 2 + 1, width)) + ) + ) + self.ln_pre = LayerNorm(width) + + self.transformer = Transformer(width, layers, heads) + + self.ln_post = LayerNorm(width) + self.proj = self.create_parameter( + shape=(width,), + default_initializer=Assign( + scale * paddle.randn(((width, output_dim))) + ) + ) + + def forward(self, x: paddle.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = paddle.reshape(x,(x.shape[0], x.shape[1], -1)) + x = paddle.transpose(x,(0, 2, 1)) + x = paddle.concat([paddle.cast(self.class_embedding,x.dtype) + paddle.zeros((x.shape[0],1,x.shape[-1]),dtype=x.dtype),x],axis=1) + x = x + paddle.cast(self.positional_embedding,x.dtype) + x = self.ln_pre(x) + x = self.transformer(x) + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +class CVLP(nn.Layer): + def __init__(self, + embed_dim: int, + # vision + image_resolution: int, + vision_layers: Union[Tuple[int, int, int, int], int], + vision_width: int, + vision_patch_size: int, + # text + context_length: int, + vocab_size: int, + transformer_width: int, + transformer_heads: int, + transformer_layers: int, + model_type = "", + pretrained_clip=None, + ): + super().__init__() + + self.context_length = context_length + self.model_type = model_type + + if isinstance(vision_layers, (tuple, list)): + vision_heads = vision_width * 32 // 64 + self.visual = ModifiedResNet( + layers=vision_layers, + output_dim=embed_dim, + heads=vision_heads, + input_resolution=image_resolution, + width=vision_width + ) + else: + vision_heads = vision_width // 64 + self.visual = VisionTransformer( + input_resolution=image_resolution, + patch_size=vision_patch_size, + width=vision_width, + layers=vision_layers, + heads=vision_heads, + output_dim=embed_dim + ) + self.embed_dim = embed_dim + self.transformer = Transformer( + width=transformer_width, + layers=transformer_layers, + heads=transformer_heads, + attn_mask=self.build_attention_mask() + ) + + self.vocab_size = vocab_size + self.token_embedding = nn.Embedding(vocab_size, transformer_width) + self.positional_embedding = self.create_parameter( + shape=(self.context_length, transformer_width), + default_initializer=Assign( + paddle.empty((self.context_length, transformer_width)) + ) + ) + self.add_parameter("positional_embedding", self.positional_embedding) + self.ln_final = LayerNorm(transformer_width) + + self.text_projection = self.create_parameter( + shape=(transformer_width, embed_dim), + default_initializer=Assign( + paddle.empty((transformer_width, embed_dim)) + ) + ) + self.add_parameter("text_projection", self.text_projection) + self.logit_scale = self.create_parameter( + shape=(1,), + default_initializer=Assign(paddle.ones([1])) + ) + self.add_parameter("logit_scale", self.logit_scale) + + + self.initialize_parameters(pretrained_clip) + + def initialize_parameters(self, pretrained_clip): + Normal(std=0.02)(self.token_embedding.weight) + Normal(std=0.01)(self.positional_embedding) + if isinstance(self.visual, ModifiedResNet): + if self.visual.attnpool is not None: + std = self.embed_dim ** -0.5 + normal_ = Normal(std=std) + normal_(self.visual.attnpool.attn.q_proj.weight) + normal_(self.visual.attnpool.attn.k_proj.weight) + normal_(self.visual.attnpool.attn.v_proj.weight) + normal_(self.visual.attnpool.attn.out_proj.weight) + + for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + #nn.init.zeros_(param) + Constant(value=0.0)(param) + proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) + attn_std = self.transformer.width ** -0.5 + fc_std = (2 * self.transformer.width) ** -0.5 + for block in self.transformer.resblocks: + + normal_ = Normal(std = attn_std) + normal_(block.attn.q_proj.weight) + normal_(block.attn.k_proj.weight) + normal_(block.attn.v_proj.weight) + + + Normal(std=proj_std)(block.attn.out_proj.weight) + Normal(std=fc_std)(block.mlp.c_fc.weight) + Normal(std=proj_std)(block.mlp.c_proj.weight) + + if self.text_projection is not None: + Normal(std=self.transformer.width ** -0.5)(self.text_projection) + if pretrained_clip is not None: + if self.model_type == "vit": + pretrained_state_dict = load_dygraph_pretrain_from_url(MODEL_URLS["CLIP_vit_base_patch16_224"],False) + else: + pretrained_state_dict = paddle.load(pretrained_clip) + #pretrained_state_dict = model.state_dict() + for key in ["input_resolution", "context_length", "vocab_size"]: + if key in pretrained_state_dict: + del pretrained_state_dict[key] + if isinstance(self.visual, VisionTransformer): + num_extra_tokens = 1 + new_size = int((self.visual.positional_embedding.shape[0] - num_extra_tokens) ** 0.5) + new_pos_embed = interpolate_pos_embed(pretrained_state_dict['pos_embed'], + new_size, num_extra_tokens=num_extra_tokens) + pretrained_state_dict['pos_embed'] = new_pos_embed + + info = self.set_state_dict(pretrained_state_dict) + print('loaded pretrained clip.') + print(info) + + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = paddle.empty((self.context_length, self.context_length)) + mask = paddle.full_like(mask,float("-inf")) + mask = paddle.triu(mask,diagonal=1) + return mask + + @property + def dtype(self): + return self.visual.conv1.weight.dtype + + def encode_image(self, image) -> paddle.Tensor: + image = paddle.cast(image,self.dtype) + return self.visual(image) + + def encode_text(self, text) -> paddle.Tensor: + text = paddle.cast(text,paddle.int32) + x = paddle.cast(self.token_embedding(text),self.dtype) + + x = x + paddle.cast(self.positional_embedding,self.dtype) + x = self.transformer(x) + x = paddle.cast(self.ln_final(x),self.dtype) + + # take features from the eot embedding (eot_token is the highest number in each sequence) + _buff = paddle.argmax(text,axis=1) + x = x[paddle.arange(x.shape[0]), _buff] @ self.text_projection + + return x + + def forward(self, x): + image, text = x + image_features = self.encode_image(image) + text_features = self.encode_text(text) + + # normalized features + image_features = image_features / image_features.norm(axis=-1, keepdim=True) + text_features = text_features / text_features.norm(axis=-1, keepdim=True) + + # cosine similarity as logits + logit_scale = paddle.exp(self.logit_scale) + logits_per_image = logit_scale * image_features @ text_features.t() + logits_per_text = logits_per_image.transpose((0, 1)) + + # shape = [global_batch_size, global_batch_size] + return [image.detach(),text.detach()], (logits_per_image, logits_per_text) + + + +def CVLP_vit16(pretrained=False, **kwargs): + model_type = "vit" + pretrained_clip = kwargs['pretrained_clip'] if "pretrained_clip" in kwargs.keys() else None + model = CVLP( + embed_dim=512, + image_resolution=224, + vision_layers=12, + vision_width=768, + vision_patch_size=16, + context_length=kwargs['context_length'] + 2, + vocab_size=49408, + transformer_width=512, + transformer_heads=8, + transformer_layers=12, + model_type = model_type, + pretrained_clip = pretrained_clip + ) + + return model \ No newline at end of file diff --git a/ppcls/arch/backbone/model_zoo/VL_LTR/__init__.py b/ppcls/arch/backbone/model_zoo/VL_LTR/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ppcls/arch/backbone/model_zoo/VL_LTR/utils.py b/ppcls/arch/backbone/model_zoo/VL_LTR/utils.py new file mode 100644 index 0000000000..aa039482b0 --- /dev/null +++ b/ppcls/arch/backbone/model_zoo/VL_LTR/utils.py @@ -0,0 +1,139 @@ +import math +import warnings + +import paddle +from paddle import nn +from paddle.nn import functional as F +from paddle.nn.functional import interpolate +from paddle.autograd import PyLayer +import paddle.distributed as dist +import os +from ......ppcls.utils.download import get_weights_path_from_url + + +MODEL_URLS = { + "CLIP_vit_base_patch32_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/CLIP_vit_base_patch32_224.pdparams", + "CLIP_vit_base_patch16_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/CLIP_vit_base_patch16_224.pdparams", + "CLIP_vit_large_patch14_336": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/CLIP_vit_large_patch14_336.pdparams", + "CLIP_vit_large_patch14_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/CLIP_vit_large_patch14_224.pdparams", + "BEiTv2_vit_base_patch16_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/BEiTv2_vit_base_patch16_224.pdparams", + "BEiTv2_vit_large_patch16_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/BEiTv2_vit_large_patch16_224.pdparams", + "CAE_vit_base_patch16_224": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/CAE_vit_base_patch16_224.pdparams", + 'EVA_vit_giant_patch14': + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/EVA_vit_giant_patch14.pdparams", + "MOCOV3_vit_small": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/MOCOV3_vit_small.pdparams", + "MOCOV3_vit_base": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/MOCOV3_vit_base.pdparams", + "MAE_vit_huge_patch14": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/MAE_vit_huge_patch14.pdparams", + "MAE_vit_large_patch16": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/MAE_vit_large_patch16.pdparams", + "MAE_vit_base_patch16": + "https://paddle-imagenet-models-name.bj.bcebos.com/dygraph/foundation_models/MAE_vit_base_patch16.pdparams", +} + + +def interpolate_pos_embed(pos_embed_checkpoint: paddle.Tensor, new_patch_size, num_extra_tokens=1): + # interpolate position embedding + if pos_embed_checkpoint.ndim > 2: + pos_embed_checkpoint = paddle.squeeze(pos_embed_checkpoint) + embedding_size = pos_embed_checkpoint.shape[1] + # height (== width) for the checkpoint position embedding + orig_size = int((pos_embed_checkpoint.shape[0] - num_extra_tokens) ** 0.5) + # height (== width) for the new position embedding + # class_token and dist_token are kept unchanged + extra_tokens = pos_embed_checkpoint[:num_extra_tokens, :] + # only the position tokens are interpolated + pos_tokens = pos_embed_checkpoint[num_extra_tokens:, :] + pos_tokens = paddle.reshape(pos_tokens,[-1,orig_size, orig_size, embedding_size]) + pos_tokens = paddle.transpose(pos_tokens,(0, 3, 1, 2)) + pos_tokens = interpolate(pos_tokens, size=(new_patch_size, new_patch_size), mode='bicubic', align_corners=False) + + pos_tokens = paddle.transpose(pos_tokens,(0, 2, 3, 1)) + pos_tokens = paddle.flatten(pos_tokens,1,2) + pos_tokens = paddle.squeeze(pos_tokens,axis=0) + + new_pos_embed = paddle.concat((extra_tokens, pos_tokens), axis=0) + return new_pos_embed + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with paddle.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor = paddle.uniform((2 * l - 1, 2 * u - 1)) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor = paddle.erfinv(tensor) + + # Transform to proper mean, std + tensor = paddle.multiply(tensor,std * math.sqrt(2.)) + tensor = paddle.add(tensor,mean) + + # Clamp to ensure it's in the proper range + tensor = paddle.clip(tensor,min=a,max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def load_dygraph_pretrain(path=None): + if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): + raise ValueError("Model pretrain path {}.pdparams does not " + "exists.".format(path)) + param_state_dict = paddle.load(path + ".pdparams") + return param_state_dict + + +def load_dygraph_pretrain_from_url(pretrained_url, use_ssld=False): + if use_ssld: + pretrained_url = pretrained_url.replace("_pretrained", + "_ssld_pretrained") + local_weight_path = get_weights_path_from_url(pretrained_url).replace( + ".pdparams", "") + return load_dygraph_pretrain(path=local_weight_path) \ No newline at end of file diff --git a/ppcls/configs/ImageNet/VL_LTR/VL_LTR_finetune.yaml b/ppcls/configs/ImageNet/VL_LTR/VL_LTR_finetune.yaml new file mode 100644 index 0000000000..943ef85b77 --- /dev/null +++ b/ppcls/configs/ImageNet/VL_LTR/VL_LTR_finetune.yaml @@ -0,0 +1,122 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 120 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + +# model architecture +Arch: + name: LGR_vit16 + class_num: 1000 + context_length: 75 + sent_length: 64 + cache_root: "cached" + pretrained_clip: "vit" + pretrain_model_path: "output/CVLP_vit16/best_model.pdparams" + +# loss function config for traing/eval process +Loss: + Train: + - LGRTwoBrachLoss: + weight: 1.0 + Eval: + - LGRTwoBrachCELoss: + weight: 1.0 + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 1e-3 + epochs: 120 + warmup_start_lr: 0.000001 + step_each_epoch: 30 + warmup_epoch: 0 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetLTDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/ImageNet_LT_train.txt + is_pretrain: False + context_length: 75 + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + batch_transform_ops: + - MixupOperator: + alpha: 0.2 + + sampler: + name: DistributedBatchSampler + batch_size: 40 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetLTDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/ImageNet_LT_val.txt + is_pretrain: False + context_length: 75 + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 40 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Train: + - LGRTwoBranchMetric: + is_mixup: True + topk: [1, 5] + Eval: + - LGRTwoBranchMetric: + topk: [1, 5] + + diff --git a/ppcls/configs/ImageNet/VL_LTR/VL_LTR_pretrain.yaml b/ppcls/configs/ImageNet/VL_LTR/VL_LTR_pretrain.yaml new file mode 100644 index 0000000000..5da4f716f5 --- /dev/null +++ b/ppcls/configs/ImageNet/VL_LTR/VL_LTR_pretrain.yaml @@ -0,0 +1,124 @@ +# global configs +Global: + checkpoints: null + pretrained_model: null + output_dir: ./output/ + device: gpu + save_interval: 1 + eval_during_train: True + eval_interval: 1 + epochs: 80 + print_batch_step: 10 + use_visualdl: False + # used for static mode and model export + image_shape: [3, 224, 224] + save_inference_dir: ./inference + +# model architecture +Arch: + name: CVLP_vit16 + class_num: 1000 + context_length: 75 + pretrained_clip: "vit" + +# loss function config for traing/eval process +Loss: + Train: + - PretrainSentLoss: + weight: 1.0 + loss_type: "smoothCE" + context_length: 75 + pretrained_clip: "vit" + distill_type: "logits" + beta: 0.5 + Eval: + - PretrainSentLoss: + weight: 1.0 + loss_type: "CEloss" + context_length: 75 + pretrained_clip: "vit" + distill_type: "logits" + + +Optimizer: + name: Momentum + momentum: 0.9 + lr: + name: Cosine + learning_rate: 0.00005 + epochs: 80 + warmup_start_lr: 0.000001 + step_each_epoch: 30 + warmup_epoch: 5 + + +# data loader for train and eval +DataLoader: + Train: + dataset: + name: ImageNetLTDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/ImageNet_LT_train.txt + is_pretrain: True + context_length: 75 + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - RandCropImage: + size: 224 + - RandFlipImage: + flip_code: 1 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + + + sampler: + name: DistributedBatchSampler + batch_size: 24 + drop_last: False + shuffle: True + loader: + num_workers: 4 + use_shared_memory: True + + Eval: + dataset: + name: ImageNetLTDataset + image_root: ./dataset/ILSVRC2012/ + cls_label_path: ./dataset/ILSVRC2012/ImageNet_LT_val.txt + is_pretrain: True + context_length: 75 + transform_ops: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + sampler: + name: DistributedBatchSampler + batch_size: 24 + drop_last: False + shuffle: False + loader: + num_workers: 4 + use_shared_memory: True + +Metric: + Train: + - TopkAccVLPretrain: + topk: [1, 5] + Eval: + - TopkAccVLPretrain: + topk: [1, 5] + diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index 5f73e7d832..12ef1dab0c 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -22,6 +22,7 @@ from ppcls.data import dataloader # dataset from ppcls.data.dataloader.imagenet_dataset import ImageNetDataset +from ppcls.data.dataloader.imagenetLT_dataset import ImageNetLTDataset from ppcls.data.dataloader.multilabel_dataset import MultiLabelDataset from ppcls.data.dataloader.common_dataset import create_operators from ppcls.data.dataloader.vehicle_dataset import CompCars, VeriWild diff --git a/ppcls/data/dataloader/__init__.py b/ppcls/data/dataloader/__init__.py index 796f4b4584..7ba4de8821 100644 --- a/ppcls/data/dataloader/__init__.py +++ b/ppcls/data/dataloader/__init__.py @@ -11,3 +11,4 @@ from ppcls.data.dataloader.pk_sampler import PKSampler from ppcls.data.dataloader.person_dataset import Market1501, MSMT17 from ppcls.data.dataloader.face_dataset import AdaFaceDataset, FiveValidationDataset +from ppcls.data.dataloader.imagenetLT_dataset import ImageNetLTDataset diff --git a/ppcls/data/dataloader/bpe_simple_vocab_16e6.txt.gz b/ppcls/data/dataloader/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000..7b5088a527 Binary files /dev/null and b/ppcls/data/dataloader/bpe_simple_vocab_16e6.txt.gz differ diff --git a/ppcls/data/dataloader/imagenetLT_dataset.py b/ppcls/data/dataloader/imagenetLT_dataset.py new file mode 100644 index 0000000000..1fc6f0236f --- /dev/null +++ b/ppcls/data/dataloader/imagenetLT_dataset.py @@ -0,0 +1,441 @@ +from __future__ import print_function +from typing import List,Dict + +import numpy as np +import os + +from .common_dataset import CommonDataset +import gzip +import html +import os +from functools import lru_cache +import ftfy +import regex as re +from tqdm import tqdm +import paddle +from ppcls.data import preprocess +from ppcls.data.preprocess import transform +from ppcls.utils import logger + +""" +prompt for CLIP +""" +prompt_templates = [ + 'a bad photo of a {}.', + 'a photo of many {}.', + 'a sculpture of a {}.', + 'a photo of the hard to see {}.', + 'a low resolution photo of the {}.', + 'a rendering of a {}.', + 'graffiti of a {}.', + 'a bad photo of the {}.', + 'a cropped photo of the {}.', + 'a tattoo of a {}.', + 'the embroidered {}.', + 'a photo of a hard to see {}.', + 'a bright photo of a {}.', + 'a photo of a clean {}.', + 'a photo of a dirty {}.', + 'a dark photo of the {}.', + 'a drawing of a {}.', + 'a photo of my {}.', + 'the plastic {}.', + 'a photo of the cool {}.', + 'a close-up photo of a {}.', + 'a black and white photo of the {}.', + 'a painting of the {}.', + 'a painting of a {}.', + 'a pixelated photo of the {}.', + 'a sculpture of the {}.', + 'a bright photo of the {}.', + 'a cropped photo of a {}.', + 'a plastic {}.', + 'a photo of the dirty {}.', + 'a jpeg corrupted photo of a {}.', + 'a blurry photo of the {}.', + 'a photo of the {}.', + 'a good photo of the {}.', + 'a rendering of the {}.', + 'a {} in a video game.', + 'a photo of one {}.', + 'a doodle of a {}.', + 'a close-up photo of the {}.', + 'a photo of a {}.', + 'the origami {}.', + 'the {} in a video game.', + 'a sketch of a {}.', + 'a doodle of the {}.', + 'a origami {}.', + 'a low resolution photo of a {}.', + 'the toy {}.', + 'a rendition of the {}.', + 'a photo of the clean {}.', + 'a photo of a large {}.', + 'a rendition of a {}.', + 'a photo of a nice {}.', + 'a photo of a weird {}.', + 'a blurry photo of a {}.', + 'a cartoon {}.', + 'art of a {}.', + 'a sketch of the {}.', + 'a embroidered {}.', + 'a pixelated photo of a {}.', + 'itap of the {}.', + 'a jpeg corrupted photo of the {}.', + 'a good photo of a {}.', + 'a plushie {}.', + 'a photo of the nice {}.', + 'a photo of the small {}.', + 'a photo of the weird {}.', + 'the cartoon {}.', + 'art of the {}.', + 'a drawing of the {}.', + 'a photo of the large {}.', + 'a black and white photo of a {}.', + 'the plushie {}.', + 'a dark photo of a {}.', + 'itap of a {}.', + 'graffiti of the {}.', + 'a toy {}.', + 'itap of my {}.', + 'a photo of a cool {}.', + 'a photo of a small {}.', + 'a tattoo of the {}.', +] +""" +CLIP text encoder and decoder +""" +@lru_cache() +def default_bpe(): + return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r'\s+', ' ', text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe()): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') + merges = merges[1:49152-256-2+1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v+'' for v in vocab] + for merge in merges: + vocab.append(''.join(merge)) + vocab.extend(['<|startoftext|>', '<|endoftext|>']) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} + self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + ( token[-1] + '',) + pairs = get_pairs(word) + + if not pairs: + return token+'' + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens + + def decode(self, tokens): + text = ''.join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') + return text + +""" +text_loader +""" +class SentPreProcessor(object): + def __init__(self,root:str) -> None: + self.root = root + with open(os.path.join(self.root,"labels.txt"),"r") as f: + data = f.readlines() + _lines = [l.split() for l in data] + self.categories = [{"id": l[1].strip(), "name": l[-1].strip().replace("_", " "), + "wid": l[0].strip()} for l in _lines] + self.categories.sort(key=lambda x: x["wid"]) + self.drop_keys = ['External links', 'References', 'Further reading', 'Bibliography'] + self._tokenizer = SimpleTokenizer() + self.SEP_TOKENS = [267, 269] # [',', '.'] + self.wikis = None + def get_clip_text(self): + if self.wikis is None: + self.wikis = [self._parse_wiki(id) for id in range(len(self.categories))] + naive_text = [self.gen_naive_desc(id) for id in range(len(self.categories))] + wiki_text = [self._get_text(wiki) for wiki in self.wikis] + return [naive_text[i] + wiki_text[i] for i in range(len(self.categories))] + def split_text(self,texts): + pat = re.compile(r'(?"] # 49406 + eot_token = self._tokenizer.encoder["<|endoftext|>"] # 49407 + all_tokens = [[sot_token] + self._tokenizer.encode(text)[:context_length] + [eot_token] for text in + texts] + result = paddle.zeros((len(all_tokens), context_length + 2)) + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length + 2: + raise RuntimeError( + f"Input {texts[i]} is too long for context length {context_length}") + result[i, :len(tokens)] = paddle.to_tensor(tokens,dtype=paddle.float32) + return result + + if isinstance(texts, str): + texts = [texts] + elif isinstance(texts[0], List): + return [_tokenize(text) for text in tqdm(texts)] + return _tokenize(texts) + + def _get_text(self, wiki: Dict): + # use all key part of each wiki text except those in drop_keys + text = wiki["summary"] + "\n" + text += "\n".join([v for k, v in wiki.items() if k not in ["summary"] + self.drop_keys]) + return text + + def _parse_wiki(self, id) -> Dict: + try: + with open(os.path.join(self.root, "wiki", f"desc_{id}.txt")) as rf: + lines = rf.readlines() + except UnicodeDecodeError: + with open(os.path.join(self.root, "wiki", f"desc_{id}.txt"), encoding='gbk') as rf: + lines = rf.readlines() + lines = [d.strip() for d in lines if d.strip() != ''] + ret_dict = {} + key = "summary" + val = "" + for line in lines: + if line[:2] == "==": + ret_dict[key] = val.strip() + key = line.strip('= ') + val = "" + else: + val += line + '\n' + ret_dict[key] = val.strip() + return ret_dict + + + + +""" +dataset +""" +class ImageNetLTDataset(CommonDataset): + """ImageNetDataset + + Args: + image_root (str): image root, path to `ILSVRC2012` + cls_label_path (str): path to annotation file `train_list.txt` or 'val_list.txt` + transform_ops (list, optional): list of transform op(s). Defaults to None. + delimiter (str, optional): delimiter. Defaults to None. + relabel (bool, optional): whether do relabel when original label do not starts from 0 or are discontinuous. Defaults to False. + """ + def __init__(self, + image_root, + cls_label_path, + context_length=75, + is_pretrain=False, + transform_ops=None, + delimiter=None, + relabel=False): + + self.delimiter = delimiter if delimiter is not None else " " + self.relabel = relabel + self.is_pretrain = is_pretrain + self.context_length = context_length + super(ImageNetLTDataset, self).__init__(image_root, cls_label_path, + transform_ops) + def get_sentence_tokens(self, context_length): + print('using clip text tokens splitted by sentence') + cache_root = 'cached' + cache_path = os.path.join(cache_root, 'IMNET_LT_desc_text_sent.pkl') + clip_token_path = os.path.join(cache_root, 'IMNET_LT_text_tokens.pkl') + if os.path.exists(clip_token_path): + text_tokens = paddle.load(clip_token_path) + return text_tokens + + preprocessor = SentPreProcessor(root=self._img_root) + if not os.path.exists(cache_path): + os.makedirs(cache_root, exist_ok=True) + texts = preprocessor.get_clip_text() + texts = preprocessor.split_text(texts) + paddle.save(texts,cache_path) + else: + texts = paddle.load(cache_path) + text_tokens = preprocessor.tokenize(texts, context_length=context_length) + paddle.save(text_tokens,clip_token_path) + return text_tokens + + def __getitem__(self, idx): + try: + with open(self.images[idx], 'rb') as f: + img = f.read() + if self._transform_ops: + img = transform(img, self._transform_ops) + img = img.transpose((2, 0, 1)) + if self.is_pretrain: + sent_idxs = self.end_idxs + text_tokens = self.text_tokens + target = self.labels[idx] + + idx = np.random.randint(sent_idxs[target]) + token = text_tokens[target][idx] + + return ((img,token),target) + return (img, self.labels[idx]) + + except Exception as ex: + logger.error("Exception occured when parse line: {} with msg: {}". + format(self.images[idx], ex)) + rnd_idx = np.random.randint(self.__len__()) + return self.__getitem__(rnd_idx) + + def _load_anno(self, seed=None): + assert os.path.exists( + self._cls_path), f"path {self._cls_path} does not exist." + assert os.path.exists( + self._img_root), f"path {self._img_root} does not exist." + assert os.path.exists( + os.path.join(self._img_root,"wiki")), f"wiki does not exist." + assert os.path.exists( + os.path.join(self._img_root,"labels.txt")), f"labels does not exist." + self.images = [] + self.labels = [] + + with open(self._cls_path) as fd: + lines = fd.readlines() + if self.relabel: + label_set = set() + for line in lines: + line = line.strip().split(self.delimiter) + label_set.add(np.int64(line[1])) + label_map = { + oldlabel: newlabel + for newlabel, oldlabel in enumerate(label_set) + } + + if seed is not None: + np.random.RandomState(seed).shuffle(lines) + for line in lines: + line = line.strip().split(self.delimiter) + self.images.append(os.path.join(self._img_root, line[0])) + if self.relabel: + self.labels.append(label_map[np.int64(line[1])]) + else: + self.labels.append(np.int64(line[1])) + if os.path.exists(self.images[-1]) == False: + self.images.pop() + #assert os.path.exists(self.images[ + # -1]), f"path {self.images[-1]} does not exist." + self.text_tokens = self.get_sentence_tokens(self.context_length) + self.end_idxs = [len(sents) for sents in self.text_tokens] + diff --git a/ppcls/data/preprocess/__init__.py b/ppcls/data/preprocess/__init__.py index 8f9ea02834..a8bbaf3603 100644 --- a/ppcls/data/preprocess/__init__.py +++ b/ppcls/data/preprocess/__init__.py @@ -41,7 +41,7 @@ from ppcls.data.preprocess.ops.operators import RandomRotation from ppcls.data.preprocess.ops.operators import Padv2 -from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator +from ppcls.data.preprocess.batch_ops.batch_operators import MixupOperator, CutmixOperator, OpSampler, FmixOperator,MixupOperatorLT from ppcls.data.preprocess.batch_ops.batch_operators import MixupCutmixHybrid import numpy as np diff --git a/ppcls/data/preprocess/batch_ops/batch_operators.py b/ppcls/data/preprocess/batch_ops/batch_operators.py index c9563e2294..23978c4f84 100644 --- a/ppcls/data/preprocess/batch_ops/batch_operators.py +++ b/ppcls/data/preprocess/batch_ops/batch_operators.py @@ -58,7 +58,49 @@ def _mix_target(self, targets0, targets1, lam): def __call__(self, batch): return batch +####### +class MixupOperatorLT(BatchOperator): + """ Mixup operator + reference: https://arxiv.org/abs/1710.09412 + + """ + + def __init__(self, class_num, alpha: float=1.,is_pretrain=False): + """Build Mixup operator + + Args: + alpha (float, optional): The parameter alpha of mixup. Defaults to 1.. + Raises: + Exception: The value of parameter is illegal. + """ + if alpha <= 0: + raise Exception( + f"Parameter \"alpha\" of Mixup should be greater than 0. \"alpha\": {alpha}." + ) + if not class_num: + msg = "Please set \"Arch.class_num\" in config if use \"MixupOperator\"." + logger.error(Exception(msg)) + raise Exception(msg) + + self._alpha = alpha + self.class_num = class_num + self.is_pretrain = is_pretrain + + def __call__(self, batch): + if self.is_pretrain: + (imgs,text),labels,bs = self._unpack(batch) + else: + imgs, labels, bs = self._unpack(batch) + idx = np.random.permutation(bs) + lam = np.random.beta(self._alpha, self._alpha) + imgs = lam * imgs + (1 - lam) * imgs[idx] + targets = self._mix_target(labels, labels[idx], lam) + if self.is_pretrain: + return list(zip((imgs,text),targets)) + else: + return list(zip(imgs, targets)) +####### class MixupOperator(BatchOperator): """ Mixup operator reference: https://arxiv.org/abs/1710.09412 diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 45eccf71d6..94feef26f3 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -53,8 +53,8 @@ def classification_eval(engine, epoch_id=0): paddle.to_tensor(batch[0]['label']) ] time_info["reader_cost"].update(time.time() - tic) - batch_size = batch[0].shape[0] - batch[0] = paddle.to_tensor(batch[0]) + batch_size = batch[1].shape[0] + #batch[0] = paddle.to_tensor(batch[0]) if not engine.config["Global"].get("use_multilabel", False): batch[1] = batch[1].reshape([-1, 1]).astype("int64") diff --git a/ppcls/engine/train/train.py b/ppcls/engine/train/train.py index a41674da70..b17fd05cf8 100644 --- a/ppcls/engine/train/train.py +++ b/ppcls/engine/train/train.py @@ -34,7 +34,7 @@ def train_epoch(engine, epoch_id, print_batch_step): paddle.to_tensor(batch[0]['data']), paddle.to_tensor(batch[0]['label']) ] - batch_size = batch[0].shape[0] + batch_size = batch[1].shape[0] if not engine.config["Global"].get("use_multilabel", False): batch[1] = batch[1].reshape([batch_size, -1]) engine.global_step += 1 @@ -64,7 +64,7 @@ def train_epoch(engine, epoch_id, print_batch_step): for i in range(len(engine.optimizer)): engine.scaler.minimize(engine.optimizer[i], scaled) else: - loss.backward() + loss.backward(retain_graph=True) if (iter_id + 1) % engine.update_freq == 0: for i in range(len(engine.optimizer)): engine.optimizer[i].step() diff --git a/ppcls/loss/__init__.py b/ppcls/loss/__init__.py index cbe7e266ba..824802fb25 100644 --- a/ppcls/loss/__init__.py +++ b/ppcls/loss/__init__.py @@ -29,6 +29,8 @@ from .distillationloss import DistillationMultiLabelLoss from .distillationloss import DistillationDISTLoss from .distillationloss import DistillationPairLoss +from .vltloss import PretrainSentLoss +from .vltloss import VLT_DistillationLoss,LGRTwoBrachLoss,LGRTwoBrachCELoss from .multilabelloss import MultiLabelLoss from .afdloss import AFDLoss diff --git a/ppcls/loss/vltloss.py b/ppcls/loss/vltloss.py new file mode 100644 index 0000000000..9bf269d11f --- /dev/null +++ b/ppcls/loss/vltloss.py @@ -0,0 +1,177 @@ +import paddle +import paddle.nn as nn +from paddle.nn import functional as F +from ..arch.backbone.model_zoo.VL_LTR.VL_LTR_pretrain import CVLP_vit16 + +def cross_entropy(outputs, teacher_outputs): + logprobs = F.log_softmax(outputs, axis=-1) + soft_targets = F.softmax(teacher_outputs, axis=-1) + distill_loss = -paddle.sum(soft_targets * logprobs, axis=-1) + return paddle.mean(distill_loss) + +def labels2idxs(labels: paddle.Tensor): + #labels = paddle.cast(labels,paddle.int32) + buff = [paddle.cast(labels[i] == labels,dtype=paddle.int32) for i in range(labels.shape[0])] + targets = paddle.stack(buff) + return targets + +def kl_div(outputs1, outputs2, T=1.): + return paddle.log(F.kl_div( + F.log_softmax(outputs1 / T, axis=1), + F.log_softmax(outputs2 / T, axis=1), + reduction='sum', + )) * (T * T) / paddle.numel(outputs1) + +class SoftTargetCrossEntropy(nn.Layer): + + def __init__(self): + super(SoftTargetCrossEntropy, self).__init__() + + def forward(self, x: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor: + target = paddle.cast(target,dtype=paddle.float32) + loss = paddle.sum(-target * paddle.nn.functional.log_softmax(x, axis=-1), axis=-1) + return paddle.mean(loss) + +class LabelSmoothingCrossEntropy(nn.Layer): + """ + NLL loss with label smoothing. + """ + + def __init__(self, smoothing=0.1): + """ + Constructor for the LabelSmoothing module. + :param smoothing: label smoothing factor + """ + super(LabelSmoothingCrossEntropy, self).__init__() + assert smoothing < 1.0 + self.smoothing = smoothing + self.confidence = 1. - smoothing + + def forward(self, x: paddle.Tensor, target: paddle.Tensor): + logprobs = F.log_softmax(x, axis=-1) + smooth_loss = -paddle.mean(logprobs,axis=-1) + #smooth_loss = -logprobs.mean(dim=-1) + if target.ndim == 1: + nll_loss = -paddle.gather(logprobs,index=paddle.unsqueeze(target),axis=-1) + nll_loss = paddle.squeeze(nll_loss,axis=1) + else: + assert target.ndim == 2 + nll_loss = -paddle.sum(target * logprobs, axis=-1) + loss = self.confidence * nll_loss + self.smoothing * smooth_loss + return paddle.mean(loss) + + +class PretrainSentLoss(nn.Layer): + def __init__(self,loss_type: str,context_length,pretrained_clip, args=None, + distill_type='none', alpha=0., beta=0., tau=0., set_training_mode=False): + super().__init__() + self.teacher_model = CVLP_vit16(context_length=context_length,pretrained_clip=pretrained_clip) + self.base_criterion = LabelSmoothingCrossEntropy() + self.loss_type = loss_type + self.alpha = alpha + self.beta = beta + self.tau = tau + assert distill_type in ['none', 'feat', 'logits', 'logits_kl'] + self.distill_type:str = distill_type + if beta > 0: + assert self.distill_type.startswith("logits") + self.teacher_model.stop_gradient = True + + def forward(self, x, labels: paddle.Tensor): + # x \in [inputs, outputs] + inputs, outputs = x + labels = paddle.squeeze(labels) + labels = labels2idxs(labels) + #check the format for labels + if isinstance(outputs, paddle.Tensor): + loss = self.base_criterion(outputs, labels) + return loss + # assume that the model outputs a tuple of outputs + if self.alpha > 0.: + assert self.distill_type.startswith("feat") + # assume that the model outputs a tuple of [outputs1, outputs2, distill_loss] + outputs1, outputs2, distill_loss = outputs + distill_loss = paddle.mean(distill_loss) + else: + # assume that the model outputs a tuple of [outputs1, outputs2] + outputs1, outputs2 = outputs + distill_loss = 0. + if self.loss_type in ["softCE", "smoothCE"]: + labels = labels / paddle.sum(labels, axis=1, keepdim=True) + loss1 = self.base_criterion(outputs1, labels) + loss2 = self.base_criterion(outputs2, labels) + base_loss = (loss1 + loss2) / 2.0 + loss = (1 - self.alpha) * base_loss + self.alpha * distill_loss + if self.beta > 0: + _,(teacher_outputs1, teacher_outputs2) = self.teacher_model(inputs) + teacher_outputs1, teacher_outputs2 = teacher_outputs1.detach(), teacher_outputs2.detach() + if self.distill_type == 'logits_kl': + distill_loss1 = kl_div(outputs1, teacher_outputs1, T=self.tau) + distill_loss2 = kl_div(outputs2, teacher_outputs2, T=self.tau) + distill_loss = (distill_loss1 + distill_loss2) / 2.0 + else: + assert self.distill_type == "logits" + distill_loss1 = cross_entropy(outputs1, teacher_outputs1) + distill_loss2 = cross_entropy(outputs2, teacher_outputs2) + distill_loss = (distill_loss1 + distill_loss2) / 2.0 + loss = (1 - self.beta) * loss + self.beta * distill_loss + return {"pretrain_loss": loss} + +class LGRTwoBrachLoss(nn.Layer): + def __init__(self, teacher_model: nn.Layer = None, + distillation_type: str = 'none', alpha: float = 0.0, tau: float=0.0): + super().__init__() + self.loss = VLT_DistillationLoss(teacher_model,distillation_type, alpha, tau) + def forward(self, x, labels): + output1,output2 = x[0], x[1] + loss1 = self.loss(output1,labels) + loss2 = self.loss(output2,labels) + loss = loss1 + loss2 + return {"two_branch_loss":loss} + +class LGRTwoBrachCELoss(nn.Layer): + def __init__(self, teacher_model: nn.Layer = None, + distillation_type: str = 'none', alpha: float = 0.0, tau: float=0.0): + super().__init__() + self.loss = F.cross_entropy + def forward(self, x, labels): + output1,output2 = x[0], x[1] + loss1 = self.loss(output1,labels) + loss2 = self.loss(output2,labels) + loss = loss1.mean() + loss2.mean() + return {"two_branch_loss":loss} + +class VLT_DistillationLoss(nn.Layer): + """ + This module wraps a standard criterion and adds an extra knowledge distillation loss by + taking a teacher model prediction and using it as additional supervision. + """ + + def __init__(self, teacher_model: nn.Layer = None, + distillation_type: str = 'none', alpha: float = 0.0, tau: float=0.0): + super().__init__() + self.base_criterion = SoftTargetCrossEntropy() + self.teacher_model = teacher_model + assert distillation_type in ['none', 'soft', 'hard'] + self.distillation_type = distillation_type + self.alpha = alpha + self.tau = tau + + def forward(self, x, labels): + outputs = x + """ + Args: + inputs: The original inputs that are feed to the teacher model + outputs: the outputs of the model to be trained. It is expected to be + either a Tensor, or a Tuple[Tensor, Tensor], with the original output + in the first position and the distillation predictions as the second output + labels: the labels for the base criterion + """ + outputs_kd = None + if not isinstance(outputs, paddle.Tensor): + # assume that the model outputs a tuple of [outputs, outputs_kd] + outputs, outputs_kd = outputs + base_loss = self.base_criterion(outputs, labels) + if self.distillation_type == 'none': + return base_loss + return base_loss \ No newline at end of file diff --git a/ppcls/metric/__init__.py b/ppcls/metric/__init__.py index 1f49cc2d9c..8cec2da4f0 100644 --- a/ppcls/metric/__init__.py +++ b/ppcls/metric/__init__.py @@ -16,7 +16,7 @@ from collections import OrderedDict from .avg_metrics import AvgMetrics -from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk +from .metrics import TopkAcc, mAP, mINP, Recallk, Precisionk,TopkAccVLPretrain,LGRTwoBranchMetric from .metrics import DistillationTopkAcc from .metrics import GoogLeNetTopkAcc from .metrics import HammingDistance, AccuracyScore diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 134ef279b2..c4baa987b1 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -30,6 +30,70 @@ from ppcls.utils import logger + +class LGRTwoBranchMetric(AvgMetrics): + def __init__(self,topk,is_mixup=False): + super().__init__() + self.Topk = topk + self._topk = TopkAcc(topk) + self.is_mixup = is_mixup + self.reset() + + def reset(self): + self.avg_meters = { + f"top{k}": AverageMeter(f"top{k}") + for k in self.Topk + } + def forward(self, x, labels): + ## + if self.is_mixup: + labels = paddle.argmax(labels,axis=1) + labels = paddle.unsqueeze(labels,axis=-1) + buff_output = F.softmax(x[0],axis=1) * 0.2 + F.softmax(x[1],axis=1)*(1-0.2) + metric = self._topk(buff_output,labels) + for idx, k in enumerate(metric): + self.avg_meters[k].update(metric[k],x[0].shape[0]) + return metric + + +class TopkAccVLPretrain(AvgMetrics): + def __init__(self, topk=(1, 5)): + super().__init__() + assert isinstance(topk, (int, list, tuple)) + if isinstance(topk, int): + topk = [topk] + self.topk = topk + self.reset() + + def reset(self): + self.avg_meters = { + f"top{k}": AverageMeter(f"top{k}") + for k in self.topk + } + + def forward(self, x, label): + if isinstance(x, dict): + x = x["logits"] + + #image acc + x = x[1][0] + output_dims = x.shape[-1] + + metric_dict = dict() + for idx, k in enumerate(self.topk): + if output_dims < k: + msg = f"The output dims({output_dims}) is less than k({k}), and the argument {k} of Topk has been removed." + logger.warning(msg) + self.avg_meters.pop(f"top{k}") + continue + metric_dict[f"top{k}"] = paddle.metric.accuracy(x, label, k=k) + self.avg_meters[f"top{k}"].update(metric_dict[f"top{k}"], + x.shape[0]) + + self.topk = list(filter(lambda k: k <= output_dims, self.topk)) + + return metric_dict + class TopkAcc(AvgMetrics): def __init__(self, topk=(1, 5)): super().__init__() diff --git a/tools/train.py b/tools/train.py index e7c9d7bcc8..28b743a41c 100644 --- a/tools/train.py +++ b/tools/train.py @@ -22,11 +22,14 @@ from ppcls.utils import config from ppcls.engine.engine import Engine +import paddle if __name__ == "__main__": args = config.parse_args() + args.config = "./ppcls/configs/ImageNet/VL_LTR/VL_LTR_pretrain.yaml" config = config.get_config( args.config, overrides=args.override, show=False) config.profiler_options = args.profiler_options + a = paddle.zeros([1,1]) engine = Engine(config, mode="train") engine.train()