|
|
@@ -61,32 +61,18 @@ def build_model(args,
|
|
|
# model state dict
|
|
|
model_state_dict = model.state_dict()
|
|
|
# check
|
|
|
- new_checkpoint_state_dict = {}
|
|
|
-
|
|
|
for k in list(checkpoint_state_dict.keys()):
|
|
|
- v = checkpoint_state_dict[k]
|
|
|
- if 'reduce_layer_3' in k:
|
|
|
- k_new = k.split('.')
|
|
|
- k_new[1] = 'downsample_layer_1'
|
|
|
- k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
|
|
|
- elif 'reduce_layer_4' in k:
|
|
|
- k_new = k.split('.')
|
|
|
- k_new[1] = 'downsample_layer_2'
|
|
|
- k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
|
|
|
- new_checkpoint_state_dict[k] = v
|
|
|
-
|
|
|
- for k in list(new_checkpoint_state_dict.keys()):
|
|
|
if k in model_state_dict:
|
|
|
shape_model = tuple(model_state_dict[k].shape)
|
|
|
- shape_checkpoint = tuple(new_checkpoint_state_dict[k].shape)
|
|
|
+ shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
|
|
|
if shape_model != shape_checkpoint:
|
|
|
- new_checkpoint_state_dict.pop(k)
|
|
|
+ checkpoint_state_dict.pop(k)
|
|
|
print(k)
|
|
|
else:
|
|
|
- new_checkpoint_state_dict.pop(k)
|
|
|
+ checkpoint_state_dict.pop(k)
|
|
|
print(k)
|
|
|
|
|
|
- model.load_state_dict(new_checkpoint_state_dict, strict=False)
|
|
|
+ model.load_state_dict(checkpoint_state_dict, strict=False)
|
|
|
|
|
|
# keep training
|
|
|
if args.resume is not None:
|
|
|
@@ -94,7 +80,21 @@ def build_model(args,
|
|
|
checkpoint = torch.load(args.resume, map_location='cpu')
|
|
|
# checkpoint state dict
|
|
|
checkpoint_state_dict = checkpoint.pop("model")
|
|
|
- model.load_state_dict(checkpoint_state_dict)
|
|
|
+ # check
|
|
|
+ new_checkpoint_state_dict = {}
|
|
|
+
|
|
|
+ for k in list(checkpoint_state_dict.keys()):
|
|
|
+ v = checkpoint_state_dict[k]
|
|
|
+ if 'reduce_layer_3' in k:
|
|
|
+ k_new = k.split('.')
|
|
|
+ k_new[1] = 'downsample_layer_1'
|
|
|
+ k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
|
|
|
+ elif 'reduce_layer_4' in k:
|
|
|
+ k_new = k.split('.')
|
|
|
+ k_new[1] = 'downsample_layer_2'
|
|
|
+ k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
|
|
|
+ new_checkpoint_state_dict[k] = v
|
|
|
+ model.load_state_dict(new_checkpoint_state_dict)
|
|
|
|
|
|
return model, criterion
|
|
|
|