Skip to content

Non trivial slice assignments and tensor manipulation #935

@vladimirmujagic

Description

@vladimirmujagic

Hello,

I am trying to port Retinaface (face and landmark detection in pytorch) to rust and was just wondering if you support operations similar to

def decode(loc, priors, variances):
    """Decode locations from predictions using priors to undo
    the encoding we did for offset regression at train time.
    Args:
        loc (tensor): location predictions for loc layers,
            Shape: [num_priors,4]
        priors (tensor): Prior boxes in center-offset form.
            Shape: [num_priors,4].
        variances: (list[float]) Variances of priorboxes
    Return:
        decoded bounding box predictions
    """

    boxes = torch.cat((
        priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
        priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
    boxes[:, :2] -= boxes[:, 2:] / 2
    boxes[:, 2:] += boxes[:, :2]
    return boxes

I couldn't find similar functionalities in your library to implement operations like boxes[:, :2] -= boxes[:, 2:] / 2

This is my current implementation which compiles but is still not tested for correctness and is not optimized.

pub fn decode(
    loc: ArrayBase<OwnedRepr<f32>, Dim<[usize; 2]>>,
    priors: ArrayBase<OwnedRepr<f32>, Dim<[usize; 2]>>,
    variances: &Vec<f32>
) -> Result<ArrayBase<OwnedRepr<f32>, Dim<[usize; 2]>>, Error>
{
    let priors_to = priors.clone().slice_move(s![.., ..2]);
    let priors_from = priors.slice_move(s![.., 2..]);

    let mut loc_to = loc.clone().slice_move(s![.., ..2]);
    let mut loc_from = loc.slice_move(s![.., 2..]);
    loc_to = loc_to.mapv(|v:f32| v * variances[0]);
    loc_from = loc_from.mapv(|v:f32| (v * variances[1]).exp());
    
    let a =
        priors_to + loc_to * priors_from.clone();
    let b =  
        priors_from * loc_from;
    let mut boxes = stack![Axis(1), a, b];

    let boxes_to = boxes.clone().slice_move(s![.., ..2]);
    let mut boxes_from = boxes.clone().slice_move(s![.., 2..]);
    boxes_from = boxes_from.mapv(|v: f32| v / 2.0);

    for i in 0..boxes.shape()[0] {
        for j in 0..2 {
            boxes[[i, j]] -= boxes_from[[i, j]];
        }
    }

    for i in 0..boxes.shape()[0] {
        for j in 0..2 {
            boxes[[i, j]] += boxes_to[[i, j]];
        }
    }

    Ok(boxes)
}

Also is it possible to stack multiple tensors like:

    landms = torch.cat((priors[:, :2] + pre[:, :2] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 2:4] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 4:6] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 6:8] * variances[0] * priors[:, 2:],
                        priors[:, :2] + pre[:, 8:10] * variances[0] * priors[:, 2:],
                        ), dim=1)

Or i have to go 2 by 2 and produce same results?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions