-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathb-384-arc-roundb2.py
45 lines (38 loc) · 1.38 KB
/
b-384-arc-roundb2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
_base_ =[
'./_base_/dataset384_b16.py',
'./_base_/default_runtime.py',
'./_base_/scheduler20e_arc.py'
]
_base_.train_dataloader.dataset.ann_file = "meta/roundb2/train.txt"
custom_imports = dict(imports=['src'], allow_failed_imports=False)
# model settings
model = dict(
type='ImageClassifier',
pretrained = "https://download.openmmlab.com/mmclassification/v0/swin-v2/pretrain/swinv2-base-w12_3rdparty_in21k-192px_20220803-f7dc9763.pth",
backbone=dict(
type='SwinTransformerV2',
arch='base',
img_size=384,
window_size=[24, 24, 24, 12],
pretrained_window_sizes=[12, 12, 12, 6],
drop_path_rate=0.2),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='ArcFaceClsHeadAdaptiveMargin',
num_classes=5000,
in_channels=1024,
number_sub_center=3,
ann_file="./data/ACCV_workshop/meta/roundb2/train.txt",
loss = dict(type='SoftmaxEQLLoss', num_classes=5000),
# loss = dict(type='CrossEntropyLoss', loss_weight=1.0),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=0.02, bias=0.),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.)],),
)
if __name__ == "__main__":
from mmcls.models import build_classifier
import torch
x = torch.rand( (1, 2, 384, 384) )
cla = build_classifier(model)
y = cla(x)
print(y.size())