I\'m trying to plot a simple discrete distribution using matplotlib:
You may use
numpy.piecewise
numpy.piecewise allows to define a function dependent on some conditions. Here you have three conditions [x<0, x>=1, (x>=0) & (x<1)]
, and you may define a function to use for each of them.
import matplotlib.pyplot as plt
import numpy as np
l1 = lambda x: 0.3 + x * 0
l2 = lambda x: 0.2 + x * 0
l3 = lambda x: 0.5 + x * 0
mapDiscProb=lambda x: np.piecewise(x, [x<0, x>=1, (x>=0) & (x<1)],[l1,l2,l3])
x = np.linspace(-1, 2)
y = mapDiscProb(x)
fig, ax = plt.subplots()
ax.plot(x, y, clip_on = False)
plt.show()
numpy.vectorize
numpy.vectorize vectorizes a function which is meant to be called with scalars, such that is evaluated for each element in an array. This allows if
/else
statements to be used as expected.
import matplotlib.pyplot as plt
import numpy as np
def mapDiscProb(x):
if x < 0:
return 0.3
elif x >= 1:
return 0.2
else:
return 0.5
x = np.linspace(-1, 2)
y = np.vectorize(mapDiscProb)(x)
fig, ax = plt.subplots()
ax.plot(x, y, clip_on = False)
plt.show()
numpy.select
(credit to PaulH for this idea) numpy.select can choose select values from different arrays based on a condition. For piecewise constant functions this is an easy tool, because it does not require to build any additional functions (one-liner).
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(-1, 2)
y = np.select([x<0, x<1, x>1], [0.3, 0.5, 0.2])
fig, ax = plt.subplots()
ax.plot(x, y, clip_on = False)
plt.show()
Output in all cases:
In case you don't want any vertical lines to appear, it makes sense to plot as many plots as you have conditions.
import matplotlib.pyplot as plt
import numpy as np
l1 = lambda x: 0.3 + x * 0
l2 = lambda x: 0.2 + x * 0
l3 = lambda x: 0.5 + x * 0
x = np.linspace(-1, 2)
func = [l1,l2,l3]
cond = [x<0, x>=1, (x>=0) & (x<1)]
fig, ax = plt.subplots()
for f,c in zip(func,cond):
xi = x[c]
ax.plot(xi, f(xi), color="C0")
plt.show()
Alternatively, using numpy.select
, you may modify the x
array to surely include the values [0,1]
, which lie on the edge between conditions. Choosing the conditions to exclude those values explicitely, [x<0, (x>0) & (x<1), x>1]
(note the lack of any equal sign) will allow to set those values to nan. Nan values are not shown, hence a gap appears.
import matplotlib.pyplot as plt
import numpy as np
x = np.sort(np.append(np.linspace(-1, 2),[0,1]))
y = np.select([x<0, (x>0) & (x<1), x>1], [0.3, 0.5, 0.2], np.nan)
fig, ax = plt.subplots()
ax.plot(x, y)
plt.show()