|
|
@@ -6,7 +6,7 @@ from utils.box_ops import get_ious
|
|
|
from utils.misc import sigmoid_focal_loss
|
|
|
from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
|
|
|
|
|
|
-from .matcher import FcosMatcher, SimOtaMatcher
|
|
|
+from .matcher import FcosMatcher, AlignedOTAMatcher
|
|
|
|
|
|
|
|
|
class SetCriterion(nn.Module):
|
|
|
@@ -33,9 +33,9 @@ class SetCriterion(nn.Module):
|
|
|
elif cfg.matcher == 'simota':
|
|
|
self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
|
|
|
'loss_reg': cfg.loss_reg_weight}
|
|
|
- self.matcher = SimOtaMatcher(cfg.num_classes,
|
|
|
- self.matcher_cfg['soft_center_radius'],
|
|
|
- self.matcher_cfg['topk_candidates'])
|
|
|
+ self.matcher = AlignedOTAMatcher(cfg.num_classes,
|
|
|
+ self.matcher_cfg['soft_center_radius'],
|
|
|
+ self.matcher_cfg['topk_candidates'])
|
|
|
else:
|
|
|
raise NotImplementedError("Unknown matcher: {}.".format(cfg.matcher))
|
|
|
|