问题
For convenience, I transfer the retention graph from Seaborn to Plotly, so that I can apply shapes to it later. The plotly library seems to be suitable for this.graph_objects, but I don't understand how to pass DataFrame data.
import pandas as pd
import numpy as np
import seaborn as sns
import plotly as ply
import matplotlib.pyplot as plt
import plotly.graph_objects as go
df=pd.DataFrame(index=['01.2020','02.2020','03.2020','04.2020','05.2020','06.2020'],
data={0:[1,1,1,1,1,1],
1:[0.58, 0.88, 0.27, 0.28, 0.68,0.90],
2:[0.56, 0.58, 0.1, 0.77, 0.68,None],
3:[0.78, 0.33, 0.4, 0.79, None,None],
4:[0.58, 0.16, 0.89, None, None,None],
5:[0.25, 0.14, None, None, None,None],
6:[0.69, None, None, None, None,None] })
sns.set(style='white')
plt.figure(figsize=(12, 8))
plt.title('Cohorts: User Retention')
sns.heatmap(df,annot=True, fmt='.0%');
How I can do it in Plotly?
回答1:
There is already an answer that can help you see. But it's kind of outdated as many methods there are already deprecated. Actually, as long as you are fine changing your scale from 0-1 to 0-100 you could use plotly.figure_factory.create_annotated_heatmap but as far as I know all figure_factory are going to be deprecated soon. So the bad thing is that you are forced to write annotations (the text) manually as following.
import pandas as pd
import numpy as np
import plotly.graph_objects as go
df = pd.DataFrame(index=['01.2020','02.2020','03.2020','04.2020','05.2020','06.2020'],
data={0:[1,1,1,1,1,1],
1:[0.58, 0.88, 0.27, 0.28, 0.68,0.90],
2:[0.56, 0.58, 0.1, 0.77, 0.68,None],
3:[0.78, 0.33, 0.4, 0.79, None,None],
4:[0.58, 0.16, 0.89, None, None,None],
5:[0.25, 0.14, None, None, None,None],
6:[0.69, None, None, None, None,None] })
z = df.values
x = df.columns
y = df.index
annotations = []
for n, row in enumerate(z):
for m, val in enumerate(row):
annotations.append(
dict(text="{0:.0%}".format(z[n][m]) if not np.isnan(z[n][m]) else '',
x=x[m],
y=y[n],
xref='x1',
yref='y1',
showarrow=False))
layout = dict(title='Cohorts: User Retention',
title_x=0.5,
annotations=annotations,
yaxis=dict(showgrid=False,
tickmode='array',
tickvals=np.arange(1,len(y)+1),
ticktext=y
),
xaxis=dict(showgrid=False),
width=700,
height=700,
autosize=False
)
trace = go.Heatmap(x=x, y=y, z=z)
fig = go.Figure(data=trace, layout=layout)
fig.show()
来源:https://stackoverflow.com/questions/62470758/retention-heatmap-in-plotly