Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def forward(self, feats):
return multi_apply(self.forward_single, feats)
def forward(self, feats):
return multi_apply(self.forward_single, feats, self.scales_bbox, self.scales_mask)
def forward(self, feats):
return multi_apply(self.forward_single, feats)
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=self.sampling)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (
num_total_pos + num_total_neg if self.sampling else num_total_pos)
losses_cls, losses_bbox = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
all_bbox_preds = torch.cat([
b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
for b in bbox_preds
], -2)
all_bbox_targets = torch.cat(bbox_targets_list,
-2).view(num_images, -1, 4)
all_bbox_weights = torch.cat(bbox_weights_list,
-2).view(num_images, -1, 4)
# check NaN and Inf
assert torch.isfinite(all_cls_scores).all().item(), \
'classification scores become infinite or NaN!'
assert torch.isfinite(all_bbox_preds).all().item(), \
'bbox predications become infinite or NaN!'
losses_cls, losses_bbox = multi_apply(
self.loss_single,
all_cls_scores,
all_bbox_preds,
all_labels,
all_label_weights,
all_bbox_targets,
all_bbox_weights,
num_total_samples=num_total_pos,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
def forward(self, feats):
return multi_apply(self.forward_single, feats)
gt_bboxes,
img_metas,
self.target_means,
self.target_stds,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=self.sampling)
if cls_reg_targets is None:
return None
(labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
num_total_pos, num_total_neg) = cls_reg_targets
num_total_samples = (
num_total_pos + num_total_neg if self.sampling else num_total_pos)
losses_cls, losses_bbox = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
num_total_samples=num_total_samples,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
gt_bboxes,
img_metas,
cfg.refine,
gt_bboxes_ignore_list=gt_bboxes_ignore,
gt_labels_list=gt_labels,
label_channels=label_channels,
sampling=self.sampling)
(labels_list, label_weights_list, bbox_gt_list_refine,
candidate_list_refine, bbox_weights_list_refine, num_total_pos_refine,
num_total_neg_refine) = cls_reg_targets_refine
num_total_samples_refine = (
num_total_pos_refine +
num_total_neg_refine if self.sampling else num_total_pos_refine)
# compute loss
losses_cls, losses_pts_init, losses_pts_refine = multi_apply(
self.loss_single,
cls_scores,
pts_coordinate_preds_init,
pts_coordinate_preds_refine,
labels_list,
label_weights_list,
bbox_gt_list_init,
bbox_weights_list_init,
bbox_gt_list_refine,
bbox_weights_list_refine,
self.point_strides,
num_total_samples_init=num_total_samples_init,
num_total_samples_refine=num_total_samples_refine)
loss_dict_all = {
'loss_cls': losses_cls,
'loss_pts_init': losses_pts_init,
s.permute(0, 2, 3, 1).reshape(
num_images, -1, self.cls_out_channels) for s in cls_scores
], 1)
all_labels = torch.cat(labels_list, -1).view(num_images, -1)
all_label_weights = torch.cat(label_weights_list,
-1).view(num_images, -1)
all_bbox_preds = torch.cat([
b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
for b in bbox_preds
], -2)
all_bbox_targets = torch.cat(bbox_targets_list,
-2).view(num_images, -1, 4)
all_bbox_weights = torch.cat(bbox_weights_list,
-2).view(num_images, -1, 4)
losses_cls, losses_bbox = multi_apply(
self.loss_single,
all_cls_scores,
all_bbox_preds,
all_labels,
all_label_weights,
all_bbox_targets,
all_bbox_weights,
num_total_samples=num_total_pos,
cfg=cfg)
return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
cfg,
gt_bboxes_ignore=None):
cls_reg_targets = self.fsaf_target(
cls_scores,
bbox_preds,
gt_bboxes,
gt_labels,
img_metas,
cfg,
gt_bboxes_ignore_list=gt_bboxes_ignore)
(cls_targets_list, reg_targets_list) = cls_reg_targets
level_list = [i for i in range(len(self.feat_strides))]
loss_cls, loss_reg, norm_cls, norm_reg = multi_apply(
self.loss_single,
cls_scores,
bbox_preds,
cls_targets_list,
reg_targets_list,
level_list,
img_metas=img_metas,
cfg=cfg,
gt_bboxes_ignore=None)
loss_cls = sum(loss_cls)/sum(norm_cls)
loss_reg = sum(loss_reg)/sum(norm_reg)
return dict(loss_cls=loss_cls, loss_bbox=loss_reg)