|
|
@@ -68,6 +68,7 @@ def build_model(args,
|
|
|
k_ = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
|
|
|
checkpoint_state_dict[k_] = checkpoint_state_dict[k]
|
|
|
checkpoint_state_dict.pop(k)
|
|
|
+ k = k_
|
|
|
|
|
|
if k in model_state_dict:
|
|
|
shape_model = tuple(model_state_dict[k].shape)
|