Dijkstra's algorithm in python

后端 未结 11 1419
无人及你
无人及你 2021-02-01 09:32

I am trying to implement Dijkstra\'s algorithm in python using arrays. This is my implementation.

def extract(Q, w):
    m=0
    minimum=w[0]
    for i in range(l         


        
相关标签:
11条回答
  • 2021-02-01 09:49

    Implementation based on CLRS 2nd Ed. Chapter 24.3

    d is deltas, p is predecessors

    import heapq
    
    def dijkstra(g, s, t):
    
        q = []
        d = {k: sys.maxint for k in g.keys()}
        p = {}
    
        d[s] = 0 
        heapq.heappush(q, (0, s))
    
        while q:
            last_w, curr_v = heapq.heappop(q)
            for n, n_w in g[curr_v]:
                cand_w = last_w + n_w # equivalent to d[curr_v] + n_w 
                # print d # uncomment to see how deltas are updated
                if cand_w < d[n]:
                    d[n] = cand_w
                    p[n] = curr_v
                    heapq.heappush(q, (cand_w, n))
    
        print "predecessors: ", p 
        print "delta: ", d 
        return d[t]
    
    def test():
    
        og = {}
        og["s"] = [("t", 10), ("y", 5)]
        og["t"] = [("y", 2), ("x", 1)]
        og["y"] = [("t", 3), ("x", 9), ("z", 2)]
        og["z"] = [("x", 6), ("s", 7)]
        og["x"] = [("z", 4)]
    
        assert dijkstra(og, "s", "x") == 9 
    
    
    if __name__ == "__main__":
        test()
    

    Implementation assumes all nodes are represented as keys. If say node(e.g "x" in the example above) was not defined as a key in the og, deltas d would be missing that key and check if cand_w < d[n] wouldn't work correctly.

    0 讨论(0)
  • 2021-02-01 09:53

    I wrote it in a more verbose form to make it clearer for a novice reader:

    def get_parent(pos):
        return (pos + 1) // 2 - 1
    
    
    def get_children(pos):
        right = (pos + 1) * 2
        left = right - 1
        return left, right
    
    
    def swap(array, a, b):
        array[a], array[b] = array[b], array[a]
    
    
    class Heap:
    
        def __init__(self):
            self._array = []
    
        def peek(self):
            return self._array[0] if self._array else None
    
        def _get_smallest_child(self, parent):
            return min([
                it
                for it in get_children(parent)
                if it < len(self._array)
            ], key=lambda it: self._array[it], default=-1)
    
        def _sift_down(self):
            parent = 0
            smallest = self._get_smallest_child(parent)
            while smallest != -1 and self._array[smallest] < self._array[parent]:
                swap(self._array, smallest, parent)
                parent, smallest = smallest, self._get_smallest_child(smallest)
    
        def pop(self):
            if not self._array:
                return None
            swap(self._array, 0, len(self._array) - 1)
            node = self._array.pop()
            self._sift_down()
            return node
    
        def _sift_up(self):
            index = len(self._array) - 1
            parent = get_parent(index)
            while parent != -1 and self._array[index] < self._array[parent]:
                swap(self._array, index, parent)
                index, parent = parent, get_parent(parent)
    
        def add(self, item):
            self._array.append(item)
            self._sift_up()
    
        def __bool__(self):
            return bool(self._array)
    
    
    def backtrack(best_parents, start, end):
        if end not in best_parents:
            return None
        cursor = end
        path = [cursor]
        while cursor in best_parents:
            cursor = best_parents[cursor]
            path.append(cursor)
            if cursor == start:
                return list(reversed(path))
        return None
    
    
    def dijkstra(weighted_graph, start, end):
        """
        Calculate the shortest path for a directed weighted graph.
    
        Node can be virtually any hashable datatype.
    
        :param start: starting node
        :param end: ending node
        :param weighted_graph: {"node1": {"node2": weight, ...}, ...}
        :return: ["START", ... nodes between ..., "END"] or None, if there is no
                path
        """
        distances = {i: float("inf") for i in weighted_graph}
        best_parents = {i: None for i in weighted_graph}
    
        to_visit = Heap()
        to_visit.add((0, start))
        distances[start] = 0
    
        visited = set()
    
        while to_visit:
            src_distance, source = to_visit.pop()
            if src_distance > distances[source]:
                continue
            if source == end:
                break
            visited.add(source)
            for target, distance in weighted_graph[source].items():
                if target in visited:
                    continue
                new_dist = distances[source] + weighted_graph[source][target]
                if distances[target] > new_dist:
                    distances[target] = new_dist
                    best_parents[target] = source
                    to_visit.add((new_dist, target))
    
        return backtrack(best_parents, start, end)
    
    0 讨论(0)
  • 2021-02-01 09:53

    I broke down the wikipedia description into the following pseudo-code on my blog rebrained.com:

    Initial state:

    1. give nodes two properties - node.visited and node.distance
    2. set node.distance = infinity for all nodes except set start node to zero
    3. set node.visited = false for all nodes
    4. set current node = start node.

    Current node loop:

    1. if current node = end node, finish and return current.distance & path
    2. for all unvisited neighbors, calc their tentative distance (current.distance + edge to neighbor).
    3. if tentative distance < neighbor's set distance, overwrite it.
    4. set current.isvisited = true.
    5. set current = remaining unvisited node with smallest node.distance

    http://rebrained.com/?p=392

    import sys
    def shortestpath(graph,start,end,visited=[],distances={},predecessors={}):
        """Find the shortest path btw start & end nodes in a graph"""
        # detect if first time through, set current distance to zero
        if not visited: distances[start]=0
        # if we've found our end node, find the path to it, and return
        if start==end:
            path=[]
            while end != None:
                path.append(end)
                end=predecessors.get(end,None)
            return distances[start], path[::-1]
        # process neighbors as per algorithm, keep track of predecessors
        for neighbor in graph[start]:
            if neighbor not in visited:
                neighbordist = distances.get(neighbor,sys.maxint)
                tentativedist = distances[start] + graph[start][neighbor]
                if tentativedist < neighbordist:
                    distances[neighbor] = tentativedist
                    predecessors[neighbor]=start
        # neighbors processed, now mark the current node as visited 
        visited.append(start)
        # finds the closest unvisited node to the start 
        unvisiteds = dict((k, distances.get(k,sys.maxint)) for k in graph if k not in visited)
        closestnode = min(unvisiteds, key=unvisiteds.get)
        # now take the closest node and recurse, making it current 
        return shortestpath(graph,closestnode,end,visited,distances,predecessors)
    if __name__ == "__main__":
        graph = {'a': {'w': 14, 'x': 7, 'y': 9},
                'b': {'w': 9, 'z': 6},
                'w': {'a': 14, 'b': 9, 'y': 2},
                'x': {'a': 7, 'y': 10, 'z': 15},
                'y': {'a': 9, 'w': 2, 'x': 10, 'z': 11},
                'z': {'b': 6, 'x': 15, 'y': 11}}
        print shortestpath(graph,'a','a')
        print shortestpath(graph,'a','b')
        """
        Expected Result:
            (0, ['a']) 
            (20, ['a', 'y', 'w', 'b'])
            """
    
    0 讨论(0)
  • 2021-02-01 09:57

    As others have pointed out, due to not using understandable variable names, it is almost impossible to debug your code.

    Following the wiki article about Dijkstra's algorithm, one can implement it along these lines (and in a million other manners):

    nodes = ('A', 'B', 'C', 'D', 'E', 'F', 'G')
    distances = {
        'B': {'A': 5, 'D': 1, 'G': 2},
        'A': {'B': 5, 'D': 3, 'E': 12, 'F' :5},
        'D': {'B': 1, 'G': 1, 'E': 1, 'A': 3},
        'G': {'B': 2, 'D': 1, 'C': 2},
        'C': {'G': 2, 'E': 1, 'F': 16},
        'E': {'A': 12, 'D': 1, 'C': 1, 'F': 2},
        'F': {'A': 5, 'E': 2, 'C': 16}}
    
    unvisited = {node: None for node in nodes} #using None as +inf
    visited = {}
    current = 'B'
    currentDistance = 0
    unvisited[current] = currentDistance
    
    while True:
        for neighbour, distance in distances[current].items():
            if neighbour not in unvisited: continue
            newDistance = currentDistance + distance
            if unvisited[neighbour] is None or unvisited[neighbour] > newDistance:
                unvisited[neighbour] = newDistance
        visited[current] = currentDistance
        del unvisited[current]
        if not unvisited: break
        candidates = [node for node in unvisited.items() if node[1]]
        current, currentDistance = sorted(candidates, key = lambda x: x[1])[0]
    
    print(visited)
    

    This code is more verbous than necessary and I hope comparing your code with mine you might spot some differences.

    The result is:

    {'E': 2, 'D': 1, 'G': 2, 'F': 4, 'A': 4, 'C': 3, 'B': 0}
    
    0 讨论(0)
  • 2021-02-01 10:01

    I implement Dijkstra using priority-queue. Apart from that, I also implement min-heap myself. Hope this will help you.

    from collections import defaultdict
    
    
    class MinPQ:
        """
        each heap element is in form (key value, object handle), while heap
        operations works based on comparing key value and object handle points to
        the corresponding application object.
        """
    
        def __init__(self, array=[]):
            self._minheap = list(array)
            self._length = len(array)
            self._heapsize = 0
            self._build_min_heap()
    
        def _left(self, idx):
            return 2*idx+1
    
        def _right(self, idx):
            return 2*idx+2
    
        def _parent(self, idx):
            return int((idx-1)/2)
    
        def _min_heapify(self, idx):
            left = self._left(idx)
            right = self._right(idx)
            min_idx = idx
            if left <= self._heapsize-1 and self._minheap[left] < self._minheap[min_idx]:
                min_idx = left
            if right <= self._heapsize-1 and self._minheap[right] < self._minheap[min_idx]:
                min_idx = right
            if min_idx != idx:
                self._minheap[idx], self._minheap[min_idx] = self._minheap[min_idx], self._minheap[idx]
                self._min_heapify(min_idx)
    
        def _build_min_heap(self):
            self._heapsize = self._length
            mid_id = int((self._heapsize)/2)-1
            for i in range(mid_id, -1, -1):
                self._min_heapify(i)
    
        def decrease_key(self, idx, new_key):
            while idx > 0 and new_key < self._minheap[self._parent(idx)]:
                self._minheap[idx] = self._minheap[self._parent(idx)]
                idx = self._parent(idx)
            self._minheap[idx] = new_key
    
        def extract_min(self):
            if self._heapsize < 1:
                raise IndexError
            minimum = self._minheap[0]
            self._minheap[0] = self._minheap[self._heapsize-1]
            self._heapsize = self._heapsize - 1
            self._min_heapify(0)
            return minimum
    
        def insert(self, item):
            self._minheap.append(item)
            self._heapsize = self._heapsize + 1
            self.decrease_key(self._heapsize-1, item)
    
        @property
        def minimum(self):
            return self._minheap[0]
    
        def is_empty(self):
            return self._heapsize == 0
    
        def __str__(self):
            return str(self._minheap)
    
        __repr__ = __str__
    
        def __len__(self):
            return self._heapsize
    
    
    class DiGraph:
        def __init__(self, edges=None):
            self.adj_list = defaultdict(list)
            self.add_weighted_edges(edges)
    
        @property
        def nodes(self):
            nodes = set()
            nodes.update(self.adj_list.keys())
            for node in self.adj_list.keys():
                for neighbor, weight in self.adj_list[node]:
                    nodes.add(neighbor)
            return list(nodes)
    
        def add_weighted_edges(self, edges):
            if edges is None:
                return None
            for edge in edges:
                self.add_weighted_edge(edge)
    
        def add_weighted_edge(self, edge):
            node1, node2, weight = edge
            self.adj_list[node1].append((node2, weight))
    
        def weight(self, tail, head):
            for node, weight in self.adj_list[tail]:
                if node == head:
                    return weight
            return None
    
    
    def relax(min_heapq, dist, graph, u, v):
        if dist[v] > dist[u] + graph.weight(u, v):
            dist[v] = dist[u] + graph.weight(u, v)
            min_heapq.insert((dist[v], v))
    
    
    def dijkstra(graph, start):
        # initialize
        dist = dict.fromkeys(graph.nodes, float('inf'))
        dist[start] = 0
        min_heapq = MinPQ()
        min_heapq.insert((0, start))
    
        while not min_heapq.is_empty():
            distance, u = min_heapq.extract_min()
            # we may add a node multiple time in priority queue, but we process it
            # only once
            if distance > dist[u]:
                continue
            for neighbor, weight in graph.adj_list[u]:
                relax(min_heapq, dist, graph, u, neighbor)
    
        return dist
    
    0 讨论(0)
提交回复
热议问题