dijkstra_algorithm.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. """Python 3 implementation of Djikstra's algorithm for finding the shortest
  2. path between nodes in a graph. Written as a learning exercise, so lots of
  3. comments and no error handling.
  4. """
  5. from collections import deque
  6. import math
  7. INFINITY = float("inf")
  8. class Graph:
  9. def __init__(self):
  10. """Reads graph definition and stores it. Each line of the graph
  11. definition file defines an edge by specifying the start node,
  12. end node, and distance, delimited by spaces.
  13. Stores the graph definition in two properties which are used by
  14. Dijkstra's algorithm in the shortest_path method:
  15. self.nodes = set of all unique nodes in the graph
  16. self.adjacency_list = dict that maps each node to an unordered set of
  17. (neighbor, distance) tuples.
  18. """
  19. # Read the graph definition file and store in graph_edges as a list of
  20. # lists of [from_node, to_node, distance]. This data structure is not
  21. # used by Dijkstra's algorithm, it's just an intermediate step in the
  22. # create of self.nodes and self.adjacency_list.
  23. self.points = {} # dict,{id:[x, y]}
  24. self.graph_edges = [] # list, [(node_id1, node_id2, distance)]
  25. self.nodes = set() # dict,{id:[node_id1, node_id2]}
  26. def AddVertex(self, name, point):
  27. self.points[name] = point
  28. def ResetVertexValue(self, name, point):
  29. """
  30. point : [x, y]
  31. """
  32. self.points[name] = point
  33. def AddEdge(self, name1, name2):
  34. if self.points.get(name1) == None or self.points.get(name2) == None:
  35. print("node :%s or %s not exis" % (name1, name2))
  36. return False
  37. pt1 = self.points[name1]
  38. pt2 = self.points[name2]
  39. distance = math.sqrt(math.pow(pt1[0] - pt2[0], 2) + math.pow(pt1[1] - pt2[1], 2)) + 0.000001
  40. self.graph_edges.append((name1, name2, distance))
  41. self.nodes.update([name1, name2])
  42. self.adjacency_list = {node: set() for node in self.nodes}
  43. for edge in self.graph_edges:
  44. self.adjacency_list[edge[0]].add((edge[1], edge[2]))
  45. return True
  46. def __getitem__(self, item):
  47. return self.points[item]
  48. def shortest_path(self, start_node, end_node):
  49. """Uses Dijkstra's algorithm to determine the shortest path from
  50. start_node to end_node. Returns (path, distance).
  51. """
  52. unvisited_nodes = self.nodes.copy() # All nodes are initially unvisited.
  53. # Create a dictionary of each node's distance from start_node. We will
  54. # update each node's distance whenever we find a shorter path.
  55. distance_from_start = {
  56. node: (0 if node == start_node else INFINITY) for node in self.nodes
  57. }
  58. # Initialize previous_node, the dictionary that maps each node to the
  59. # node it was visited from when the the shortest path to it was found.
  60. previous_node = {node: None for node in self.nodes}
  61. while unvisited_nodes:
  62. # Set current_node to the unvisited node with shortest distance
  63. # calculated so far.
  64. current_node = min(
  65. unvisited_nodes, key=lambda node: distance_from_start[node]
  66. )
  67. unvisited_nodes.remove(current_node)
  68. # If current_node's distance is INFINITY, the remaining unvisited
  69. # nodes are not connected to start_node, so we're done.
  70. if distance_from_start[current_node] == INFINITY:
  71. break
  72. # For each neighbor of current_node, check whether the total distance
  73. # to the neighbor via current_node is shorter than the distance we
  74. # currently have for that node. If it is, update the neighbor's values
  75. # for distance_from_start and previous_node.
  76. for neighbor, distance in self.adjacency_list[current_node]:
  77. new_path = distance_from_start[current_node] + distance
  78. if new_path < distance_from_start[neighbor]:
  79. distance_from_start[neighbor] = new_path
  80. previous_node[neighbor] = current_node
  81. if current_node == end_node:
  82. break # we've visited the destination node, so we're done
  83. # To build the path to be returned, we iterate through the nodes from
  84. # end_node back to start_node. Note the use of a deque, which can
  85. # appendleft with O(1) performance.
  86. path = deque()
  87. current_node = end_node
  88. while previous_node[current_node] is not None:
  89. path.appendleft(current_node)
  90. current_node = previous_node[current_node]
  91. path.appendleft(start_node)
  92. return path, distance_from_start[end_node]
  93. '''
  94. def main():
  95. """Runs a few simple tests to verify the implementation.
  96. """
  97. verify_algorithm(
  98. filename="simple_graph.txt",
  99. start="A",
  100. end="G",
  101. path=["A", "D", "E", "G"],
  102. distance=11,
  103. )
  104. verify_algorithm(
  105. filename="seattle_area.txt",
  106. start="Renton",
  107. end="Redmond",
  108. path=["Renton", "Factoria", "Bellevue", "Northup", "Redmond"],
  109. distance=16,
  110. )
  111. verify_algorithm(
  112. filename="seattle_area.txt",
  113. start="Seattle",
  114. end="Redmond",
  115. path=["Seattle", "Eastlake", "Northup", "Redmond"],
  116. distance=15,
  117. )
  118. verify_algorithm(
  119. filename="seattle_area.txt",
  120. start="Eastlake",
  121. end="Issaquah",
  122. path=["Eastlake", "Seattle", "SoDo", "Factoria", "Issaquah"],
  123. distance=21,
  124. )
  125. def verify_algorithm(filename, start, end, path, distance):
  126. """Helper function to run simple tests and print results to console.
  127. filename = graph definition file
  128. start/end = path to be calculated
  129. path = expected shorted path
  130. distance = expected distance of path
  131. """
  132. graph = Graph(filename)
  133. returned_path, returned_distance = graph.shortest_path(start, end)
  134. assert list(returned_path) == path
  135. assert returned_distance == distance
  136. print('\ngraph definition file: {0}'.format(filename))
  137. print(' start/end nodes: {0} -> {1}'.format(start, end))
  138. print(' shortest path: {0}'.format(path))
  139. print(' total distance: {0}'.format(distance))
  140. if __name__ == "__main__":
  141. main()'''