问题
data_dict = {'x': {(0, 0): 3760.448435678077,
(0, 12): 4851.68102541007,
(0, 2226): 5297.61518907981,
(0, 2479): 4812.134249142693,
(0, 2724): 4756.5295525777465,
(0, 3724): 3760.448435678077,
(0, 4598): 4763.265306122449,
(0, 4599): 5155.102040816327,
(0, 4600): 5191.836734693878,
(1, 0): 3822.238314568112,
(1, 12): 4856.1910324326145,
(1, 2226): 5304.678983022428,
(1, 2479): 4815.435125468252,
(1, 2724): 4761.889691080804,
(1, 3724): 3768.2889580569245,
(1, 4598): 4768.908833716798,
(1, 4599): 5159.900248610219,
(1, 4600): 5198.053973405109,
(2, 0): 3880.710643551325,
(2, 12): 4860.815600807341,
(2, 2226): 5311.463202354179,
(2, 2479): 4818.773880246848,
(2, 2724): 4767.173347353905,
(2, 3724): 3775.726731574509,
(2, 4598): 4774.4490391107765,
(2, 4599): 5164.871993247027,
(2, 4600): 5203.992167369707,
(3, 0): 3936.0658098882136,
(3, 12): 4865.552525110625,
(3, 2226): 5317.975399527148,
(3, 2479): 4822.152977310737,
(3, 2724): 4772.381182203056,
(3, 3724): 3782.7732491325282,
(3, 4598): 4779.888739700325,
(3, 4599): 5170.010331730589,
(3, 4600): 5209.661736027094,
(4, 0): 3988.491290089178,
(4, 12): 4870.399599918841,
(4, 2226): 5324.223126993423,
(4, 2479): 4825.574880492175,
(4, 2724): 4777.513856434266,
(4, 3724): 3789.4400036326792,
(4, 4598): 4785.230752881375,
(4, 4599): 5175.308321064745,
(4, 4600): 5215.073098816687,
(5, 0): 4038.1625164006414,
(5, 12): 4875.354619808369,
(5, 2226): 5330.2139372050915,
(5, 2479): 4829.04205362342,
(5, 2724): 4782.572030853543,
(5, 3724): 3795.7384879766646,
(5, 4598): 4790.477896049872,
(5, 4599): 5180.7590182533295,
(5, 4600): 5220.2366751779045,
(6, 0): 4085.2436834766995,
(6, 12): 4880.415379355583,
(6, 2226): 5335.955382614236,
(6, 2479): 4832.55696053673,
(6, 2724): 4787.5563662668965,
(6, 3724): 3801.6801950661807,
(6, 4598): 4795.632986601749,
(6, 4599): 5186.355480300186,
(6, 4600): 5225.16288455017,
(7, 0): 4129.888499451394,
(7, 12): 4885.5796731368655,
(7, 2226): 5341.4550156729465,
(7, 2479): 4836.122065064363,
(7, 2724): 4792.4675234803335,
(7, 3724): 3807.2766178029274,
(7, 4598): 4800.698841932945,
(7, 4599): 5192.090764209151,
(7, 4600): 5229.8621463729005,
(8, 0): 4172.2408853249335,
(8, 12): 4890.845295728588,
(8, 2226): 5346.720388833307,
(8, 2479): 4839.739831038576,
(8, 2724): 4797.306163299865,
(8, 3724): 3812.539249088603,
(8, 4598): 4805.678279439399,
(8, 4599): 5197.9579269840615,
(8, 4600): 5234.344880085516,
(9, 0): 4212.43562629731,
(9, 12): 4896.210041707129,
(9, 2226): 5351.759054547402,
(9, 2479): 4843.412722291625,
(9, 2724): 4802.072946531498,
(9, 3724): 3817.479581824906,
(9, 4598): 4810.574116517045,
(9, 4599): 5203.950025628757,
(9, 4600): 5238.621505127434,
(10, 0): 4250.598978423163,
(10, 12): 4901.671705648866,
(10, 2226): 5356.578565267323,
(10, 2479): 4847.1432026557695,
(10, 2724): 4806.7685339812415,
(10, 3724): 3822.1091089135375,
(10, 4598): 4815.389170561825,
(10, 4599): 5210.060117147079,
(10, 4600): 5242.702440938076,
(11, 0): 4286.849233720921,
(11, 12): 4907.228082130176,
(11, 2226): 5361.186473445152,
(11, 2479): 4850.933735963267,
(11, 2724): 4811.393586455103,
(11, 3724): 3826.4393232561943,
(11, 4598): 4820.126258969674,
(11, 4599): 5216.281258542863,
(11, 4600): 5246.5981069568625,
(12, 0): 4321.297246645838,
(12, 12): 4912.876965727434,
(12, 2226): 5365.590331532978,
(12, 2479): 4854.786786046375,
(12, 2724): 4815.948764759092,
(12, 3724): 3830.481717754576,
(12, 4598): 4824.788199136532,
(12, 4599): 5222.606506819949,
(12, 4600): 5250.318922623211,
(13, 0): 4354.046924629284,
(13, 12): 4918.6161510170205,
(13, 2226): 5369.797691982883,
(13, 2479): 4858.70481673735,
(13, 2724): 4820.434729699218,
(13, 3724): 3834.247785310383,
(13, 4598): 4829.377808458337,
(13, 4599): 5229.028918982174,
(13, 4600): 5253.875307376542,
(14, 0): 4385.195685194348,
(14, 12): 4924.443432575308,
(14, 2226): 5373.816107246958,
(14, 2479): 4862.690291868448,
(14, 2724): 4824.852142081489,
(14, 3724): 3837.7490188253105,
(14, 4598): 4833.897904331024,
(14, 4599): 5235.541552033379,
(14, 4600): 5257.277680656276,
(15, 0): 4414.834881979362,
(15, 12): 4930.356604978678,
(15, 2226): 5377.653129777288,
(15, 2479): 4866.74567527193,
(15, 2724): 4829.201662711913,
(15, 3724): 3840.9969112010617,
(15, 4598): 4838.351304150532,
(15, 4599): 5242.137462977402,
(15, 4600): 5260.53646190183,
(16, 0): 4443.050201835423,
(16, 12): 4936.353462803505,
(16, 2226): 5381.316312025957,
(16, 2479): 4870.873430780051,
(16, 2724): 4833.483952396497,
(16, 3724): 3844.002955339333,
(16, 4598): 4842.740825312798,
(16, 4599): 5248.809708818081,
(16, 4600): 5263.662070552626,
(17, 0): 4469.92203501027,
(17, 12): 4942.4318006261665,
(17, 2226): 5384.813206445053,
(17, 2479): 4875.07602222507,
(17, 2724): 4837.699671941253,
(17, 3724): 3846.7786441418243,
(17, 4598): 4847.069285213763,
(17, 4599): 5255.551346559254,
(17, 4600): 5266.664926048083,
(18, 0): 4495.525820288381,
(18, 12): 4948.589413023038,
(18, 2226): 5388.151365486662,
(18, 2479): 4879.3559134392435,
(18, 2724): 4841.849482152186,
(18, 3724): 3849.3354705102342,
(18, 4598): 4851.339501249362,
(18, 4599): 5262.355433204761,
(18, 4600): 5269.555447827619,
(19, 0): 4518.893324127626,
(19, 12): 4954.824094570498,
(19, 2226): 5391.338341602872,
(19, 2479): 4883.71556825483,
(19, 2724): 4845.934043835307,
(19, 3724): 3851.6849273462612,
(19, 4598): 4855.554290815534,
(19, 4599): 5269.21502575844,
(19, 4600): 5272.344055330656}}
With the data above I want to make an animated swarm plot with matplotlib
and moviepy
. However, with the following code with every frame I get additional points, but with preserved old ones.
import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde
from matplotlib import pyplot as plt
from moviepy.editor import VideoClip
from moviepy.video.io.bindings import mplfig_to_npimage
fps = 10
df = pd.DataFrame(data_dict)
fig, ax = plt.subplots(1, 1)
def swarm_plot(x):
kde = gaussian_kde(x)
density = kde(x) # estimate the local density at each datapoint
# ax.clear()
jitter = np.random.rand(*x.shape) - .5
# scale the jitter by the KDE estimate and add it to the centre x-coordinate
y = 1 + (density * jitter * 1000 * 2)
ax.scatter(x, y, s = 30, c = 'g')
# plt.axis('off')
return fig
def draw_swarmplot(t):
f = int(t * fps)
fig, ax = plt.subplots(1, 1)
dff = df.loc[f]
return mplfig_to_npimage(swarm_plot(dff['x']))
anim = VideoClip(lambda x: draw_swarmplot(x), duration=2)
anim.to_videofile('swarmplot.mp4', fps=fps)
As a result, all points are cumulated in the animation. I believe it is because of matplotlib
fig
and ax
objects used incorrectly. However, in draw_swarmplot
function I reset fig
and ax
objects after each iteration. Nevertheless, I still need to initialise fig
and ax
outside both function to not get an error regarding ax
object. Therefore, my question is how both fig
and ax
should be referenced and what am I missing that makes my code not working as intended?
回答1:
The scoping of your fig
and ax
variables is subject to the Variable Scope and Crossing Boundaries sections of the Variables and Scope documentation. Specifically relevant,
When we use the assignment operator (=) inside a function, its default behaviour is to create a new local variable – unless a variable with the same name is already defined in the local scope.
Note that the caveat "unless a variable with the same name is already defined" is in fact limited to local variables. As is clarified further in the example,
a = 0
def my_function():
a = 3
print(a)
my_function()
print(a)
which will output
3
0
This is because
By default, the assignment statement creates variables in the local scope. So the assignment inside the function does not modify the global variable [...]
If you want to modify a global variable from within a function, use the keyword global
, as the answer from @iliar says.
However this is not advised -
Note that it is usually very bad practice to access global variables from inside functions, and even worse practice to modify them. This makes it difficult to arrange our program into logically encapsulated parts which do not affect each other in unexpected ways. If a function needs to access some external value, we should pass the value into the function as a parameter. [...]
Two alternatives would be
- Implement this as a
class
- Pass
fig
andax
intodraw_swarmplot()
.
The former
class SwarmPlot:
def __init__(self):
self.fig, self.ax = plt.subplots(1, 1)
anim = VideoClip(lambda x: self.draw_swarmplot(x, self.fig, self.ax), duration=2)
anim.to_videofile('swarmplot.mp4', fps=fps)
def swarm_plot(self, x):
kde = gaussian_kde(x)
density = kde(x) # estimate the local density at each datapoint
jitter = np.random.rand(*x.shape) - .5
y = 1 + (density * jitter * 1000 * 2)
self.ax.scatter(x, y, s = 30, c = 'g')
return self.fig
def draw_swarmplot(self, t, fig, ax):
self.fig, self.ax = plt.subplots(1, 1)
f = int(t * fps)
dff = df.loc[f]
return mplfig_to_npimage(self.swarm_plot(dff['x']))
S = SwarmPlot()
The latter
def draw_swarmplot(t, fig, ax):
fig, ax = plt.subplots(1, 1)
f = int(t * fps)
dff = df.loc[f]
return mplfig_to_npimage(swarm_plot(dff['x']))
anim = VideoClip(lambda x: draw_swarmplot(x, fig, ax), duration=2)
For a simple case such as this I might be partial to the latter, but in more complex cases the former might be preferable. Both appear to correctly generate the desired output:
Of course all this could be avoided if you didn't overwrite the figure
and axis
instances in each iteration by instead using one of the clearing functions:
- plt.cla() to clear the current axis
- plt.clf() to clear the current figure
fig.clear()
to clear the figurefig
(equivalent toplt.clf()
iffig
is the current figure)- ax.clear() to clear the axis
ax
(equivalent toplt.cla()
ifax
is the current axis)
ax.clear()
or plt.cla()
may be the most appropriate in this case and would be used as follows
fig, ax = plt.subplots(1, 1)
def swarm_plot(x):
kde = gaussian_kde(x)
density = kde(x) # estimate the local density at each datapoint
jitter = np.random.rand(*x.shape) - .5
y = 1 + (density * jitter * 1000 * 2)
ax.clear()
ax.scatter(x, y, s = 30, c = 'g')
return fig
def draw_swarmplot(t):
f = int(t * fps)
dff = df.loc[f]
return mplfig_to_npimage(swarm_plot(dff['x']))
Which will also produce the output shown above.
回答2:
def draw_swarmplot(t):
f = int(t * fps)
fig, ax = plt.subplots(1, 1)
dff = df.loc[f]
should be
def draw_swarmplot(t):
global fig,ax
f = int(t * fps)
fig, ax = plt.subplots(1, 1)
dff = df.loc[f]
Otherwise it initializes new objects fig
and ax
that are local to the draw_swarmplot
function. In order to assign to global variables you need to declare them as global
.
回答3:
The problem with your code is that you recreate a new figure at each frame with fig, ax = plt.subplots(1, 1)
since draw_swarmplot(t)
is called at the creation of each frame.
To solve this, you need to create the figure only once, outside the function. To avoid all the points accumulate, use àx.clear()
to clear the axis each time a new frame is made.
Since the code is not very long, I grouped everything into one make_frame(t)
function. I think it makes the code clearer to understand, but you can surely separate in back into two functions. I also added a few lines in case you want fixed axis limits, instead of different ones at each frame. Full code:
import numpy as np
import pandas as pd
from scipy.stats import gaussian_kde
from matplotlib import pyplot as plt
from moviepy.editor import VideoClip
from moviepy.video.io.bindings import mplfig_to_npimage
fps = 10
df = pd.DataFrame(data_dict)
fig, ax = plt.subplots()
# if you want to have fixed axis limits, use these
x_min = float(df.min())
x_max = float(df.max())
# for y values, set the values by eye inspection of the video
# since y values are randomnly draw at the creation of each frame
y_min = 0
y_max = 10
def make_frame(t) :
# select series
i = int(t * fps)
x = df.loc[i]['x']
# prepare data to plot
kde = gaussian_kde(x)
density = kde(x) # estimate the local density at each datapoint
jitter = np.random.rand(*x.shape) - .5
# scale the jitter by the KDE estimate and add it to the centre x-coordinate
y = 1 + (density * jitter * 1000 * 2)
# plot
ax.clear()
ax.scatter(x, y, s = 30, c = 'g')
# comment next two lines if you don't want fixed axis limits
ax.set_xlim(x_min, x_max)
ax.set_ylim(0, 2)
return mplfig_to_npimage(fig)
anim = VideoClip(make_frame, duration=2)
anim.to_videofile('swarmplot.mp4', fps=fps)
# uncomment to display in jupyter notebook
#anim.ipython_display(fps=fps, loop=True, autoplay=True)
来源:https://stackoverflow.com/questions/58787960/how-to-correctly-refer-to-fig-and-ax-with-moviepy-animation