|
|
@@ -62,6 +62,10 @@ def build_model(args,
|
|
|
model_state_dict = model.state_dict()
|
|
|
# check
|
|
|
for k in list(checkpoint_state_dict.keys()):
|
|
|
+ if 'reduce_layer_3' in k:
|
|
|
+ k_new = k.split('.')
|
|
|
+ k_new[1] = 'downsample_layer_1'
|
|
|
+ k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
|
|
|
if k in model_state_dict:
|
|
|
shape_model = tuple(model_state_dict[k].shape)
|
|
|
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
|