|
@@ -8,12 +8,12 @@ except:
|
|
|
|
|
|
|
|
# IN1K pretrained weight
|
|
# IN1K pretrained weight
|
|
|
pretrained_urls = {
|
|
pretrained_urls = {
|
|
|
- 't': None,
|
|
|
|
|
|
|
+ 't': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/elannet_t_in1k_6.6.pth",
|
|
|
'l': None,
|
|
'l': None,
|
|
|
'x': None,
|
|
'x': None,
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-# ELANNet
|
|
|
|
|
|
|
+# ELANNet-Tiny
|
|
|
class Yolov7TBackbone(nn.Module):
|
|
class Yolov7TBackbone(nn.Module):
|
|
|
def __init__(self, cfg):
|
|
def __init__(self, cfg):
|
|
|
super(Yolov7TBackbone, self).__init__()
|
|
super(Yolov7TBackbone, self).__init__()
|
|
@@ -33,9 +33,14 @@ class Yolov7TBackbone(nn.Module):
|
|
|
self.layer_4 = self.make_block(self.feat_dims[2], self.feat_dims[3], expansion=0.5)
|
|
self.layer_4 = self.make_block(self.feat_dims[2], self.feat_dims[3], expansion=0.5)
|
|
|
self.layer_5 = self.make_block(self.feat_dims[3], self.feat_dims[4], expansion=0.5)
|
|
self.layer_5 = self.make_block(self.feat_dims[3], self.feat_dims[4], expansion=0.5)
|
|
|
|
|
|
|
|
|
|
+ # Initialize all layers
|
|
|
# Initialize all layers
|
|
# Initialize all layers
|
|
|
self.init_weights()
|
|
self.init_weights()
|
|
|
|
|
|
|
|
|
|
+ # Load imagenet pretrained weight
|
|
|
|
|
+ if cfg.use_pretrained:
|
|
|
|
|
+ self.load_pretrained()
|
|
|
|
|
+
|
|
|
def init_weights(self):
|
|
def init_weights(self):
|
|
|
"""Initialize the parameters."""
|
|
"""Initialize the parameters."""
|
|
|
for m in self.modules():
|
|
for m in self.modules():
|
|
@@ -44,6 +49,31 @@ class Yolov7TBackbone(nn.Module):
|
|
|
# reset the Conv2d initialization parameters
|
|
# reset the Conv2d initialization parameters
|
|
|
m.reset_parameters()
|
|
m.reset_parameters()
|
|
|
|
|
|
|
|
|
|
+ def load_pretrained(self):
|
|
|
|
|
+ url = pretrained_urls[self.model_scale]
|
|
|
|
|
+ if url is not None:
|
|
|
|
|
+ print('Loading backbone pretrained weight from : {}'.format(url))
|
|
|
|
|
+ # checkpoint state dict
|
|
|
|
|
+ checkpoint = torch.hub.load_state_dict_from_url(
|
|
|
|
|
+ url=url, map_location="cpu", check_hash=True)
|
|
|
|
|
+ checkpoint_state_dict = checkpoint.pop("model")
|
|
|
|
|
+ # model state dict
|
|
|
|
|
+ model_state_dict = self.state_dict()
|
|
|
|
|
+ # check
|
|
|
|
|
+ for k in list(checkpoint_state_dict.keys()):
|
|
|
|
|
+ if k in model_state_dict:
|
|
|
|
|
+ shape_model = tuple(model_state_dict[k].shape)
|
|
|
|
|
+ shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
|
|
|
|
|
+ if shape_model != shape_checkpoint:
|
|
|
|
|
+ checkpoint_state_dict.pop(k)
|
|
|
|
|
+ else:
|
|
|
|
|
+ checkpoint_state_dict.pop(k)
|
|
|
|
|
+ print('Unused key: ', k)
|
|
|
|
|
+ # load the weight
|
|
|
|
|
+ self.load_state_dict(checkpoint_state_dict)
|
|
|
|
|
+ else:
|
|
|
|
|
+ print('No pretrained weight for model scale: {}.'.format(self.model_scale))
|
|
|
|
|
+
|
|
|
def make_stem(self, in_dim, out_dim):
|
|
def make_stem(self, in_dim, out_dim):
|
|
|
stem = BasicConv(in_dim, out_dim, kernel_size=6, padding=2, stride=2,
|
|
stem = BasicConv(in_dim, out_dim, kernel_size=6, padding=2, stride=2,
|
|
|
act_type=self.bk_act, norm_type=self.bk_norm, depthwise=self.bk_depthwise)
|
|
act_type=self.bk_act, norm_type=self.bk_norm, depthwise=self.bk_depthwise)
|
|
@@ -70,7 +100,7 @@ class Yolov7TBackbone(nn.Module):
|
|
|
|
|
|
|
|
return outputs
|
|
return outputs
|
|
|
|
|
|
|
|
-
|
|
|
|
|
|
|
+# ELANNet-Large
|
|
|
class Yolov7LBackbone(nn.Module):
|
|
class Yolov7LBackbone(nn.Module):
|
|
|
def __init__(self, cfg):
|
|
def __init__(self, cfg):
|
|
|
super(Yolov7LBackbone, self).__init__()
|
|
super(Yolov7LBackbone, self).__init__()
|
|
@@ -182,9 +212,9 @@ if __name__ == '__main__':
|
|
|
self.bk_act = 'silu'
|
|
self.bk_act = 'silu'
|
|
|
self.bk_norm = 'BN'
|
|
self.bk_norm = 'BN'
|
|
|
self.bk_depthwise = False
|
|
self.bk_depthwise = False
|
|
|
|
|
+ self.use_pretrained = True
|
|
|
self.width = 0.5
|
|
self.width = 0.5
|
|
|
- self.depth = 0.34
|
|
|
|
|
- self.scale = "l"
|
|
|
|
|
|
|
+ self.scale = "t"
|
|
|
|
|
|
|
|
cfg = BaseConfig()
|
|
cfg = BaseConfig()
|
|
|
model = Yolov7TBackbone(cfg)
|
|
model = Yolov7TBackbone(cfg)
|