What does `tf.strided_slice()` do?

前端 未结 5 1857
隐瞒了意图╮
隐瞒了意图╮ 2020-12-29 05:24

I am wondering what tf.strided_slice() operator actually does.
The doc says,

To a first order, this operation extracts a slice of siz

5条回答
  •  伪装坚强ぢ
    2020-12-29 06:20

    The mistake in your argument is the fact that you are directly adding the lists strides and begin element by element. This will make the function a lot less useful. Instead, it increments the begin list one dimension at a time, starting from the last dimension.

    Let's solve the first example part by part. begin = [1, 0, 0] and end = [2, 1, 3]. Also, all the strides are 1. Work your way backwards, from the last dimension.

    Start with element [1,0,0]. Now increase the last dimension only by its stride amount, giving you [1,0,1]. Keep doing this until you reach the limit. Something like [1,0,2], [1,0,3] (end of the loop). Now in your next iteration, start by incrementing the second to last dimension and resetting the last dimension, [1,1,0]. Here the second to last dimension is equal to end[1], so move to the first dimension (third to last) and reset the rest, giving you [2,0,0]. Again you are at the first dimension's limit, so quit the loop.

    The following code is a recursive implementation of what I described above,

    # Assume global `begin`, `end` and `stride`
    def iterate(active, dim):
        if dim == len(begin):
            # last dimension incremented, work on the new matrix
            # Note that `active` and `begin` are lists
            new_matrix[active - begin] = old_matrix[active]
        else:
            for i in range(begin[dim], end[dim], stride[dim]):
                new_active = copy(active)
                new_active[dim] = i
                iterate(new_active, dim + 1)
    
    iterate(begin, 0)
    

提交回复
热议问题