|
|
@@ -1322,16 +1322,18 @@ class RTDetrTrainer(object):
|
|
|
targets = self.box_xyxy_to_cxcywh(targets)
|
|
|
|
|
|
# Inference
|
|
|
- with torch.cuda.amp.autocast(enabled=self.args.fp16):
|
|
|
- outputs = model(images, targets)
|
|
|
- # Compute loss
|
|
|
+ with torch.autocast(device_type=str(self.device), cache_enabled=True):
|
|
|
+ outputs = model(images, targets)
|
|
|
+
|
|
|
+ # Compute loss
|
|
|
+ with torch.autocast(device_type=str(self.device), enabled=False):
|
|
|
loss_dict = self.criterion(outputs, targets)
|
|
|
- losses = sum(loss_dict.values())
|
|
|
- # Grad Accumulate
|
|
|
- if self.grad_accumulate > 1:
|
|
|
- losses /= self.grad_accumulate
|
|
|
+ losses = sum(loss_dict.values())
|
|
|
+ # Grad Accumulate
|
|
|
+ if self.grad_accumulate > 1:
|
|
|
+ losses /= self.grad_accumulate
|
|
|
|
|
|
- loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
|
|
|
+ loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
|
|
|
|
|
|
# Backward
|
|
|
self.scaler.scale(losses).backward()
|