The behavior of the numpy rollaxis function confuses me. The documentation says:
Roll the specified axis backwards, until it lies in a given position.
a = np.arange(1*2*3*4*5).reshape(1,2,3,4,5)
np.rollaxis(a,axis,start)
'axis' is the index of the axis to be moved starting from 0. In my example the axis at position 0 is 1.
'start' is the index (again starting at 0) of the axis that we would like to move our selected axis before.
So, if start=2, the axis at position 2 is 3, therefor the selected axis will be before the 3.
Examples:
>>> np.rollaxis(a,0,2).shape # the 1 will be before the 3.
(2, 1, 3, 4, 5)
>>> np.rollaxis(a,0,3).shape # the 1 will be before the 4.
(2, 3, 1, 4, 5)
>>> np.rollaxis(a,1,2).shape # the 2 will be before the 3.
(1, 2, 3, 4, 5)
>>> np.rollaxis(a,1,3).shape # the 2 will be before the 4.
(1, 3, 2, 4, 5)
So, after the roll the number at axis before the roll will be placed just before the number at start before the roll.
If you think of rollaxis like this it is very simple and makes perfect sense, though it's strange that they chose to design it this way.
So, what happens when axis and start are the same? Well, you obviously can't put a number before itself, so the number doesn't move and the instruction becomes a no-op.
Examples:
>>> np.rollaxis(a,1,1).shape # the 2 can't be moved to before the 2.
(1, 2, 3, 4, 5)
>>> np.rollaxis(a,2, 2).shape # the 3 can't be moved to before the 3.
(1, 2, 3, 4, 5)
How about moving the axis to the end? Well, there's no number after the end, but you can specify start as after the end.
Example:
>>> np.rollaxis(a,1,5).shape # the 2 will be moved to the end.
(1, 3, 4, 5, 2)
>>> np.rollaxis(a,2,5).shape # the 3 will be moved to the end.
(1, 2, 4, 5, 3)
>>> np.rollaxis(a,4,5).shape # the 5 is already at the end.
(1, 2, 3, 4, 5)
Much of the confusion results from our human intuition - how we think about moving an axis. We could specify a number of roll steps (back or forth 2 steps), or a location in the final shape tuple, or location relative to the original shape.
I think the key to understanding rollaxis
is to focus on the slots in the original shape. The most general statement that I can come up with is:
Roll a.shape[axis]
to the position before a.shape[start]
before
in this context means the same as in list insert()
. So it is possible to insert before the end.
The basic action of rollaxis
is:
axes = list(range(0, n))
axes.remove(axis)
axes.insert(start, axis)
return a.transpose(axes)
If axis<start
, then start-=1
to account for the remove
action.
Negative values get +=n
, so rollaxis(a,-2,-3)
is the same as np.rollaxis(a,2,1)
. e.g. a.shape[-3]==a.shape[1]
. List insert
also allows a negative insert position, but rollaxis
doesn't make use of that feature.
So the keys are understanding that remove/insert
pair of actions, and understanding transpose(x)
.
I suspect rollaxis
is intended to be a more intuitive version of transpose
. Whether it achieves that or not is another question.
You suggest either omitting the start-=1
or applying across the board
Omitting it doesn't change your 2 examples. It only affects the rollaxis(a,1,4)
case, and axes.insert(4,1)
is the same as axes.insert(3,1)
when axes
is [0,2,3]
. The 1
is still placed at the end. Changing that test a bit:
np.rollaxis(a,1,3).shape
# (3, 5, 4, 6) # a.shape[1](4) placed before a.shape[3](6)
without the -=1
# transpose axes == [0, 2, 3, 1]
# (3, 5, 6, 4) # the 4 is placed at the end, after 6
If instead -=1
applies always
np.rollaxis(a,3,1).shape
# (3, 6, 4, 5)
becomes
(6, 3, 4, 5)
now the 6
is before the 3
, which was the original a.shape[0]
. After the roll 3
is the the a.shape[1]
. But that's a different roll
specification.
It comes down to how start
is defined. Is a postion in the original order, or a position in the returned order?
If you prefer to think of start
as an index position in the final shape, wouldn't it be simpler to drop the before
part and just say 'move axis
to dest
slot'?
myroll(a, axis=3, dest=0) => (np.transpose(a,[3,0,1,2])
myroll(a, axis=1, dest=3) => (np.transpose(a,[0,2,3,1])
Simply dropping the -=1
test might do the trick (omiting the handling of negative numbers and boundaries)
def myroll(a,axis,dest):
x=list(range(a.ndim))
x.remove(axis)
x.insert(dest,axis)
return a.transpose(x)
NumPy v1.11 and newer includes a new function, moveaxis, that I recommend using instead of rollaxis
(disclaimer: I wrote it!). The source axis always ends up at the destination, without any funny off-by-one issues depending on whether start
is greater or less than end
:
import numpy as np
x = np.zeros((1, 2, 3, 4, 5))
for i in range(5):
print(np.moveaxis(x, 3, i).shape)
Results in:
(4, 1, 2, 3, 5)
(1, 4, 2, 3, 5)
(1, 2, 4, 3, 5)
(1, 2, 3, 4, 5)
(1, 2, 3, 5, 4)