yjh0410 1 жил өмнө
parent
commit
5cddc31ca3

+ 12 - 7
yolo/models/rtdetr/basic_modules/backbone.py

@@ -48,13 +48,18 @@ class ResNet(nn.Module):
         # Pretrained
         assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
         if pretrained_weights is not None:
-            if name in ('resnet18', 'resnet34'):
-                pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
-            else:
-                if pretrained_weights == "imagenet1k_v1":
-                    pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
-                else:
-                    pretrained_weights = pretrained_urls[name].IMAGENET1K_V2
+            if   name == "resnet18":
+                pretrained_weights = resnet.ResNet18_Weights.IMAGENET1K_V1
+            elif name == "resnet34":
+                pretrained_weights = resnet.ResNet34_Weights.IMAGENET1K_V1
+            elif name == "resnet50" and pretrained_weights == "imagenet1k_v1":
+                pretrained_weights = resnet.ResNet50_Weights.IMAGENET1K_V1
+            elif name == "resnet50" and pretrained_weights == "imagenet1k_v2":
+                pretrained_weights = resnet.ResNet50_Weights.IMAGENET1K_V2
+            elif name == "resnet101" and pretrained_weights == "imagenet1k_v1":
+                pretrained_weights = resnet.ResNet101_Weights.IMAGENET1K_V1
+            elif name == "resnet101" and pretrained_weights == "imagenet1k_v2":
+                pretrained_weights = resnet.ResNet101_Weights.IMAGENET1K_V2
         else:
             pretrained_weights = None
         print('- Backbone pretrained weight: ', pretrained_weights)