| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164 |
- # Objects as Points
- # --------------- Torch components ---------------
- import torch
- import torch.nn as nn
- # --------------- Model components ---------------
- from .ctrnet_encoder import build_encoder
- from .ctrnet_decoder import build_decoder
- from .ctrnet_neck import build_neck
- from .ctrnet_head import build_det_head
- from .ctrnet_pred import build_det_pred
- # CenterNet
- class CenterNet():
- def __init__(self,
- cfg,
- device,
- num_classes = 20,
- conf_thresh = 0.01,
- topk = 1000,
- trainable = False,
- deploy = False,
- no_multi_labels = False,
- nms_class_agnostic = False,
- ):
- super(CenterNet, self).__init__()
- # ---------------- Basic Parameters ----------------
- self.cfg = cfg
- self.device = device
- self.stride = cfg['out_stride']
- self.num_classes = num_classes
- self.trainable = trainable
- self.conf_thresh = conf_thresh
- self.num_classes = num_classes
- self.topk_candidates = topk
- self.deploy = deploy
- self.no_multi_labels = no_multi_labels
- self.nms_class_agnostic = nms_class_agnostic
- self.head_dim = round(128 * cfg['width'])
-
- # ---------------- Network Parameters ----------------
- ## Encoder
- self.encoder, feat_dims = build_encoder(cfg, pretrained=cfg['bk_pretrained']&trainable)
- ## Neck
- self.neck = build_neck(cfg, feat_dims[-1], feat_dims[-1])
- self.feat_dim = self.neck.out_dim
-
- ## Decoder
- self.decoder = build_decoder(cfg, self.feat_dim, self.head_dim)
- ## Head
- self.det_head = nn.Sequential(
- build_det_head(cfg, self.head_dim, self.head_dim),
- build_det_pred(self.head_dim, self.head_dim, self.stride, num_classes, 4)
- )
- ## Aux Head
- self.aux_det_head = nn.Sequential(
- build_det_head(cfg, self.head_dim, self.head_dim),
- build_det_pred(self.head_dim, self.head_dim, self.stride, num_classes, 4)
- )
- # Post process
- def post_process(self, cls_pred, box_pred):
- """
- Input:
- cls_pred: torch.Tensor -> [M, C]
- box_pred: torch.Tensor -> [M, 4]
- Output:
- bboxes: np.array -> [N, 4]
- scores: np.array -> [N,]
- labels: np.array -> [N,]
- """
- cls_pred = cls_pred[0]
- box_pred = box_pred[0]
- if self.no_multi_labels:
- # [M,]
- scores, labels = torch.max(cls_pred.sigmoid(), dim=1)
- # Keep top k top scoring indices only.
- num_topk = min(self.topk_candidates, box_pred.size(0))
- # topk candidates
- predicted_prob, topk_idxs = scores.sort(descending=True)
- topk_scores = predicted_prob[:num_topk]
- topk_idxs = topk_idxs[:num_topk]
- # filter out the proposals with low confidence score
- keep_idxs = topk_scores > self.conf_thresh
- scores = topk_scores[keep_idxs]
- topk_idxs = topk_idxs[keep_idxs]
- labels = labels[topk_idxs]
- bboxes = box_pred[topk_idxs]
- else:
- # [M, C] -> [MC,]
- scores = cls_pred.sigmoid().flatten()
- # Keep top k top scoring indices only.
- num_topk = min(self.topk_candidates, box_pred.size(0))
- # torch.sort is actually faster than .topk (at least on GPUs)
- predicted_prob, topk_idxs = scores.sort(descending=True)
- topk_scores = predicted_prob[:num_topk]
- topk_idxs = topk_idxs[:num_topk]
- # filter out the proposals with low confidence score
- keep_idxs = topk_scores > self.conf_thresh
- scores = topk_scores[keep_idxs]
- topk_idxs = topk_idxs[keep_idxs]
- anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
- labels = topk_idxs % self.num_classes
- bboxes = box_pred[anchor_idxs]
- # to cpu & numpy
- scores = scores.cpu().numpy()
- labels = labels.cpu().numpy()
- bboxes = bboxes.cpu().numpy()
- return bboxes, scores, labels
-
- # Main process
- def forward(self, x):
- # ---------------- Backbone ----------------
- pyramid_feats = self.encoder(x)
- # ---------------- Neck ----------------
- feat = self.neck(pyramid_feats[-1])
- # ---------------- Encoder ----------------
- feat = self.decoder(feat)
- # ---------------- Head ----------------
- outputs = self.det_head(x)
- if self.trainable:
- outputs['aux_outputs'] = self.aux_det_head(x)
- # ---------------- Post-process ----------------
- if not self.trainable:
- cls_preds = outputs['pred_cls']
- box_preds = outputs['pred_box']
- if self.deploy:
- scores = cls_preds[0].sigmoid()
- bboxes = box_preds[0]
- # [n_anchors_all, 4 + C]
- outputs = torch.cat([bboxes, scores], dim=-1)
- else:
- # post process
- bboxes, scores, labels = self.post_process(cls_preds, box_preds)
- outputs = {
- "scores": scores,
- "labels": labels,
- "bboxes": bboxes
- }
-
- return outputs
-
|