Why is this matplotlib code giving me a weird exception? I\'m going for two rows of plots. The top row is supposed to show true vs. pred and the bottom row is supposed to show p
If len(X)
is >1, axes
will be a 2D array of AxesSubplot
instances. So when you loop over axes
, you actually get a slice along one dimension of the axes
array.
To overcome this, you could use axes.flat
:
for ax,_x in zip(axes.flat,X):
Also if you are trying to plot all these on one figure, you don't need to call plt.subplots
twice, as that will create two figures.
It may be easier to index the axes
array like this:
yy = func(*X)
fig, axes = plt.subplots(2, len(X))
for i,_x in enumerate(X):
axes[0, i].plot(_x, y, 'b.')
axes[0, i].plot(_x, yy, 'r.')
axes[1, i].plot(_x, yy/y-1, 'r.')
plt.show()