Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add other effv2 models #8

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ Reproduction of EfficientNet V2 architecture as described in [EfficientNetV2: Sm

| Architecture | # Parameters | FLOPs | Top-1 Acc. (%) |
| ----------------- | ------------ | ------ | -------------------------- |
| EfficientNetV2-S | 24.12M | 8.64G @ 384 | |
| EfficientNetV2-M | 55.30M | 24.74G @ 480 | |
| EfficientNetV2-L | 119.36M | 56.13G @ 384 | |
| EfficientNetV2-XL | 208.96M | 93.41G @ 512 | |
| EfficientNetV2-B0 | 7.17M | 0.79G @ 224 | |
| EfficientNetV2-B1 | 8.18M | 1.09G @ 224 | |
| EfficientNetV2-B2 | 10.37M | 1.25G @ 224 | |
| EfficientNetV2-B3 | 14.69M | 1.83G @ 224 | |
| EfficientNetV2-S | 21.10M | 2.90G @ 224 | |
| EfficientNetV2-M | 55.30M | 5.44G @ 224 | |
| EfficientNetV2-L | 119.36M | 12.32G @ 224 | |
| EfficientNetV2-XL | 208.96M | 18.02G @ 224 | |

Stay tuned for ImageNet pre-trained weights.

Expand Down
148 changes: 113 additions & 35 deletions effnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
import torch
import torch.nn as nn
import math
from functools import partial

__all__ = ['effnetv2_s', 'effnetv2_m', 'effnetv2_l', 'effnetv2_xl']
__all__ = ['effnetv2_s', 'effnetv2_m', 'effnetv2_l', 'effnetv2_xl',
'effnetv2_base', 'effnetv2_b0', 'effnetv2_b1', 'effnetv2_b2', 'effnetv2_b3']


def _make_divisible(v, divisor, min_value=None):
def _make_divisible(v, divisor, min_value=None, round_limit=.9):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
Expand All @@ -28,11 +30,18 @@ def _make_divisible(v, divisor, min_value=None):
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
if new_v < round_limit * v:
new_v += divisor
return new_v


def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
return _make_divisible(channels * multiplier, divisor, min_value=channel_min, round_limit=round_limit)


# SiLU (Swish) activation function
if hasattr(nn, 'SiLU'):
SiLU = nn.SiLU
Expand Down Expand Up @@ -119,12 +128,13 @@ def forward(self, x):


class EffNetV2(nn.Module):
def __init__(self, cfgs, num_classes=1000, width_mult=1.):
def __init__(self, num_classes=1000, width_mult=1., cfgs=None, stem_size=24, num_feature=1792):
super(EffNetV2, self).__init__()
# setting of inverted residual blocks
self.cfgs = cfgs

# building first layer
input_channel = _make_divisible(24 * width_mult, 8)
input_channel = _make_divisible(stem_size * width_mult, 8)
layers = [conv_3x3_bn(3, input_channel, 2)]
# building inverted residual blocks
block = MBConv
Expand All @@ -135,7 +145,7 @@ def __init__(self, cfgs, num_classes=1000, width_mult=1.):
input_channel = output_channel
self.features = nn.Sequential(*layers)
# building last several layers
output_channel = _make_divisible(1792 * width_mult, 8) if width_mult > 1.0 else 1792
output_channel = _make_divisible(num_feature * width_mult, 8) if width_mult > 1.0 else num_feature
self.conv = conv_1x1_bn(input_channel, output_channel)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(output_channel, num_classes)
Expand Down Expand Up @@ -169,64 +179,132 @@ def effnetv2_s(**kwargs):
"""
Constructs a EfficientNetV2-S model
"""
cfgs = [
settings = {"cfgs": [
# t, c, n, s, SE
[1, 24, 2, 1, 0],
[4, 48, 4, 2, 0],
[4, 64, 4, 2, 0],
[4, 128, 6, 2, 1],
[6, 160, 9, 1, 1],
[6, 272, 15, 2, 1],
]
return EffNetV2(cfgs, **kwargs)

[1, 24, 2, 1, 0],
[4, 48, 4, 2, 0],
[4, 64, 4, 2, 0],
[4, 128, 6, 2, 1],
[6, 160, 9, 1, 1],
[6, 256, 15, 2, 1],
]
}
kwargs.update(settings)
return EffNetV2(**kwargs)

def effnetv2_m(**kwargs):
"""
Constructs a EfficientNetV2-M model
"""
cfgs = [
settings = {"cfgs": [
# t, c, n, s, SE
[1, 24, 3, 1, 0],
[4, 48, 5, 2, 0],
[4, 80, 5, 2, 0],
[4, 160, 7, 2, 1],
[1, 24, 3, 1, 0],
[4, 48, 5, 2, 0],
[4, 80, 5, 2, 0],
[4, 160, 7, 2, 1],
[6, 176, 14, 1, 1],
[6, 304, 18, 2, 1],
[6, 512, 5, 1, 1],
[6, 512, 5, 1, 1],
]
return EffNetV2(cfgs, **kwargs)
}
kwargs.update(settings)
return EffNetV2(**kwargs)


def effnetv2_l(**kwargs):
"""
Constructs a EfficientNetV2-L model
"""
cfgs = [
settings = {"cfgs": [
# t, c, n, s, SE
[1, 32, 4, 1, 0],
[4, 64, 7, 2, 0],
[4, 96, 7, 2, 0],
[1, 32, 4, 1, 0],
[4, 64, 7, 2, 0],
[4, 96, 7, 2, 0],
[4, 192, 10, 2, 1],
[6, 224, 19, 1, 1],
[6, 384, 25, 2, 1],
[6, 640, 7, 1, 1],
[6, 640, 7, 1, 1],
]
return EffNetV2(cfgs, **kwargs)
}
kwargs.update(settings)
return EffNetV2(**kwargs)


def effnetv2_xl(**kwargs):
"""
Constructs a EfficientNetV2-XL model
"""
cfgs = [
settings = {"cfgs": [
# t, c, n, s, SE
[1, 32, 4, 1, 0],
[4, 64, 8, 2, 0],
[4, 96, 8, 2, 0],
[1, 32, 4, 1, 0],
[4, 64, 8, 2, 0],
[4, 96, 8, 2, 0],
[4, 192, 16, 2, 1],
[6, 256, 24, 1, 1],
[6, 512, 32, 2, 1],
[6, 640, 8, 1, 1],
[6, 640, 8, 1, 1],
]
return EffNetV2(cfgs, **kwargs)
}
kwargs.update(settings)
return EffNetV2(**kwargs)


def effnetv2_base(**kwargs):
"""
Constructs a EfficientNetV2-Base model
"""
width_mult = kwargs.pop("width_mult", 1.0)
round_chs_fn = partial(round_channels, multiplier=width_mult, round_limit=0.)
num_feature = round_chs_fn(1280)

depth_multiplier = kwargs.pop("depth_multiplier", 1.0)

settings = {"cfgs": [
# t, c, n, s, SE
[1, 16, 1, 1, 0],
[4, 32, 2, 2, 0],
[4, 48, 2, 2, 0],
[4, 96, 3, 2, 1],
[6, 112, 5, 1, 1],
[6, 192, 8, 2, 1],
],
"stem_size": 32,
"num_feature": num_feature,
"width_mult": width_mult
}
# scale depth
for i in range(len(settings["cfgs"])):
settings["cfgs"][i][2] = int(math.ceil(depth_multiplier*settings["cfgs"][i][2]))

kwargs.update(settings)
return EffNetV2(**kwargs)


effnetv2_b0 = effnetv2_base


def effnetv2_b1(**kwargs):
settings = {
"depth_multiplier": 1.1,
"width_mult": 1.0,
}
kwargs.update(settings)
return effnetv2_base(**kwargs)


def effnetv2_b2(**kwargs):
settings = {
"depth_multiplier": 1.2,
"width_mult": 1.1,
}
kwargs.update(settings)
return effnetv2_base(**kwargs)


def effnetv2_b3(**kwargs):
settings = {
"depth_multiplier": 1.4,
"width_mult": 1.2,
}
kwargs.update(settings)
return effnetv2_base(**kwargs)
45 changes: 45 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import torch
import argparse
from thop import profile

from effnetv2 import *

# for mac duplicate lib bug
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

def get_args_parser():
parser = argparse.ArgumentParser('PVT training and evaluation script', add_help=False)
# Model parameters
parser.add_argument('-m', '--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('-bs', '--batch_size', default=8, type=int, help='set batch size')
parser.add_argument('-e', '--export', action='store_true', help='convert to onnx models')
return parser


if __name__ == '__main__':
parser = argparse.ArgumentParser('generate onnx timm models', parents=[get_args_parser()])
args = parser.parse_args()

model_names = ['s', 'm', 'l', 'xl', 'b0', 'b1', 'b2', 'b3']
for m in model_names:
model_name = "effnetv2_" + m
model = eval(model_name)()
print(model_name)
x = torch.randn(1, 3, 224, 224)
flops, params = profile(model, inputs=(x,), verbose=False)
print("flops = %fM" % (flops / 1e6, ))
print("param size = %fM" % (params / 1e6, ))

if args.export:
print("exporting....")
model.eval()
x = torch.randn(args.batch_size, 3, 224, 224)
torch.onnx.export(model, x, args.model+"_bs"+str(args.batch_size)+".onnx",
input_names=['input'],
output_names=['output'],
verbose=True,
opset_version=11,
operator_export_type=torch.onnx.OperatorExportTypes.ONNX)
print("exported!")
59 changes: 59 additions & 0 deletions parse_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@

#################### EfficientNet V2 configs ####################
v2_base_block = [ # The baseline config for v2 models.
'r1_k3_s1_e1_i32_o16_c1',
'r2_k3_s2_e4_i16_o32_c1',
'r2_k3_s2_e4_i32_o48_c1',
'r3_k3_s2_e4_i48_o96_se0.25',
'r5_k3_s1_e6_i96_o112_se0.25',
'r8_k3_s2_e6_i112_o192_se0.25',
]


v2_s_block = [ # about base * (width1.4, depth1.8)
'r2_k3_s1_e1_i24_o24_c1',
'r4_k3_s2_e4_i24_o48_c1',
'r4_k3_s2_e4_i48_o64_c1',
'r6_k3_s2_e4_i64_o128_se0.25',
'r9_k3_s1_e6_i128_o160_se0.25',
'r15_k3_s2_e6_i160_o256_se0.25',
]


v2_m_block = [ # about base * (width1.6, depth2.2)
'r3_k3_s1_e1_i24_o24_c1',
'r5_k3_s2_e4_i24_o48_c1',
'r5_k3_s2_e4_i48_o80_c1',
'r7_k3_s2_e4_i80_o160_se0.25',
'r14_k3_s1_e6_i160_o176_se0.25',
'r18_k3_s2_e6_i176_o304_se0.25',
'r5_k3_s1_e6_i304_o512_se0.25',
]


v2_l_block = [ # about base * (width2.0, depth3.1)
'r4_k3_s1_e1_i32_o32_c1',
'r7_k3_s2_e4_i32_o64_c1',
'r7_k3_s2_e4_i64_o96_c1',
'r10_k3_s2_e4_i96_o192_se0.25',
'r19_k3_s1_e6_i192_o224_se0.25',
'r25_k3_s2_e6_i224_o384_se0.25',
'r7_k3_s1_e6_i384_o640_se0.25',
]

v2_xl_block = [ # only for 21k pretraining.
'r4_k3_s1_e1_i32_o32_c1',
'r8_k3_s2_e4_i32_o64_c1',
'r8_k3_s2_e4_i64_o96_c1',
'r16_k3_s2_e4_i96_o192_se0.25',
'r24_k3_s1_e6_i192_o256_se0.25',
'r32_k3_s2_e6_i256_o512_se0.25',
'r8_k3_s1_e6_i512_o640_se0.25',
]

for blk in [v2_base_block, v2_s_block, v2_m_block, v2_l_block, v2_xl_block]:
cfgs = []
for k in blk:
keys = k.split('_')
cfgs.append([int(keys[3][1:]), int(keys[5][1:]), int(keys[0][1:]), int(keys[2][1:]), 1 if 'se' in keys[-1] else 0])
print(cfgs)