|
|
@@ -15,12 +15,13 @@ class RTRDet(nn.Module):
|
|
|
aux_loss :bool = False,
|
|
|
deploy :bool = False):
|
|
|
super(RTRDet, self).__init__()
|
|
|
- assert cfg['max_stride'] == 16 or cfg['max_stride'] == 32
|
|
|
+ assert cfg['out_stride'] == 16 or cfg['out_stride'] == 32
|
|
|
# ------------------ Basic parameters ------------------
|
|
|
self.cfg = cfg
|
|
|
self.device = device
|
|
|
+ self.out_stride = cfg['out_stride']
|
|
|
self.max_stride = cfg['max_stride']
|
|
|
- self.num_levels = 2 if cfg['max_stride'] == 16 else 1
|
|
|
+ self.num_levels = 2 if cfg['out_stride'] == 16 else 1
|
|
|
self.num_topk = cfg['num_topk']
|
|
|
self.num_classes = num_classes
|
|
|
self.d_model = round(cfg['d_model'] * cfg['width'])
|