-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
source code for VL-LTR #2791
base: release/2.5
Are you sure you want to change the base?
source code for VL-LTR #2791
Conversation
intial pr for VL-LTR
Thanks for your contribution! |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码问题很多,请仔细修改。
- 代码风格有问题,使用pre-commit格式化后再提交。
- 无关的改动删掉;
- 无关的注释、代码删掉;
- 逻辑一致、近似的复用之前的代码,或对之前代码进行扩展,避免代码冗余;
- 改动历史代码时,保证向前兼容。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和之前的 PaddleClas/ppcls/utils/logger.py 有什么区别吗?为什么在这里添加logger文件。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个看起来只是用来下载预训练权重,PaddleClas/ppcls/utils/download.py中已有类似方法可以复用。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件是做什么的?
#k = self.wk(self.norm1k(kx)).reshape(Bk, Nk, self.num_heads, C // | ||
# self.num_heads).permute(0, 2, 1, 3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无用的注释删掉。
x = paddle.expand(qx,(qx.shape[0], kx.shape[0], qx.shape[-1])) | ||
#x = qx.expand(qx.shape[0], kx.shape[0], qx.shape[-1]) | ||
x = paddle.concat((x,v),axis=-1) | ||
#x = torch.cat((x, v), dim=-1) # [Bq, Bk, 2*C] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
类似无用的注释删掉
@@ -58,7 +58,49 @@ def _mix_target(self, targets0, targets1, lam): | |||
def __call__(self, batch): | |||
return batch | |||
|
|||
####### | |||
class MixupOperatorLT(BatchOperator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
能否在之前MimupOperator类的基础上改一下,复用之前的代码,看起来逻辑基本一致。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
合并之前的mixup类
batch_size = batch[1].shape[0] | ||
#batch[0] = paddle.to_tensor(batch[0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要改这里?
@@ -64,7 +64,7 @@ def train_epoch(engine, epoch_id, print_batch_step): | |||
for i in range(len(engine.optimizer)): | |||
engine.scaler.minimize(engine.optimizer[i], scaled) | |||
else: | |||
loss.backward() | |||
loss.backward(retain_graph=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个模型训练时必须保留反向计算图吗?
tools/train.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无关改动删掉
这个PR请重新提到 dev_hackathon4 分支 |
def load_dygraph_pretrain(path=None): | ||
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): | ||
raise ValueError("Model pretrain path {}.pdparams does not " | ||
"exists.".format(path)) | ||
param_state_dict = paddle.load(path + ".pdparams") | ||
return param_state_dict | ||
|
||
|
||
def load_dygraph_pretrain_from_url(pretrained_url, use_ssld=False): | ||
if use_ssld: | ||
pretrained_url = pretrained_url.replace("_pretrained", | ||
"_ssld_pretrained") | ||
local_weight_path = get_weights_path_from_url(pretrained_url).replace( | ||
".pdparams", "") | ||
return load_dygraph_pretrain(path=local_weight_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
测试代码删掉
return tensor | ||
|
||
|
||
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
移到组网代码文件中
return new_pos_embed | ||
|
||
|
||
def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
移到组网代码文件中。
} | ||
|
||
|
||
def interpolate_pos_embed(pos_embed_checkpoint: paddle.Tensor, new_patch_size, num_extra_tokens=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
移到组网代码文件中。
from ......ppcls.utils.download import get_weights_path_from_url | ||
|
||
|
||
MODEL_URLS = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
移到组网代码文件中。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件可以删掉了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个文件可以删掉了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不放在repo中
@@ -58,7 +58,49 @@ def _mix_target(self, targets0, targets1, lam): | |||
def __call__(self, batch): | |||
return batch | |||
|
|||
####### | |||
class MixupOperatorLT(BatchOperator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
合并之前的mixup类
intial pr for VL-LTR