yjh0410 1 year ago
parent
commit
8affb369da
1 changed files with 10 additions and 9 deletions
  1. 10 9
      yolo/models/rtdetr/basic_modules/backbone.py

+ 10 - 9
yolo/models/rtdetr/basic_modules/backbone.py

@@ -1,20 +1,21 @@
 import torch.nn as nn
 import torchvision
 from torchvision.models._utils import IntermediateLayerGetter
-from torchvision.models.resnet import (ResNet18_Weights,
-                                       ResNet34_Weights,
-                                       ResNet50_Weights,
-                                       ResNet101_Weights)
-from .norm import FrozenBatchNorm2d
+from torchvision.models import resnet
+
+try:
+    from .norm import FrozenBatchNorm2d
+except:
+    from  norm import FrozenBatchNorm2d
 
 
 # IN1K pretrained weights
 pretrained_urls = {
     # ResNet series
-    'resnet18':  ResNet18_Weights,
-    'resnet34':  ResNet34_Weights,
-    'resnet50':  ResNet50_Weights,
-    'resnet101': ResNet101_Weights,
+    'resnet18':  resnet.ResNet18_Weights,
+    'resnet34':  resnet.ResNet34_Weights,
+    'resnet50':  resnet.ResNet50_Weights,
+    'resnet101': resnet.ResNet101_Weights,
 
 }