Update yolo_layer.py to work like multibox_loss_gmm.py

17 views Asked by At

I would like this yolo_layer.py to work like how multibox_loss_gmm.py is. There needs a lot of major fixes to do but I'm lost. Please update the yolo_layer.py so it includes the gmm working. This is my yolo_layer.py:

import torch
import torch.nn as nn
import numpy as np
from utils.utils import bboxes_iou


class YOLOLayer(nn.Module):
    """
    detection layer corresponding to yolo_layer.c of darknet
    """
    def __init__(self, config_model, layer_no, in_ch, ignore_thre=0.7):

        super(YOLOLayer, self).__init__()
        strides = [32, 16, 8] # fixed pixel bit depth 
        self.anchors = config_model['ANCHORS']
        self.anch_mask = config_model['ANCH_MASK'][layer_no]
        self.n_anchors = len(self.anch_mask)
        self.n_classes = config_model['N_CLASSES']
        self.ignore_thre = ignore_thre
        self.l2_loss = nn.MSELoss(size_average=False) # measures the error of mean squared format that is square L2 normalization
        self.bce_loss = nn.BCELoss(size_average=False)
        self.stride = strides[layer_no]
        self.all_anchors_grid = [(w / self.stride, h / self.stride)
                                 for w, h in self.anchors]
        self.masked_anchors = [self.all_anchors_grid[i]
                               for i in self.anch_mask]
        self.ref_anchors = np.zeros((len(self.all_anchors_grid), 4))
        self.ref_anchors[:, 2:] = np.array(self.all_anchors_grid)
        self.ref_anchors = torch.FloatTensor(self.ref_anchors)
        self.conv = nn.Conv2d(in_channels=in_ch,
                              out_channels=self.n_anchors * (self.n_classes + 5),
                              kernel_size=1, stride=1, padding=0)

    def forward(self, xin, labels=None):
 
        output = self.conv(xin)

        batchsize = output.shape[0]
        fsize = output.shape[2]
        n_ch = 5 + self.n_classes # 5 + number of classes. channels per anchor w/o xywh unceartainties
        dtype = torch.cuda.FloatTensor if xin.is_cuda else torch.FloatTensor # FloatTensor is a tensor of 32-bit floating point values

        output = output.view(batchsize, self.n_anchors, n_ch, fsize, fsize)
        # shape: [batch, anchor, grid_y, grid_x, channels_per_anchor]
        output = output.permute(0, 1, 3, 4, 2)  # .contiguous() returns a view of the input tensor with its dimension permuted

        # logistic activation for xy, obj, cls
        output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(
            output[..., np.r_[:2, 4:n_ch]])

        # calculate pred - xywh obj cls

        x_shift = dtype(np.broadcast_to(
            np.arange(fsize, dtype=np.float32), output.shape[:4]))
        y_shift = dtype(np.broadcast_to(
            np.arange(fsize, dtype=np.float32).reshape(fsize, 1), output.shape[:4]))

        masked_anchors = np.array(self.masked_anchors)

        w_anchors = dtype(np.broadcast_to(np.reshape(
            masked_anchors[:, 0], (1, self.n_anchors, 1, 1)), output.shape[:4]))
        h_anchors = dtype(np.broadcast_to(np.reshape(
            masked_anchors[:, 1], (1, self.n_anchors, 1, 1)), output.shape[:4]))

        pred = output.clone()
        pred[..., 0] += x_shift
        pred[..., 1] += y_shift
        pred[..., 2] = torch.exp(pred[..., 2]) * w_anchors
        pred[..., 3] = torch.exp(pred[..., 3]) * h_anchors

        if labels is None:  # not training  # for testing
            pred[..., :4] *= self.stride
            return pred.view(batchsize, -1, n_ch).data

        pred = pred[..., :4].data # shape: [batch, anchor, grid_y, grid_x, 4(= x, y, w, h)]

        # target assignment
        # torch.zeros returns a tensor filled with the scalar value 0, with the shape defined by the variable argument size.
        tgt_mask = torch.zeros(batchsize, self.n_anchors,
                               fsize, fsize, 4 + self.n_classes).type(dtype)
        obj_mask = torch.ones(batchsize, self.n_anchors,
                              fsize, fsize).type(dtype)
        tgt_scale = torch.zeros(batchsize, self.n_anchors,
                                fsize, fsize, 2).type(dtype)

        target = torch.zeros(batchsize, self.n_anchors,
                             fsize, fsize, n_ch).type(dtype)

        labels = labels.cpu().data
        nlabel = (labels.sum(dim=2) > 0).sum(dim=1)  # number of objects

        truth_x_all = labels[:, :, 1] * fsize
        truth_y_all = labels[:, :, 2] * fsize
        truth_w_all = labels[:, :, 3] * fsize
        truth_h_all = labels[:, :, 4] * fsize
        truth_i_all = truth_x_all.to(torch.int16).numpy()
        truth_j_all = truth_y_all.to(torch.int16).numpy()

        for b in range(batchsize):
            n = int(nlabel[b])
            if n == 0:
                continue
            truth_box = dtype(np.zeros((n, 4)))
            truth_box[:n, 2] = truth_w_all[b, :n]
            truth_box[:n, 3] = truth_h_all[b, :n]
            truth_i = truth_i_all[b, :n]
            truth_j = truth_j_all[b, :n]

            # calculate iou between truth and reference anchors
            #ground truth
            anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors)
            best_n_all = np.argmax(anchor_ious_all, axis=1)
            best_n = best_n_all % 3
            best_n_mask = ((best_n_all == self.anch_mask[0]) | (
                best_n_all == self.anch_mask[1]) | (best_n_all == self.anch_mask[2]))

            truth_box[:n, 0] = truth_x_all[b, :n]
            truth_box[:n, 1] = truth_y_all[b, :n]

            #prediction
            pred_ious = bboxes_iou(
                pred[b].view(-1, 4), truth_box, xyxy=False)
            pred_best_iou, _ = pred_ious.max(dim=1)
            pred_best_iou = (pred_best_iou > self.ignore_thre)
            pred_best_iou = pred_best_iou.view(pred[b].shape[:3])
            # set mask to zero (ignore) if pred matches truth
            obj_mask[b] = 1 - pred_best_iou

            if sum(best_n_mask) == 0:
                continue

            for ti in range(best_n.shape[0]):
                if best_n_mask[ti] == 1:
                    i, j = truth_i[ti], truth_j[ti]
                    a = best_n[ti]
                    obj_mask[b, a, j, i] = 1
                    tgt_mask[b, a, j, i, :] = 1
                    target[b, a, j, i, 0] = truth_x_all[b, ti] - \
                        truth_x_all[b, ti].to(torch.int16).to(torch.float)
                    target[b, a, j, i, 1] = truth_y_all[b, ti] - \
                        truth_y_all[b, ti].to(torch.int16).to(torch.float)
                    target[b, a, j, i, 2] = torch.log(
                        truth_w_all[b, ti] / torch.Tensor(self.masked_anchors)[best_n[ti], 0] + 1e-16)
                    target[b, a, j, i, 3] = torch.log(
                        truth_h_all[b, ti] / torch.Tensor(self.masked_anchors)[best_n[ti], 1] + 1e-16)
                    target[b, a, j, i, 4] = 1
                    target[b, a, j, i, 5 + labels[b, ti,
                                                  0].to(torch.int16).numpy()] = 1
                    tgt_scale[b, a, j, i, :] = torch.sqrt(
                        2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize)

        # loss calculation

        output[..., 4] *= obj_mask
        output[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
        output[..., 2:4] *= tgt_scale

        target[..., 4] *= obj_mask
        target[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
        target[..., 2:4] *= tgt_scale

        bceloss = nn.BCELoss(weight=tgt_scale*tgt_scale,
                             size_average=False)  # weighted BCEloss
        loss_xy = bceloss(output[..., :2], target[..., :2])
        loss_wh = self.l2_loss(output[..., 2:4], target[..., 2:4]) / 2
        loss_obj = self.bce_loss(output[..., 4], target[..., 4])
        loss_cls = self.bce_loss(output[..., 5:], target[..., 5:])
        loss_l2 = self.l2_loss(output, target)

        loss = loss_xy + loss_wh + loss_obj + loss_cls

        return loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2

Here is the code for multibox_loss_gmm.py. As you can see, it implements Gaussian function and NLL_loss function:

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from data import coco as cfg
from ..box_utils import match, log_sum_exp
import math

def Gaussian(y, mu, var):
    eps = 0.3
    result = (y-mu)/var
    result = (result**2)/2*(-1)
    exp = torch.exp(result)
    result = exp/(math.sqrt(2*math.pi))/(var + eps)

    return result

def NLL_loss(bbox_gt, bbox_pred, bbox_var):
        bbox_var = torch.sigmoid(bbox_var)
        prob = Gaussian(bbox_gt, bbox_pred, bbox_var)

        return prob

class MultiBoxLoss_GMM(nn.Module):
    """SSD Weighted Loss Function
    Compute Targets:
        1) Produce Confidence Target Indices by matching  ground truth boxes
           with (default) 'priorboxes' that have jaccard index > threshold parameter
           (default threshold: 0.5).
        2) Produce localization target by 'encoding' variance into offsets of ground
           truth boxes and their matched  'priorboxes'.
        3) Hard negative mining to filter the excessive number of negative examples
           that comes with using a large number of default bounding boxes.
           (default negative:positive ratio 3:1)
    """

    def __init__(self, num_classes, overlap_thresh, prior_for_matching,
                 bkg_label, neg_mining, neg_pos, neg_overlap, encode_target,
                 use_gpu=True, cls_type='Type-1'):
        super(MultiBoxLoss_GMM, self).__init__()
        self.use_gpu = use_gpu
        self.num_classes = num_classes
        self.threshold = overlap_thresh
        self.background_label = bkg_label
        self.encode_target = encode_target
        self.use_prior_for_matching = prior_for_matching
        self.do_neg_mining = neg_mining
        self.negpos_ratio = neg_pos
        self.neg_overlap = neg_overlap
        self.variance = cfg['variance']
        self.cls_type = cls_type

    def forward(self, predictions, targets):
        priors, loc_mu_1, loc_var_1, loc_pi_1, loc_mu_2, loc_var_2, loc_pi_2, \
        loc_mu_3, loc_var_3, loc_pi_3, loc_mu_4, loc_var_4, loc_pi_4, \
        conf_mu_1, conf_var_1, conf_pi_1, conf_mu_2, conf_var_2, conf_pi_2, \
        conf_mu_3, conf_var_3, conf_pi_3, conf_mu_4, conf_var_4, conf_pi_4 = predictions

        num = loc_mu_1.size(0)
        priors = priors[:loc_mu_1.size(1), :]
        num_priors = (priors.size(0))
        num_classes = self.num_classes

        # match priors (default boxes) and ground truth boxes
        loc_t = torch.Tensor(num, num_priors, 4)
        conf_t = torch.LongTensor(num, num_priors)
        for idx in range(num):
            truths = targets[idx][:, :-1].data
            labels = targets[idx][:, -1].data
            defaults = priors.data
            match(self.threshold,
                  truths,
                  defaults,
                  self.variance,
                  labels,
                  loc_t,
                  conf_t,
                  idx)
        if self.use_gpu:
            loc_t = loc_t.cuda()
            conf_t = conf_t.cuda()
        # wrap targets
        loc_t = Variable(loc_t, requires_grad=False)
        conf_t = Variable(conf_t, requires_grad=False)

        pos = conf_t > 0
        num_pos = pos.sum(dim=1, keepdim=True)

        pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_mu_1)
        loc_mu_1_ = loc_mu_1[pos_idx].view(-1, 4)
        loc_mu_2_ = loc_mu_2[pos_idx].view(-1, 4)
        loc_mu_3_ = loc_mu_3[pos_idx].view(-1, 4)
        loc_mu_4_ = loc_mu_4[pos_idx].view(-1, 4)

        loc_t = loc_t[pos_idx].view(-1, 4)

        # localization loss
        loss_l_1 = NLL_loss(loc_t, loc_mu_1_, loc_var_1[pos_idx].view(-1, 4))
        loss_l_2 = NLL_loss(loc_t, loc_mu_2_, loc_var_2[pos_idx].view(-1, 4))
        loss_l_3 = NLL_loss(loc_t, loc_mu_3_, loc_var_3[pos_idx].view(-1, 4))
        loss_l_4 = NLL_loss(loc_t, loc_mu_4_, loc_var_4[pos_idx].view(-1, 4))

        loc_pi_1_ = loc_pi_1[pos_idx].view(-1, 4)
        loc_pi_2_ = loc_pi_2[pos_idx].view(-1, 4)
        loc_pi_3_ = loc_pi_3[pos_idx].view(-1, 4)
        loc_pi_4_ = loc_pi_4[pos_idx].view(-1, 4)

        pi_all = torch.stack([
                    loc_pi_1_.reshape(-1),
                    loc_pi_2_.reshape(-1),
                    loc_pi_3_.reshape(-1),
                    loc_pi_4_.reshape(-1)
                    ])
        pi_all = pi_all.transpose(0,1)
        pi_all = (torch.softmax(pi_all, dim=1)).transpose(0,1).reshape(-1)
        (
            loc_pi_1_,
            loc_pi_2_,
            loc_pi_3_,
            loc_pi_4_
        ) = torch.split(pi_all, loc_pi_1_.reshape(-1).size(0), dim=0)
        loc_pi_1_ = loc_pi_1_.view(-1, 4)
        loc_pi_2_ = loc_pi_2_.view(-1, 4)
        loc_pi_3_ = loc_pi_3_.view(-1, 4)
        loc_pi_4_ = loc_pi_4_.view(-1, 4)

        _loss_l = (
            loc_pi_1_*loss_l_1 +
            loc_pi_2_*loss_l_2 +
            loc_pi_3_*loss_l_3 +
            loc_pi_4_*loss_l_4
        )

        epsi = 10**-9
        # balance parameter
        balance = 2.0
        loss_l = -torch.log(_loss_l + epsi)/balance
        loss_l = loss_l.sum()

        if self.cls_type == 'Type-1':
            # Classification loss (Type-1)
            conf_pi_1_ = conf_pi_1.view(-1, 1)
            conf_pi_2_ = conf_pi_2.view(-1, 1)
            conf_pi_3_ = conf_pi_3.view(-1, 1)
            conf_pi_4_ = conf_pi_4.view(-1, 1)

            conf_pi_all = torch.stack([
                            conf_pi_1_.reshape(-1),
                            conf_pi_2_.reshape(-1),
                            conf_pi_3_.reshape(-1),
                            conf_pi_4_.reshape(-1)
                            ])
            conf_pi_all = conf_pi_all.transpose(0,1)
            conf_pi_all = (torch.softmax(conf_pi_all, dim=1)).transpose(0,1).reshape(-1)
            (
                conf_pi_1_,
                conf_pi_2_,
                conf_pi_3_,
                conf_pi_4_
            ) = torch.split(conf_pi_all, conf_pi_1_.reshape(-1).size(0), dim=0)
            conf_pi_1_ = conf_pi_1_.view(conf_pi_1.size(0), -1)
            conf_pi_2_ = conf_pi_2_.view(conf_pi_2.size(0), -1)
            conf_pi_3_ = conf_pi_3_.view(conf_pi_3.size(0), -1)
            conf_pi_4_ = conf_pi_4_.view(conf_pi_4.size(0), -1)

            conf_var_1 = torch.sigmoid(conf_var_1)
            conf_var_2 = torch.sigmoid(conf_var_2)
            conf_var_3 = torch.sigmoid(conf_var_3)
            conf_var_4 = torch.sigmoid(conf_var_4)

            rand_val_1 = torch.randn(conf_var_1.size(0), conf_var_1.size(1), conf_var_1.size(2))
            rand_val_2 = torch.randn(conf_var_2.size(0), conf_var_2.size(1), conf_var_2.size(2))
            rand_val_3 = torch.randn(conf_var_3.size(0), conf_var_3.size(1), conf_var_3.size(2))
            rand_val_4 = torch.randn(conf_var_4.size(0), conf_var_4.size(1), conf_var_4.size(2))

            batch_conf_1 = (conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1).view(-1, self.num_classes)
            batch_conf_2 = (conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2).view(-1, self.num_classes)
            batch_conf_3 = (conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3).view(-1, self.num_classes)
            batch_conf_4 = (conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4).view(-1, self.num_classes)

            loss_c_1 = log_sum_exp(batch_conf_1) - batch_conf_1.gather(1, conf_t.view(-1, 1))
            loss_c_2 = log_sum_exp(batch_conf_2) - batch_conf_2.gather(1, conf_t.view(-1, 1))
            loss_c_3 = log_sum_exp(batch_conf_3) - batch_conf_3.gather(1, conf_t.view(-1, 1))
            loss_c_4 = log_sum_exp(batch_conf_4) - batch_conf_4.gather(1, conf_t.view(-1, 1))

            loss_c = (
                loss_c_1 * conf_pi_1_.view(-1, 1) +
                loss_c_2 * conf_pi_2_.view(-1, 1) +
                loss_c_3 * conf_pi_3_.view(-1, 1) +
                loss_c_4 * conf_pi_4_.view(-1, 1)
            )
            loss_c = loss_c.view(pos.size()[0], pos.size()[1])
            loss_c[pos] = 0  # filter out pos boxes for now  : true -> zero
            loss_c = loss_c.view(num, -1)

            _, loss_idx = loss_c.sort(1, descending=True)
            _, idx_rank = loss_idx.sort(1)
            num_pos = pos.long().sum(1, keepdim=True)
            num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
            neg = idx_rank < num_neg.expand_as(idx_rank)

            # Confidence Loss Including Positive and Negative Examples
            pos_idx = pos.unsqueeze(2).expand_as(conf_mu_1)
            neg_idx = neg.unsqueeze(2).expand_as(conf_mu_1)

            batch_conf_1_ = conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1
            batch_conf_2_ = conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2
            batch_conf_3_ = conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3
            batch_conf_4_ = conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4

            conf_pred_1 = batch_conf_1_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_2 = batch_conf_2_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_3 = batch_conf_3_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_4 = batch_conf_4_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)

            targets_weighted = conf_t[(pos+neg).gt(0)]

            loss_c_1 = log_sum_exp(conf_pred_1) - conf_pred_1.gather(1, targets_weighted.view(-1, 1))
            loss_c_2 = log_sum_exp(conf_pred_2) - conf_pred_2.gather(1, targets_weighted.view(-1, 1))
            loss_c_3 = log_sum_exp(conf_pred_3) - conf_pred_3.gather(1, targets_weighted.view(-1, 1))
            loss_c_4 = log_sum_exp(conf_pred_4) - conf_pred_4.gather(1, targets_weighted.view(-1, 1))

            _conf_pi_1 = conf_pi_1_[(pos+neg).gt(0)]
            _conf_pi_2 = conf_pi_2_[(pos+neg).gt(0)]
            _conf_pi_3 = conf_pi_3_[(pos+neg).gt(0)]
            _conf_pi_4 = conf_pi_4_[(pos+neg).gt(0)]

            loss_c = (
                loss_c_1 * _conf_pi_1.view(-1, 1) +
                loss_c_2 * _conf_pi_2.view(-1, 1) +
                loss_c_3 * _conf_pi_3.view(-1, 1) +
                loss_c_4 * _conf_pi_4.view(-1, 1)
            )
            loss_c = loss_c.sum()

        else:
            # Classification loss (Type-2)
            # more details are in our supplementary material
            conf_pi_1_ = conf_pi_1.view(-1, 1)
            conf_pi_2_ = conf_pi_2.view(-1, 1)
            conf_pi_3_ = conf_pi_3.view(-1, 1)
            conf_pi_4_ = conf_pi_4.view(-1, 1)

            conf_pi_all = torch.stack([
                            conf_pi_1_.reshape(-1),
                            conf_pi_2_.reshape(-1),
                            conf_pi_3_.reshape(-1),
                            conf_pi_4_.reshape(-1)
                            ])
            conf_pi_all = conf_pi_all.transpose(0,1)
            conf_pi_all = (torch.softmax(conf_pi_all, dim=1)).transpose(0,1).reshape(-1)
            (
                conf_pi_1_,
                conf_pi_2_,
                conf_pi_3_,
                conf_pi_4_
            ) = torch.split(conf_pi_all, conf_pi_1_.reshape(-1).size(0), dim=0)
            conf_pi_1_ = conf_pi_1_.view(conf_pi_1.size(0), -1)
            conf_pi_2_ = conf_pi_2_.view(conf_pi_2.size(0), -1)
            conf_pi_3_ = conf_pi_3_.view(conf_pi_3.size(0), -1)
            conf_pi_4_ = conf_pi_4_.view(conf_pi_4.size(0), -1)

            conf_var_1 = torch.sigmoid(conf_var_1)
            conf_var_2 = torch.sigmoid(conf_var_2)
            conf_var_3 = torch.sigmoid(conf_var_3)
            conf_var_4 = torch.sigmoid(conf_var_4)

            rand_val_1 = torch.randn(conf_var_1.size(0), conf_var_1.size(1), conf_var_1.size(2))
            rand_val_2 = torch.randn(conf_var_2.size(0), conf_var_2.size(1), conf_var_2.size(2))
            rand_val_3 = torch.randn(conf_var_3.size(0), conf_var_3.size(1), conf_var_3.size(2))
            rand_val_4 = torch.randn(conf_var_4.size(0), conf_var_4.size(1), conf_var_4.size(2))

            batch_conf_1 = (conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1).view(-1, self.num_classes)
            batch_conf_2 = (conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2).view(-1, self.num_classes)
            batch_conf_3 = (conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3).view(-1, self.num_classes)
            batch_conf_4 = (conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4).view(-1, self.num_classes)

            soft_max = nn.Softmax(dim=1)

            epsi = 10**-9
            weighted_softmax_out = (
                        soft_max(batch_conf_1)*conf_pi_1_.view(-1, 1) +
                        soft_max(batch_conf_2)*conf_pi_2_.view(-1, 1) +
                        soft_max(batch_conf_3)*conf_pi_3_.view(-1, 1) +
                        soft_max(batch_conf_4)*conf_pi_4_.view(-1, 1)
            )
            softmax_out_log = -torch.log(weighted_softmax_out+epsi)
            loss_c = softmax_out_log.gather(1, conf_t.view(-1,1))

            loss_c = loss_c.view(pos.size()[0], pos.size()[1])
            loss_c[pos] = 0  # filter out pos boxes for now  : true -> zero
            loss_c = loss_c.view(num, -1)

            _, loss_idx = loss_c.sort(1, descending=True)
            _, idx_rank = loss_idx.sort(1)
            num_pos = pos.long().sum(1, keepdim=True)
            num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
            neg = idx_rank < num_neg.expand_as(idx_rank)

            # Confidence Loss Including Positive and Negative Examples
            pos_idx = pos.unsqueeze(2).expand_as(conf_mu_1)
            neg_idx = neg.unsqueeze(2).expand_as(conf_mu_1)

            batch_conf_1_ = conf_mu_1+torch.sqrt(conf_var_1)*rand_val_1
            batch_conf_2_ = conf_mu_2+torch.sqrt(conf_var_2)*rand_val_2
            batch_conf_3_ = conf_mu_3+torch.sqrt(conf_var_3)*rand_val_3
            batch_conf_4_ = conf_mu_4+torch.sqrt(conf_var_4)*rand_val_4

            conf_pred_1 = batch_conf_1_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_2 = batch_conf_2_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_3 = batch_conf_3_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)
            conf_pred_4 = batch_conf_4_[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes)

            targets_weighted = conf_t[(pos+neg).gt(0)]

            _conf_pi_1 = conf_pi_1_[(pos+neg).gt(0)]
            _conf_pi_2 = conf_pi_2_[(pos+neg).gt(0)]
            _conf_pi_3 = conf_pi_3_[(pos+neg).gt(0)]
            _conf_pi_4 = conf_pi_4_[(pos+neg).gt(0)]

            weighted_softmax_out = (
                        soft_max(conf_pred_1)*_conf_pi_1.view(-1, 1) +
                        soft_max(conf_pred_2)*_conf_pi_2.view(-1, 1) +
                        soft_max(conf_pred_3)*_conf_pi_3.view(-1, 1) +
                        soft_max(conf_pred_4)*_conf_pi_4.view(-1, 1)
            )
            softmax_out_log = -torch.log(weighted_softmax_out+epsi)
            loss_c = softmax_out_log.gather(1, targets_weighted.view(-1,1))
            loss_c = loss_c.sum()

        N = num_pos.data.sum()
        loss_l /= N
        loss_c /= N
        return loss_l, loss_c

0

There are 0 answers