|
|
@@ -6,6 +6,12 @@ try:
|
|
|
except:
|
|
|
from yolov7_af_basic import BasicConv, MDown, ELANLayer
|
|
|
|
|
|
+# IN1K pretrained weight
|
|
|
+pretrained_urls = {
|
|
|
+ 't': None,
|
|
|
+ 'l': None,
|
|
|
+ 'x': None,
|
|
|
+}
|
|
|
|
|
|
# ELANNet
|
|
|
class Yolov7TBackbone(nn.Module):
|
|
|
@@ -87,6 +93,10 @@ class Yolov7LBackbone(nn.Module):
|
|
|
# Initialize all layers
|
|
|
self.init_weights()
|
|
|
|
|
|
+ # Load imagenet pretrained weight
|
|
|
+ if cfg.use_pretrained:
|
|
|
+ self.load_pretrained()
|
|
|
+
|
|
|
def init_weights(self):
|
|
|
"""Initialize the parameters."""
|
|
|
for m in self.modules():
|
|
|
@@ -95,6 +105,31 @@ class Yolov7LBackbone(nn.Module):
|
|
|
# reset the Conv2d initialization parameters
|
|
|
m.reset_parameters()
|
|
|
|
|
|
+ def load_pretrained(self):
|
|
|
+ url = pretrained_urls[self.model_scale]
|
|
|
+ if url is not None:
|
|
|
+ print('Loading backbone pretrained weight from : {}'.format(url))
|
|
|
+ # checkpoint state dict
|
|
|
+ checkpoint = torch.hub.load_state_dict_from_url(
|
|
|
+ url=url, map_location="cpu", check_hash=True)
|
|
|
+ checkpoint_state_dict = checkpoint.pop("model")
|
|
|
+ # model state dict
|
|
|
+ model_state_dict = self.state_dict()
|
|
|
+ # check
|
|
|
+ for k in list(checkpoint_state_dict.keys()):
|
|
|
+ if k in model_state_dict:
|
|
|
+ shape_model = tuple(model_state_dict[k].shape)
|
|
|
+ shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
|
|
|
+ if shape_model != shape_checkpoint:
|
|
|
+ checkpoint_state_dict.pop(k)
|
|
|
+ else:
|
|
|
+ checkpoint_state_dict.pop(k)
|
|
|
+ print('Unused key: ', k)
|
|
|
+ # load the weight
|
|
|
+ self.load_state_dict(checkpoint_state_dict)
|
|
|
+ else:
|
|
|
+ print('No pretrained weight for model scale: {}.'.format(self.model_scale))
|
|
|
+
|
|
|
def make_stem(self, in_dim, out_dim):
|
|
|
stem = nn.Sequential(
|
|
|
BasicConv(in_dim, out_dim//2, kernel_size=3, padding=1, stride=1,
|
|
|
@@ -147,12 +182,12 @@ if __name__ == '__main__':
|
|
|
self.bk_act = 'silu'
|
|
|
self.bk_norm = 'BN'
|
|
|
self.bk_depthwise = False
|
|
|
- self.width = 1.0
|
|
|
+ self.width = 0.5
|
|
|
self.depth = 0.34
|
|
|
self.scale = "l"
|
|
|
|
|
|
cfg = BaseConfig()
|
|
|
- model = Yolov7LBackbone(cfg)
|
|
|
+ model = Yolov7TBackbone(cfg)
|
|
|
x = torch.randn(1, 3, 640, 640)
|
|
|
t0 = time.time()
|
|
|
outputs = model(x)
|