Selaa lähdekoodia

use qfl cls loss

yjh0410 2 vuotta sitten
vanhempi
sitoutus
9831826ab0
2 muutettua tiedostoa jossa 2 lisäystä ja 2 poistoa
  1. 1 1
      config/model_config/rtcdet_config.py
  2. 1 1
      train_multi_gpus.sh

+ 1 - 1
config/model_config/rtcdet_config.py

@@ -217,7 +217,7 @@ rtcdet_cfg = {
         'matcher_hpy': {'center_sampling_radius': 2.5,
                         'topk_candidate': 10},
         # ---------------- Loss config ----------------
-        'cls_loss': 'bce',
+        'cls_loss': 'qfl',
         'loss_cls_weight': 1.0,
         'loss_dfl_weight': 1.0,
         'loss_box_weight': 5.0,

+ 1 - 1
train_multi_gpus.sh

@@ -4,7 +4,7 @@ python -m torch.distributed.run --nproc_per_node=8 train.py \
                                                     -dist \
                                                     -d coco \
                                                     --root /data/datasets/ \
-                                                    -m rtcdet_t \
+                                                    -m rtcdet_s \
                                                     -bs 128 \
                                                     -size 640 \
                                                     --wp_epoch 3 \