|
|
@@ -70,15 +70,15 @@ class Criterion(object):
|
|
|
targets=targets)
|
|
|
# List[B, M, C] -> [B, M, C] -> [BM, C]
|
|
|
batch_size = outputs['pred_obj'].shape[0]
|
|
|
- pred_obj = outputs['pred_obj'].view(-1)
|
|
|
- pred_cls = outputs['pred_cls'].view(-1, self.num_classes)
|
|
|
- pred_txty = outputs['pred_txty'].view(-1, 2)
|
|
|
- pred_twth = outputs['pred_twth'].view(-1, 2)
|
|
|
+ pred_obj = outputs['pred_obj'].view(-1) # [BM,]
|
|
|
+ pred_cls = outputs['pred_cls'].view(-1, self.num_classes) # [BM, C]
|
|
|
+ pred_txty = outputs['pred_txty'].view(-1, 2) # [BM, 2]
|
|
|
+ pred_twth = outputs['pred_twth'].view(-1, 2) # [BM, 2]
|
|
|
|
|
|
- gt_objectness = gt_objectness.view(-1).to(device).float()
|
|
|
- gt_labels = gt_labels.view(-1).to(device).long()
|
|
|
- gt_bboxes = gt_bboxes.view(-1, 4).to(device).float()
|
|
|
- gt_box_weight = gt_box_weight.view(-1).to(device).float()
|
|
|
+ gt_objectness = gt_objectness.view(-1).to(device).float() # [BM,]
|
|
|
+ gt_labels = gt_labels.view(-1).to(device).long() # [BM,]
|
|
|
+ gt_bboxes = gt_bboxes.view(-1, 4).to(device).float() # [BM, 4]
|
|
|
+ gt_box_weight = gt_box_weight.view(-1).to(device).float() # [BM,]
|
|
|
|
|
|
pos_masks = (gt_objectness > 0)
|
|
|
|