dijkstra_algorithm.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. """Python 3 implementation of Djikstra's algorithm for finding the shortest
  2. path between nodes in a graph. Written as a learning exercise, so lots of
  3. comments and no error handling.
  4. """
  5. from collections import deque
  6. import math
  7. INFINITY = float("inf")
  8. class Graph:
  9. def __init__(self):
  10. """Reads graph definition and stores it. Each line of the graph
  11. definition file defines an edge by specifying the start node,
  12. end node, and distance, delimited by spaces.
  13. Stores the graph definition in two properties which are used by
  14. Dijkstra's algorithm in the shortest_path method:
  15. self.nodes = set of all unique nodes in the graph
  16. self.adjacency_list = dict that maps each node to an unordered set of
  17. (neighbor, distance) tuples.
  18. """
  19. # Read the graph definition file and store in graph_edges as a list of
  20. # lists of [from_node, to_node, distance]. This data structure is not
  21. # used by Dijkstra's algorithm, it's just an intermediate step in the
  22. # create of self.nodes and self.adjacency_list.
  23. self.points={}
  24. self.graph_edges = []
  25. self.nodes = set()
  26. def AddVertex(self,name,point):
  27. self.points[name]=point
  28. def AddEdge(self,name1,name2):
  29. if self.points.get(name1)==None or self.points.get(name2)==None:
  30. print("node :%s or %s not exis"%(name1,name2))
  31. return False
  32. pt1=self.points[name1]
  33. pt2=self.points[name2]
  34. distance=math.sqrt(math.pow(pt1[0]-pt2[0],2)+math.pow(pt1[1]-pt2[1],2))+0.000001
  35. self.graph_edges.append((name1, name2, distance))
  36. self.nodes.update([name1, name2])
  37. self.adjacency_list = {node: set() for node in self.nodes}
  38. for edge in self.graph_edges:
  39. self.adjacency_list[edge[0]].add((edge[1], edge[2]))
  40. return True
  41. def __getitem__(self, item):
  42. return self.points[item]
  43. def shortest_path(self, start_node, end_node):
  44. """Uses Dijkstra's algorithm to determine the shortest path from
  45. start_node to end_node. Returns (path, distance).
  46. """
  47. unvisited_nodes = self.nodes.copy() # All nodes are initially unvisited.
  48. # Create a dictionary of each node's distance from start_node. We will
  49. # update each node's distance whenever we find a shorter path.
  50. distance_from_start = {
  51. node: (0 if node == start_node else INFINITY) for node in self.nodes
  52. }
  53. # Initialize previous_node, the dictionary that maps each node to the
  54. # node it was visited from when the the shortest path to it was found.
  55. previous_node = {node: None for node in self.nodes}
  56. while unvisited_nodes:
  57. # Set current_node to the unvisited node with shortest distance
  58. # calculated so far.
  59. current_node = min(
  60. unvisited_nodes, key=lambda node: distance_from_start[node]
  61. )
  62. unvisited_nodes.remove(current_node)
  63. # If current_node's distance is INFINITY, the remaining unvisited
  64. # nodes are not connected to start_node, so we're done.
  65. if distance_from_start[current_node] == INFINITY:
  66. break
  67. # For each neighbor of current_node, check whether the total distance
  68. # to the neighbor via current_node is shorter than the distance we
  69. # currently have for that node. If it is, update the neighbor's values
  70. # for distance_from_start and previous_node.
  71. for neighbor, distance in self.adjacency_list[current_node]:
  72. new_path = distance_from_start[current_node] + distance
  73. if new_path < distance_from_start[neighbor]:
  74. distance_from_start[neighbor] = new_path
  75. previous_node[neighbor] = current_node
  76. if current_node == end_node:
  77. break # we've visited the destination node, so we're done
  78. # To build the path to be returned, we iterate through the nodes from
  79. # end_node back to start_node. Note the use of a deque, which can
  80. # appendleft with O(1) performance.
  81. path = deque()
  82. current_node = end_node
  83. while previous_node[current_node] is not None:
  84. path.appendleft(current_node)
  85. current_node = previous_node[current_node]
  86. path.appendleft(start_node)
  87. return path, distance_from_start[end_node]
  88. '''
  89. def main():
  90. """Runs a few simple tests to verify the implementation.
  91. """
  92. verify_algorithm(
  93. filename="simple_graph.txt",
  94. start="A",
  95. end="G",
  96. path=["A", "D", "E", "G"],
  97. distance=11,
  98. )
  99. verify_algorithm(
  100. filename="seattle_area.txt",
  101. start="Renton",
  102. end="Redmond",
  103. path=["Renton", "Factoria", "Bellevue", "Northup", "Redmond"],
  104. distance=16,
  105. )
  106. verify_algorithm(
  107. filename="seattle_area.txt",
  108. start="Seattle",
  109. end="Redmond",
  110. path=["Seattle", "Eastlake", "Northup", "Redmond"],
  111. distance=15,
  112. )
  113. verify_algorithm(
  114. filename="seattle_area.txt",
  115. start="Eastlake",
  116. end="Issaquah",
  117. path=["Eastlake", "Seattle", "SoDo", "Factoria", "Issaquah"],
  118. distance=21,
  119. )
  120. def verify_algorithm(filename, start, end, path, distance):
  121. """Helper function to run simple tests and print results to console.
  122. filename = graph definition file
  123. start/end = path to be calculated
  124. path = expected shorted path
  125. distance = expected distance of path
  126. """
  127. graph = Graph(filename)
  128. returned_path, returned_distance = graph.shortest_path(start, end)
  129. assert list(returned_path) == path
  130. assert returned_distance == distance
  131. print('\ngraph definition file: {0}'.format(filename))
  132. print(' start/end nodes: {0} -> {1}'.format(start, end))
  133. print(' shortest path: {0}'.format(path))
  134. print(' total distance: {0}'.format(distance))
  135. if __name__ == "__main__":
  136. main()'''