dijkstra_algorithm.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. """Python 3 implementation of Djikstra's algorithm for finding the shortest
  2. path between nodes in a graph. Written as a learning exercise, so lots of
  3. comments and no error handling.
  4. """
  5. from collections import deque
  6. import math
  7. INFINITY = float("inf")
  8. class Graph:
  9. def __init__(self):
  10. """Reads graph definition and stores it. Each line of the graph
  11. definition file defines an edge by specifying the start node,
  12. end node, and distance, delimited by spaces.
  13. Stores the graph definition in two properties which are used by
  14. Dijkstra's algorithm in the shortest_path method:
  15. self.nodes = set of all unique nodes in the graph
  16. self.adjacency_list = dict that maps each node to an unordered set of
  17. (neighbor, distance) tuples.
  18. """
  19. # Read the graph definition file and store in graph_edges as a list of
  20. # lists of [from_node, to_node, distance]. This data structure is not
  21. # used by Dijkstra's algorithm, it's just an intermediate step in the
  22. # create of self.nodes and self.adjacency_list.
  23. self.points = {} # dict,{id:[x, y]}
  24. self.graph_edges = [] # list, [(node_id1, node_id2, distance)]
  25. self.nodes = set() # dict,{id:[node_id1, node_id2]}
  26. def AddVertex(self, name, point):
  27. self.points[name] = point
  28. def ResetVertexValue(self, name, point):
  29. """
  30. point : [x, y]
  31. """
  32. self.points[name] = point
  33. for i in range(len(self.graph_edges)):
  34. name1,name2,_=self.graph_edges[i]
  35. if name1==name or name2==name:
  36. pt1=self.points[name1]
  37. pt2=self.points[name2]
  38. distance = math.sqrt(math.pow(pt1[0] - pt2[0], 2) + math.pow(pt1[1] - pt2[1], 2)) + 0.000001
  39. self.graph_edges[i]=[name1,name2,distance]
  40. self.nodes.update([name1, name2])
  41. for item in self.adjacency_list[name1]:
  42. neighbor,_=item
  43. if neighbor==name2:
  44. self.adjacency_list[name1].discard(item)
  45. self.adjacency_list[name1].add((neighbor,distance))
  46. break
  47. def AddEdge(self, name1, name2):
  48. if self.points.get(name1) == None or self.points.get(name2) == None:
  49. print("node :%s or %s not exis" % (name1, name2))
  50. return False
  51. pt1 = self.points[name1]
  52. pt2 = self.points[name2]
  53. distance = math.sqrt(math.pow(pt1[0] - pt2[0], 2) + math.pow(pt1[1] - pt2[1], 2)) + 0.000001
  54. self.graph_edges.append((name1, name2, distance))
  55. self.nodes.update([name1, name2])
  56. self.adjacency_list = {node: set() for node in self.nodes}
  57. for edge in self.graph_edges:
  58. self.adjacency_list[edge[0]].add((edge[1], edge[2]))
  59. return True
  60. def __getitem__(self, item):
  61. return self.points[item]
  62. def shortest_path(self, start_node, end_node):
  63. """Uses Dijkstra's algorithm to determine the shortest path from
  64. start_node to end_node. Returns (path, distance).
  65. """
  66. unvisited_nodes = self.nodes.copy() # All nodes are initially unvisited.
  67. # Create a dictionary of each node's distance from start_node. We will
  68. # update each node's distance whenever we find a shorter path.
  69. distance_from_start = {
  70. node: (0 if node == start_node else INFINITY) for node in self.nodes
  71. }
  72. # Initialize previous_node, the dictionary that maps each node to the
  73. # node it was visited from when the the shortest path to it was found.
  74. previous_node = {node: None for node in self.nodes}
  75. while unvisited_nodes:
  76. # Set current_node to the unvisited node with shortest distance
  77. # calculated so far.
  78. current_node = min(
  79. unvisited_nodes, key=lambda node: distance_from_start[node]
  80. )
  81. unvisited_nodes.remove(current_node)
  82. # If current_node's distance is INFINITY, the remaining unvisited
  83. # nodes are not connected to start_node, so we're done.
  84. if distance_from_start[current_node] == INFINITY:
  85. break
  86. # For each neighbor of current_node, check whether the total distance
  87. # to the neighbor via current_node is shorter than the distance we
  88. # currently have for that node. If it is, update the neighbor's values
  89. # for distance_from_start and previous_node.
  90. for neighbor, distance in self.adjacency_list[current_node]:
  91. new_path = distance_from_start[current_node] + distance
  92. if new_path < distance_from_start[neighbor]:
  93. distance_from_start[neighbor] = new_path
  94. previous_node[neighbor] = current_node
  95. if current_node == end_node:
  96. break # we've visited the destination node, so we're done
  97. # To build the path to be returned, we iterate through the nodes from
  98. # end_node back to start_node. Note the use of a deque, which can
  99. # appendleft with O(1) performance.
  100. path = deque()
  101. current_node = end_node
  102. while previous_node[current_node] is not None:
  103. path.appendleft(current_node)
  104. current_node = previous_node[current_node]
  105. path.appendleft(start_node)
  106. return path, distance_from_start[end_node]
  107. '''
  108. def main():
  109. """Runs a few simple tests to verify the implementation.
  110. """
  111. verify_algorithm(
  112. filename="simple_graph.txt",
  113. start="A",
  114. end="G",
  115. path=["A", "D", "E", "G"],
  116. distance=11,
  117. )
  118. verify_algorithm(
  119. filename="seattle_area.txt",
  120. start="Renton",
  121. end="Redmond",
  122. path=["Renton", "Factoria", "Bellevue", "Northup", "Redmond"],
  123. distance=16,
  124. )
  125. verify_algorithm(
  126. filename="seattle_area.txt",
  127. start="Seattle",
  128. end="Redmond",
  129. path=["Seattle", "Eastlake", "Northup", "Redmond"],
  130. distance=15,
  131. )
  132. verify_algorithm(
  133. filename="seattle_area.txt",
  134. start="Eastlake",
  135. end="Issaquah",
  136. path=["Eastlake", "Seattle", "SoDo", "Factoria", "Issaquah"],
  137. distance=21,
  138. )
  139. def verify_algorithm(filename, start, end, path, distance):
  140. """Helper function to run simple tests and print results to console.
  141. filename = graph definition file
  142. start/end = path to be calculated
  143. path = expected shorted path
  144. distance = expected distance of path
  145. """
  146. graph = Graph(filename)
  147. returned_path, returned_distance = graph.shortest_path(start, end)
  148. assert list(returned_path) == path
  149. assert returned_distance == distance
  150. print('\ngraph definition file: {0}'.format(filename))
  151. print(' start/end nodes: {0} -> {1}'.format(start, end))
  152. print(' shortest path: {0}'.format(path))
  153. print(' total distance: {0}'.format(distance))
  154. if __name__ == "__main__":
  155. main()'''