Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def __init__(self, args):
super(Model, self).__init__()
print("using backbone", args.backbone)
if args.backbone == "vgg16_torch":
self.feature_extractor = CNN()
elif args.backbone == "vgg16_longcw":
self.feature_extractor = VGG16()
self.feature_extractor.load_from_npy_file('../input/pretrained_model/VGG_imagenet.npy')
self.rpn = RPN()
self.fasterrcnn = FasterRcnn()
self.proplayer = ProposalLayer(args=args)
self.roipool = ROIpooling()
def __init__(self, args):
super(Model, self).__init__()
print("using backbone", args.backbone)
if args.backbone == "vgg16_torch":
self.feature_extractor = CNN()
elif args.backbone == "vgg16_longcw":
self.feature_extractor = VGG16()
self.feature_extractor.load_from_npy_file('../input/pretrained_model/VGG_imagenet.npy')
self.rpn = RPN()
self.fasterrcnn = FasterRcnn()
self.proplayer = ProposalLayer(args=args)
self.roipool = ROIpooling()
print("features : {} ".format(features.size()))
# RPN test
rpn = RPN()
rpn_bbox_pred, rpn_cls_prob = rpn(features)
print("rpn_bbox_pred : {}, rpn_cls_prob : {}".format(rpn_bbox_pred.size(), rpn_cls_prob.size())) # torch.Size([1, 36, 62, 37]) torch.Size([1, 18, 62, 37])
# get_achors test
all_anchors = get_anchors(features, anchor)
print("all_anchors : {}".format(all_anchors.shape))
# proposal layer test
proplayer = ProposalLayer(rpn_bbox_pred, rpn_cls_prob, all_anchors, im_info=image_info, args=args)
proposals, scores = proplayer.proposal()
print("proposals : {}, scores : {}".format(proposals.shape, scores.shape))
print(proposals.astype("int"))
# rpn_target test
rpn_labels, rpn_bbox_targets = rpn_targets(all_anchors, image, gt_boxes, args)
print("rpn_labels : {}, bbox_target : {}".format(rpn_labels.shape, rpn_bbox_targets.shape)) # (20646,) (20646, 4)
# gt_boxes도 추가해줘야 해서 targets을 먼저 구한다.
# frcnn_targets test
frcnn_labels, rois, frcnn_bbox_targets = frcnn_targets(proposals, gt_boxes, args)
print("frcnn_labels : {}, rois : {}, frcnn_bbox_targets : {}".format(frcnn_labels.shape, rois.shape, frcnn_bbox_targets.shape))
# ROIpooling test
roipool = ROIpooling()
rois_features = roipool(features, rois)