How to use the `pos` argument in `networkx` to create a flowchart-style Graph? (Python 3)

前端 未结 1 1838
眼角桃花
眼角桃花 2020-12-29 06:57

I am trying create a linear network graph using Python (preferably with matplotlib and networkx although would be int

相关标签:
1条回答
  • 2020-12-29 07:23

    Networkx has decent plotting facilities for exploratory data analysis, it is not the tool to make publication quality figures, for various reason that I don't want to go into here. I hence rewrote that part of the code base from scratch, and made a stand-alone drawing module called netgraph that can be found here (like the original purely based on matplotlib). The API is very, very similar and well documented, so it should not be too hard to mold to your purposes.

    Building on that I get the following result:

    I chose colour to denote the edge strength as you can
    1) indicate negative values, and
    2) distinguish small values better.
    However, you can also pass an edge width to netgraph instead (see netgraph.draw_edges()).

    The different order of the branches is a result of your data structure (a dict), which indicates no inherent order. You would have to amend your data structure and the function _parse_input() below to fix that issue.

    Code:

    import itertools
    import numpy as np
    import matplotlib.pyplot as plt
    import netgraph; reload(netgraph)
    
    def plot_layered_network(weight_matrices,
                             distance_between_layers=2,
                             distance_between_nodes=1,
                             layer_labels=None,
                             **kwargs):
        """
        Convenience function to plot layered network.
    
        Arguments:
        ----------
            weight_matrices: [w1, w2, ..., wn]
                list of weight matrices defining the connectivity between layers;
                each weight matrix is a 2-D ndarray with rows indexing source and columns indexing targets;
                the number of sources has to match the number of targets in the last layer
    
            distance_between_layers: int
    
            distance_between_nodes: int
    
            layer_labels: [str1, str2, ..., strn+1]
                labels of layers
    
            **kwargs: passed to netgraph.draw()
    
        Returns:
        --------
            ax: matplotlib axis instance
    
        """
        nodes_per_layer = _get_nodes_per_layer(weight_matrices)
    
        node_positions = _get_node_positions(nodes_per_layer,
                                             distance_between_layers,
                                             distance_between_nodes)
    
        w = _combine_weight_matrices(weight_matrices, nodes_per_layer)
    
        ax = netgraph.draw(w, node_positions, **kwargs)
    
        if not layer_labels is None:
            ax.set_xticks(distance_between_layers*np.arange(len(weight_matrices)+1))
            ax.set_xticklabels(layer_labels)
            ax.xaxis.set_ticks_position('bottom')
    
        return ax
    
    def _get_nodes_per_layer(weight_matrices):
        nodes_per_layer = []
        for w in weight_matrices:
            sources, targets = w.shape
            nodes_per_layer.append(sources)
        nodes_per_layer.append(targets)
        return nodes_per_layer
    
    def _get_node_positions(nodes_per_layer,
                            distance_between_layers,
                            distance_between_nodes):
        x = []
        y = []
        for ii, n in enumerate(nodes_per_layer):
            x.append(distance_between_nodes * np.arange(0., n))
            y.append(ii * distance_between_layers * np.ones((n)))
        x = np.concatenate(x)
        y = np.concatenate(y)
        return np.c_[y,x]
    
    def _combine_weight_matrices(weight_matrices, nodes_per_layer):
        total_nodes = np.sum(nodes_per_layer)
        w = np.full((total_nodes, total_nodes), np.nan, np.float)
    
        a = 0
        b = nodes_per_layer[0]
        for ii, ww in enumerate(weight_matrices):
            w[a:a+ww.shape[0], b:b+ww.shape[1]] = ww
            a += nodes_per_layer[ii]
            b += nodes_per_layer[ii+1]
    
        return w
    
    def test():
        w1 = np.random.rand(4,5) #< 0.50
        w2 = np.random.rand(5,6) #< 0.25
        w3 = np.random.rand(6,3) #< 0.75
    
        import string
        node_labels = dict(zip(range(18), list(string.ascii_lowercase)))
    
        fig, ax = plt.subplots(1,1)
        plot_layered_network([w1,w2,w3],
                             layer_labels=['start', 'step 1', 'step 2', 'finish'],
                             ax=ax,
                             node_size=20,
                             node_edge_width=2,
                             node_labels=node_labels,
                             edge_width=5,
        )
        plt.show()
        return
    
    def test_example(input_dict):
        weight_matrices, node_labels = _parse_input(input_dict)
        fig, ax = plt.subplots(1,1)
        plot_layered_network(weight_matrices,
                             layer_labels=['', '1', '2', '3', '4'],
                             distance_between_layers=10,
                             distance_between_nodes=8,
                             ax=ax,
                             node_size=300,
                             node_edge_width=10,
                             node_labels=node_labels,
                             edge_width=50,
        )
        plt.show()
        return
    
    def _parse_input(input_dict):
        weight_matrices = []
        node_labels = []
    
        # initialise sources
        sources = set()
        for v in input_dict[1].values():
            for s in v.keys():
                sources.add(s)
        sources = list(sources)
    
        for ii in range(len(input_dict)):
            inner_dict = input_dict[ii+1]
            targets = inner_dict.keys()
    
            w = np.full((len(sources), len(targets)), np.nan, np.float)
            for ii, s in enumerate(sources):
                for jj, t in enumerate(targets):
                    try:
                        w[ii,jj] = inner_dict[t][s]
                    except KeyError:
                        pass
    
            weight_matrices.append(w)
            node_labels.append(sources)
            sources = targets
    
        node_labels.append(targets)
        node_labels = list(itertools.chain.from_iterable(node_labels))
        node_labels = dict(enumerate(node_labels))
    
        return weight_matrices, node_labels
    
    # --------------------------------------------------------------------------------
    # script
    # --------------------------------------------------------------------------------
    
    if __name__ == "__main__":
    
        # test()
    
        input_dict =   {
            1: {
                "Group 1":{"sample_0":0.5, "sample_1":0.5, "sample_2":0, "sample_3":0, "sample_4":0},
                "Group 2":{"sample_0":0, "sample_1":0, "sample_2":1, "sample_3":0, "sample_4":0},
                "Group 3":{"sample_0":0, "sample_1":0, "sample_2":0, "sample_3":0.5, "sample_4":0.5}
                },
            2: {
                "Group 1":{"Group 1":1, "Group 2":0, "Group 3":0},
                "Group 2":{"Group 1":0, "Group 2":1, "Group 3":0},
                "Group 3":{"Group 1":0, "Group 2":0, "Group 3":1}
                },
            3: {
                "Group 1":{"Group 1":0.25, "Group 2":0, "Group 3":0.75},
                "Group 2":{"Group 1":0.25, "Group 2":0.75, "Group 3":0}
                },
            4: {
                "Group 1":{"Group 1":1, "Group 2":0},
                "Group 2":{"Group 1":0.25, "Group 2":0.75}
                }
            }
    
        test_example(input_dict)
    
        pass
    
    0 讨论(0)
提交回复
热议问题