Visualizing time series data of graph nodes in plotly

北城余情 提交于 2021-02-08 04:38:23

问题


I've a graph created in Networkx and plotted using plotly

Code:

import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import plotly.graph_objects as go

from pprint import pprint
from collections import OrderedDict


def get_edge_trace(G):
    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = G.nodes[edge[0]]['pos']
        x1, y1 = G.nodes[edge[1]]['pos']
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)

    edge_trace = go.Scatter(
        x=edge_x, y=edge_y,
        line=dict(width=0.5, color='#888'),
        hoverinfo='none',
        mode='lines')
    return edge_trace


def get_node_trace(G):
    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = G.nodes[node]['pos']
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x, y=node_y,
        mode='markers',
        hoverinfo='text',
        marker=dict(
            showscale=True,
            # colorscale options
            # 'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
            # 'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
            # 'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
            colorscale='YlGnBu',
            reversescale=True,
            color=[],
            size=10,
            colorbar=dict(
                thickness=15,
                title='Node Connections',
                xanchor='left',
                titleside='right'
            ),
            line_width=2))

    return node_trace


if __name__ == '__main__':

    tail = [1, 2, 3]
    head = [2, 3, 4]

    xpos = [0, 1, 2, 3]
    ypos = [0, 0, 0, 0]
    xpos_ypos = [(x, y) for x, y in zip(xpos, ypos)]

    ed_ls = [(x, y) for x, y in zip(tail, head)]
    G = nx.OrderedDiGraph()
    G.add_edges_from(ed_ls)


    pos = OrderedDict(zip(G.nodes, xpos_ypos))
    nx.draw(G, pos=pos, with_labels=True)
    nx.set_node_attributes(G, pos, 'pos')

    plt.show()

    # convert to plotly graph
    edge_trace = get_edge_trace(G)
    node_trace = get_node_trace(G)

    pprint(edge_trace)
    pprint(node_trace)

    fig = go.Figure(data=[edge_trace, node_trace],
                    layout=go.Layout(
                        title='<br>Network graph made with Python',
                        titlefont_size=16,
                        showlegend=False,
                        hovermode='closest',
                        margin=dict(b=20, l=5, r=5, t=40),
                        annotations=[dict(
                            text="Python code: <a href='https://plot.ly/ipython-notebooks/network-graphs/'> https://plot.ly/ipython-notebooks/network-graphs/</a>",
                            showarrow=False,
                            xref="paper", yref="paper",
                            x=0.005, y=-0.002)],
                        xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
                        yaxis=dict(showgrid=False, zeroline=False, showticklabels=False))
                    )
    fig.write_html('plot.html', auto_open=True)

Output:

The time-series data of nodes in the above graph is read from data frame columns plotted as below

import plotly.graph_objects as go
import numpy as np
import pandas as pd

df = pd.DataFrame(np.random.randint(0, 100, size=(20, 5)), columns=list('tABCD'))

fig = go.Figure()
fig.add_trace(go.Scatter(
                x=df.t,
                y=df['A'],
                name="1",
                line_color='deepskyblue',
                opacity=0.8))

fig.add_trace(go.Scatter(
                x=df.t,
                y=df['B'],
                name="2",
                line_color='dimgray',
                opacity=0.8))

fig.add_trace(go.Scatter(
                x=df.t,
                y=df['C'],
                name="3",
                line_color='blue',
                opacity=0.8))

fig.add_trace(go.Scatter(
                x=df.t,
                y=df['D'],
                name="4",
                line_color='red',
                opacity=0.8))

fig.write_html('ts.html', auto_open=True)

I want to link the above two plots, i.e. I want to make the plots interactive. For instance, I'd like to have two subplots with time-series plot on the left and Networkx graph on the right. I would like to display the time-series plots corresponding to the nodes that are selected on the Networkx graph. Example, if nodes labelled 1 and 4 are selected, the time-series data corresponding to name= 1 and name = 4 should be displayed on the left.

Any suggestions on how to do this will be really helpful.

EDIT: I found this, we could select and deselect lines in the plot by clicking on the legend. Likewise, I'd like to select and deselect lines by clicking on the nodes in Networkx graph.

EDIT2: For creating subplots

import plotly.graph_objects as go
import numpy as np
import pandas as pd
from plotly.subplots import make_subplots

df = pd.DataFrame(np.random.randint(0, 100, size=(20, 5)), columns=list('tABCD'))
df2 = pd.DataFrame(np.random.randint(0, 100, size=(20, 5)), columns=list('tABCD'))

fig = go.Figure()
fig = make_subplots(rows=1, cols=2)

fig.add_trace(go.Scatter(
                x=df.t,
                y=df['A'],
                name="1",
                line_color='deepskyblue',
                opacity=0.8,
                legendgroup='group1'),
                row=1, col=1
                )

fig.add_trace(go.Scatter(
                x=df.t,
                y=df['B'],
                name="2",
                line_color='dimgray',
                opacity=0.8,
                legendgroup='group2'),
                row=1, col=1
                )

fig.add_trace(go.Scatter(
                x=df.t,
                y=df['C'],
                name="3",
                line_color='blue',
                opacity=0.8,
                legendgroup='group3'),
                row=1, col=1
                )

fig.add_trace(go.Scatter(
                x=df.t,
                y=df['D'],
                name="4",
                line_color='red',
                opacity=0.8,
                legendgroup='group4'),
                row=1, col=1
                )

fig.add_trace(go.Scatter(
                x=df2.t,
                y=df2['A'],
                name="1",
                line_color='deepskyblue',
                opacity=0.8,
                legendgroup='group1',
                showlegend=False),
                row=1, col=2
                )

fig.add_trace(go.Scatter(
                x=df2.t,
                y=df2['B'],
                name="2",
                line_color='dimgray',
                opacity=0.8,
                legendgroup='group2',
                showlegend=False),
                row=1, col=2
                )

fig.add_trace(go.Scatter(
                x=df2.t,
                y=df2['C'],
                name="3",
                line_color='blue',
                opacity=0.8,
                legendgroup='group3',
                showlegend=False),
                row=1, col=2
                )

fig.add_trace(go.Scatter(
                x=df2.t,
                y=df2['D'],
                name="4",
                line_color='red',
                opacity=0.8,
                legendgroup='group4',
                showlegend=False),
                row=1, col=2
                )

fig.write_html('ts.html', auto_open=True)

I would like to know how to update the solution provided below for adding subplots


回答1:


Write post_script which is Javascript code that will be added in the resulting div after it's created.

In post_script,

  • add an event listener for plotly_click event
  • check if clicked point is a marker
  • toggle visible property of corresponding scatter plot from true to 'legendonly'
import numpy as np
import pandas as pd
import networkx as nx
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from pprint import pprint
from collections import OrderedDict


def get_edge_trace(G):
    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = G.nodes[edge[0]]["pos"]
        x1, y1 = G.nodes[edge[1]]["pos"]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)

    edge_trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        line=dict(width=0.5, color="#888"),
        hoverinfo="none",
        showlegend=False,
        xaxis="x2",
        yaxis="y2",
        mode="lines",
    )
    return edge_trace


def get_node_trace(G):
    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = G.nodes[node]["pos"]
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode="markers",
        hoverinfo="text",
        xaxis="x2",
        yaxis="y2",
        showlegend=False,
        marker=dict(
            showscale=True,
            # colorscale options
            # 'Greys' | 'YlGnBu' | 'Greens' | 'YlOrRd' | 'Bluered' | 'RdBu' |
            # 'Reds' | 'Blues' | 'Picnic' | 'Rainbow' | 'Portland' | 'Jet' |
            # 'Hot' | 'Blackbody' | 'Earth' | 'Electric' | 'Viridis' |
            colorscale="YlGnBu",
            reversescale=True,
            color=[],
            size=10,
            colorbar=dict(
                thickness=15,
                title="Node Connections",
                titleside="right",
                x=0.95,
            ),
            line_width=2,
        ),
    )

    return node_trace


if __name__ == "__main__":

    tail = [1, 2, 3]
    head = [2, 3, 4]

    xpos = [0, 1, 2, 3]
    ypos = [0, 0, 0, 0]
    xpos_ypos = [(x, y) for x, y in zip(xpos, ypos)]

    ed_ls = [(x, y) for x, y in zip(tail, head)]
    G = nx.OrderedDiGraph()
    G.add_edges_from(ed_ls)

    pos = OrderedDict(zip(G.nodes, xpos_ypos))
    nx.draw(G, pos=pos, with_labels=True)
    nx.set_node_attributes(G, pos, "pos")

    fig = make_subplots(rows=1, cols=2)
    fig.layout.update(
        dict(
            title="<br>Network graph made with Python",
            titlefont_size=16,
            hovermode="closest",
            # margin=dict(b=20, l=5, r=5, t=40),
            annotations=[
                dict(
                    text="Python code: <a href='https://plot.ly/ipython-notebooks/network-graphs/'> https://plot.ly/ipython-notebooks/network-graphs/</a>",
                    showarrow=False,
                    xref="paper",
                    yref="paper",
                    x=0,
                    y=0,
                    yshift=-0.1,
                )
            ],
            yaxis=dict(domain=[0.1, 1]),
            xaxis=dict(domain=[0, 0.6]),
            xaxis2=dict(
                domain=[0.7, 0.94],
                showgrid=False,
                zeroline=False,
                showticklabels=False,
            ),
            yaxis2=dict(showgrid=False, zeroline=False, showticklabels=False),
        )
    )

    df = pd.DataFrame(
        np.random.randint(0, 100, size=(20, 5)), columns=list("tABCD")
    )

    data = [
        go.Scatter(
            x=df.t, y=df["A"], name="1", line_color="deepskyblue", opacity=0.8
        ),
        go.Scatter(
            x=df.t, y=df["B"], name="2", line_color="dimgray", opacity=0.8
        ),
        go.Scatter(
            x=df.t, y=df["C"], name="3", line_color="blue", opacity=0.8
        ),
        go.Scatter(x=df.t, y=df["D"], name="4", line_color="red", opacity=0.8),
    ]

    fig.add_traces(data=data, rows=[1, 1, 1, 1], cols=[1, 1, 1, 1])

    fig.add_traces(
        data=[get_edge_trace(G), get_node_trace(G),], rows=[1, 1], cols=[2, 2]
    )

    post_script = """
    gd = document.getElementById('{plot_id}');
    gd.on('plotly_click', function(data) {
        var pn='',
        tn='',
        isNodeClick=false;
        for (var i=0; i < data.points.length; i++){
            pn = data.points[i].pointNumber;
            tn = data.points[i].curveNumber;
            if(data.points[i].fullData.mode === 'markers') {
                isNodeClick = true;
            }
        };
        if (!isNodeClick) return;
        var visible = gd.calcdata[pn][0].trace.visible;
        const update = {'visible': visible === true ? 'legendonly': true}
        Plotly.restyle(gd, update, [pn]);
        return false;
    });
    """
    fig.write_html("plot.html", post_script=post_script, auto_open=True)



来源:https://stackoverflow.com/questions/60596968/visualizing-time-series-data-of-graph-nodes-in-plotly

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!