问题
I want to plot multiple seaborn distplot
under a same window, where each plot has the same x and y grid. My attempt is shown below, which does not work.
# function to plot the density curve of the 200 Median Stn. MC-losses
def make_density(stat_list,color, layer_num):
num_subplots = len(stat_list)
ncols = 3
nrows = (num_subplots + ncols - 1) // ncols
fig, axes = plt.subplots(ncols=ncols, nrows=nrows, figsize=(ncols * 6, nrows * 5))
for i in range(len(stat_list)):
# Plot formatting
plt.title('Layer ' + layer_num)
plt.xlabel('Median Stn. MC-Loss')
plt.ylabel('Density')
plt.xlim(-0.2,0.05)
plt.ylim(0, 85)
min_ylim, max_ylim = plt.ylim()
# Draw the density plot.
sns.distplot(stat_list, hist = True, kde = True,
kde_kws = {'linewidth': 2}, color=color)
# `stat_list` is a list of 6 lists
# I want to draw histogram and density plot of
# each of these 6 lists contained in `stat_list` in a single window,
# where each row containing the histograms and densities of the 3 plots
# so in my example, there would be 2 rows of 3 columns of plots (2 x 3 =6).
stat_list = [[0.3,0.5,0.7,0.3,0.5],[0.2,0.1,0.9,0.7,0.4],[0.9,0.8,0.7,0.6,0.5]
[0.2,0.6,0.75,0.87,0.91],[0.2,0.3,0.8,0.9,0.3],[0.2,0.3,0.8,0.87,0.92]]
How can I modify my function to draw multiple distplot
under the same window, where the x and y grid for each displayed plot is identical?
Thank you,
PS: Aside, I want the 6 distplots to have identical color, preferably green for all of them.
回答1:
- The easiest method is to load the data into pandas and then use seaborn.displot.
.displot
replaces .distplot in seaborn version 0.11.0- Technically, what you would have wanted before, is a
FacetGrid
mapped withdistplot
.
- Technically, what you would have wanted before, is a
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
# data
stat_list = [[0.3,0.5,0.7,0.3,0.5], [0.2,0.1,0.9,0.7,0.4], [0.9,0.8,0.7,0.6,0.5], [0.2,0.6,0.75,0.87,0.91], [0.2,0.3,0.8,0.9,0.3], [0.2,0.3,0.8,0.87,0.92]]
# load the data into pandas and then transpose it for the correct column data
df = pd.DataFrame(stat_list).T
# name the columns; specify a layer number
df.columns = ['A', 'B', 'C', 'D', 'E', 'F']
# now stack the data into a long (tidy) format
dfl = df.stack().reset_index(level=1).rename(columns={'level_1': 'Layer', 0: 'Median Stn. MC-Loss'})
# plot a displot
g = sns.displot(data=dfl, x='Median Stn. MC-Loss', col='Layer', col_wrap=3, kde=True, color='green')
g.set_axis_labels(y_var='Density')
g.set(xlim=(0, 1.0), ylim=(0, 3.0))
sns.FacetGrid
and sns.distplot
.distplot
is deprecated
p = sns.FacetGrid(data=dfl, col='Layer', col_wrap=3, height=5)
p.map(sns.distplot, 'Median Stn. MC-Loss', bins=5, kde=True, color='green')
p.set(xlim=(0, 1.0))
回答2:
There is a general solution as a free library of seventeen matplotlib graphics utilities + user guide here: https://www.mlbridgeresearch.com/products/free-article-2. I got tired of interrupting my research to write utility software, so I’ve accumulated libraries that address common needs. The code is well-documented, and it works well. The example calls histogram_grid() in the library, which returns the plot grid on a matplotlib Figure. Because histograms generally do not have the same ranges, the standard method does not accommodate exactly what you asked for, so the adjustments are made to the returned Figure.
import pandas as pd
import matplotlib.pyplot as plt
from statistics_utilities import histogram_grid
stat_list = [[0.3, 0.5, 0.7, 0.3, 0.5], [0.2, 0.1, 0.9, 0.7, 0.4], [0.9, 0.8, 0.7, 0.6, 0.5],
[0.2, 0.6, 0.75, 0.87, 0.91], [0.2, 0.3, 0.8, 0.9, 0.3], [0.2, 0.3, 0.8, 0.87, 0.92]]
df = pd.DataFrame(stat_list).transpose()
# histogram_grid() accepts only a DataFrame and requires named columns.
df.columns = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6']
# If kde is True, the plot is a density plot no matter how hist_type is set.
hist_type = 'density'
variable_names = df.columns
bins = 3
fig = histogram_grid(df, bins=bins, hist_type=hist_type, kde=True, legend=False,
title='test title', variable_names=variable_names,
n_gridcolumns=3, height=6, width=10)
fig.subplots_adjust(wspace=.2, left=0.035, right=.95, bottom=.13)
# the adjustments to the axes on the 2 x 3 grid plot.
# Turn of x-axis labels/ticks in the top row and y-axis
# labels/ticks in the 1st column.
axes_list = fig.axes # get a list of Axes in Figure
ax_index = 0
modify_xaxes_indexes = [0, 1, 2]
modify_yaxes_indexes = [1, 2, 4, 5]
for ax in axes_list:
ax.set_xlabel(None)
ax.set_ylabel(None)
# normally, the xlim() would be calculated but I can see that
# .1 <= x <= .92 and similarly the densities are 0 <= y <= 3.
ax.set_xlim(.05, .95)
ax.set_ylim(0, 3)
if ax_index in modify_xaxes_indexes:
ax.tick_params(
axis='x', # changes apply to the x-axis
which='both', # both major and minor ticks are affected
bottom=False, # ticks along the bottom edge are off
top=False, # ticks along the top edge are off
labelbottom=False) # labels along the bottom edge are off
if ax_index in modify_yaxes_indexes:
ax.tick_params(
axis='y', # changes apply to the x-axis
which='both', # both major and minor ticks are affected
left=False, # ticks along the bottom edge are off
right=False, # ticks along the top edge are off
labelleft=False) # labels along the bottom edge are off
ax_index += 1
plt.show()
plt.close()
来源:https://stackoverflow.com/questions/63924125/how-to-plot-multiple-seaborn-distplot-in-a-single-figure