|
|
@@ -1,20 +1,21 @@
|
|
|
import torch.nn as nn
|
|
|
import torchvision
|
|
|
from torchvision.models._utils import IntermediateLayerGetter
|
|
|
-from torchvision.models.resnet import (ResNet18_Weights,
|
|
|
- ResNet34_Weights,
|
|
|
- ResNet50_Weights,
|
|
|
- ResNet101_Weights)
|
|
|
-from .norm import FrozenBatchNorm2d
|
|
|
+from torchvision.models import resnet
|
|
|
+
|
|
|
+try:
|
|
|
+ from .norm import FrozenBatchNorm2d
|
|
|
+except:
|
|
|
+ from norm import FrozenBatchNorm2d
|
|
|
|
|
|
|
|
|
# IN1K pretrained weights
|
|
|
pretrained_urls = {
|
|
|
# ResNet series
|
|
|
- 'resnet18': ResNet18_Weights,
|
|
|
- 'resnet34': ResNet34_Weights,
|
|
|
- 'resnet50': ResNet50_Weights,
|
|
|
- 'resnet101': ResNet101_Weights,
|
|
|
+ 'resnet18': resnet.ResNet18_Weights,
|
|
|
+ 'resnet34': resnet.ResNet34_Weights,
|
|
|
+ 'resnet50': resnet.ResNet50_Weights,
|
|
|
+ 'resnet101': resnet.ResNet101_Weights,
|
|
|
|
|
|
}
|
|
|
|