import matplotlib.pyplot as plt import numpy as np num_mixtures = 8 radius = 2.0 std = 0.02 thetas = np.linspace(0, 2 * np.pi, num_mixtures + 1)[:num_mixtures] xs, ys = radius * np.sin(thetas), radius * np.cos(thetas) mix_coeffs=tuple([1 / num_mixtures] * num_mixtures) mean=tuple(zip(xs, ys)) cov=tuple([(std, std)] * num_mixtures) ax = None epoch = 0 fig = None def gmm_sample(num_samples, mix_coeffs, mean, cov): z = np.random.multinomial(num_samples, mix_coeffs) samples = np.zeros(shape=[num_samples, len(mean[0])]) i_start = 0 for i in range(len(mix_coeffs)): i_end = i_start + z[i] samples[i_start:i_end, :] = np.random.multivariate_normal( mean=np.array(mean)[i, :], cov=np.diag(np.array(cov)[i, :]), size=z[i]) i_start = i_end return samples def disp_scatter(x, fig=None, ax=None): if ax is None: fig, ax = plt.subplots() ax.scatter(x[:, 0], x[:, 1], s=10, marker='+', color='r', alpha=0.8, label='real data') ax.legend() return fig, ax num_samples=1000 x = gmm_sample(num_samples, mix_coeffs, mean, cov) fig, ax = disp_scatter(x, fig=None, ax=None) fig.tight_layout() fig.savefig("output\{}.png".format(epoch))
num_mixtures = 8
num_mixtures = 1
来源:https://www.cnblogs.com/gaona666/p/12446784.html