|
|
@@ -144,22 +144,13 @@ class GElan(nn.Module):
|
|
|
all_cls_preds = outputs['pred_cls']
|
|
|
all_box_preds = outputs['pred_box']
|
|
|
|
|
|
- if self.deploy:
|
|
|
- cls_preds = torch.cat(all_cls_preds, dim=1)[0]
|
|
|
- box_preds = torch.cat(all_box_preds, dim=1)[0]
|
|
|
- scores = cls_preds.sigmoid()
|
|
|
- bboxes = box_preds
|
|
|
- # [n_anchors_all, 4 + C]
|
|
|
- outputs = torch.cat([bboxes, scores], dim=-1)
|
|
|
-
|
|
|
- else:
|
|
|
- # post process
|
|
|
- bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
|
|
|
- outputs = {
|
|
|
- "scores": scores,
|
|
|
- "labels": labels,
|
|
|
- "bboxes": bboxes
|
|
|
- }
|
|
|
+ # post process
|
|
|
+ bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
|
|
|
+ outputs = {
|
|
|
+ "scores": scores,
|
|
|
+ "labels": labels,
|
|
|
+ "bboxes": bboxes
|
|
|
+ }
|
|
|
|
|
|
return outputs
|
|
|
|