|
@@ -48,13 +48,18 @@ class ResNet(nn.Module):
|
|
|
# Pretrained
|
|
# Pretrained
|
|
|
assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
|
|
assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
|
|
|
if pretrained_weights is not None:
|
|
if pretrained_weights is not None:
|
|
|
- if name in ('resnet18', 'resnet34'):
|
|
|
|
|
- pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
|
|
|
|
|
- else:
|
|
|
|
|
- if pretrained_weights == "imagenet1k_v1":
|
|
|
|
|
- pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
|
|
|
|
|
- else:
|
|
|
|
|
- pretrained_weights = pretrained_urls[name].IMAGENET1K_V2
|
|
|
|
|
|
|
+ if name == "resnet18":
|
|
|
|
|
+ pretrained_weights = resnet.ResNet18_Weights.IMAGENET1K_V1
|
|
|
|
|
+ elif name == "resnet34":
|
|
|
|
|
+ pretrained_weights = resnet.ResNet34_Weights.IMAGENET1K_V1
|
|
|
|
|
+ elif name == "resnet50" and pretrained_weights == "imagenet1k_v1":
|
|
|
|
|
+ pretrained_weights = resnet.ResNet50_Weights.IMAGENET1K_V1
|
|
|
|
|
+ elif name == "resnet50" and pretrained_weights == "imagenet1k_v2":
|
|
|
|
|
+ pretrained_weights = resnet.ResNet50_Weights.IMAGENET1K_V2
|
|
|
|
|
+ elif name == "resnet101" and pretrained_weights == "imagenet1k_v1":
|
|
|
|
|
+ pretrained_weights = resnet.ResNet101_Weights.IMAGENET1K_V1
|
|
|
|
|
+ elif name == "resnet101" and pretrained_weights == "imagenet1k_v2":
|
|
|
|
|
+ pretrained_weights = resnet.ResNet101_Weights.IMAGENET1K_V2
|
|
|
else:
|
|
else:
|
|
|
pretrained_weights = None
|
|
pretrained_weights = None
|
|
|
print('- Backbone pretrained weight: ', pretrained_weights)
|
|
print('- Backbone pretrained weight: ', pretrained_weights)
|