detr_config.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # End-to-end Detection with Transformer
  2. def build_detr_config(args):
  3. if args.model == 'detr_r50':
  4. return Detr_R50_Config()
  5. else:
  6. raise NotImplementedError("No config for model: {}".format(args.model))
  7. class DetrBaseConfig(object):
  8. def __init__(self):
  9. # --------- Backbone ---------
  10. self.backbone = "resnet50"
  11. self.bk_norm = "FrozeBN"
  12. self.res5_dilation = False
  13. self.use_pretrained = True
  14. self.freeze_at = 1
  15. self.max_stride = 32
  16. self.out_stride = 32
  17. # --------- Transformer ---------
  18. self.transformer = "detr_transformer"
  19. self.hidden_dim = 256
  20. self.num_heads = 8
  21. self.feedforward_dim = 2048
  22. self.num_enc_layers = 6
  23. self.num_dec_layers = 6
  24. self.dropout = 0.1
  25. self.tr_act = 'relu'
  26. self.pre_norm = False
  27. # --------- Post-process ---------
  28. self.train_topk = 300
  29. self.train_conf_thresh = 0.05
  30. self.test_topk = 300
  31. self.test_conf_thresh = 0.3
  32. # --------- Label Assignment ---------
  33. self.matcher_hpy = {'cost_class': 1.0,
  34. 'cost_bbox': 5.0,
  35. 'cost_giou': 2.0,
  36. }
  37. # --------- Loss weight ---------
  38. self.loss_cls = 1.0
  39. self.loss_box = 5.0
  40. self.loss_giou = 2.0
  41. # --------- Optimizer ---------
  42. self.optimizer = 'adamw'
  43. self.batch_size_base = 16
  44. self.per_image_lr = 0.0001 / 16
  45. self.bk_lr_ratio = 0.1
  46. self.momentum = None
  47. self.weight_decay = 1e-4
  48. self.clip_max_norm = 0.1
  49. # --------- LR Scheduler ---------
  50. self.lr_scheduler = 'step'
  51. self.warmup = 'linear'
  52. self.warmup_iters = 100
  53. self.warmup_factor = 0.00066667
  54. # --------- Train epoch ---------
  55. self.max_epoch = 500
  56. self.lr_epoch = [400]
  57. self.eval_epoch = 2
  58. # --------- Data process ---------
  59. ## input size
  60. self.train_min_size = [800] # short edge of image
  61. self.train_min_size2 = [400, 500, 600]
  62. self.train_max_size = 1333
  63. self.test_min_size = [800]
  64. self.test_max_size = 1333
  65. self.random_crop_size = [320, 600]
  66. ## Pixel mean & std
  67. self.pixel_mean = [0.485, 0.456, 0.406]
  68. self.pixel_std = [0.229, 0.224, 0.225]
  69. ## Transforms
  70. self.box_format = 'xywh'
  71. self.normalize_coords = True
  72. self.detr_style = True
  73. self.trans_config = None
  74. def print_config(self):
  75. config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
  76. for k, v in config_dict.items():
  77. print("{} : {}".format(k, v))
  78. class Detr_R50_Config(DetrBaseConfig):
  79. def __init__(self) -> None:
  80. super().__init__()
  81. # --------- Backbone ---------
  82. self.backbone = "resnet50"
  83. self.bk_norm = "FrozeBN"
  84. self.res5_dilation = False
  85. self.use_pretrained = True