mqttsas.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. """
  2. *******************************************************************
  3. Copyright (c) 2013, 2020 IBM Corp.
  4. All rights reserved. This program and the accompanying materials
  5. are made available under the terms of the Eclipse Public License v2.0
  6. and Eclipse Distribution License v1.0 which accompany this distribution.
  7. The Eclipse Public License is available at
  8. https://www.eclipse.org/legal/epl-2.0/
  9. and the Eclipse Distribution License is available at
  10. http://www.eclipse.org/org/documents/edl-v10.php.
  11. Contributors:
  12. Ian Craggs - initial implementation and/or documentation
  13. Ian Craggs - add MQTTV5 support
  14. *******************************************************************
  15. """
  16. from __future__ import print_function
  17. import socket
  18. import sys
  19. import select
  20. import traceback
  21. import datetime
  22. import os
  23. import base64
  24. import hashlib
  25. import logging
  26. try:
  27. import socketserver
  28. import MQTTV311 # Trace MQTT traffic - Python 3 version
  29. import MQTTV5
  30. except:
  31. traceback.print_exc()
  32. import SocketServer as socketserver
  33. import MQTTV3112 as MQTTV311 # Trace MQTT traffic - Python 2 version
  34. import MQTTV5
  35. MQTT = MQTTV311
  36. logging = True
  37. myWindow = None
  38. class BufferedSockets:
  39. def __init__(self, socket):
  40. self.socket = socket
  41. self.buffer = bytearray()
  42. self.websockets = False
  43. def close(self):
  44. self.socket.shutdown(socket.SHUT_RDWR)
  45. self.socket.close()
  46. def rebuffer(self, data):
  47. self.buffer = data + self.buffer
  48. def wsrecv(self):
  49. try:
  50. header1 = ord(self.socket.recv(1))
  51. header2 = ord(self.socket.recv(1))
  52. except:
  53. return
  54. opcode = (header1 & 0x0f)
  55. maskbit = (header2 & 0x80) == 0x80
  56. length = (header2 & 0x7f) # works for 0 to 125 inclusive
  57. if length == 126: # for 126 to 65535 inclusive
  58. lb1 = ord(self.socket.recv(1))
  59. lb2 = ord(self.socket.recv(1))
  60. length = lb1*256 + lb2
  61. elif length == 127:
  62. length = 0
  63. for i in range(0, 8):
  64. length += ord(self.socket.recv(1)) * 2**((7 - i)*8)
  65. assert maskbit == True
  66. if maskbit:
  67. mask = self.socket.recv(4)
  68. mpayload = bytearray()
  69. while len(mpayload) < length:
  70. mpayload += self.socket.recv(length - len(mpayload))
  71. buffer = bytearray()
  72. if maskbit:
  73. mi = 0
  74. for i in mpayload:
  75. buffer.append(i ^ mask[mi])
  76. mi = (mi+1) % 4
  77. else:
  78. buffer = mpayload
  79. self.buffer += buffer
  80. def recv(self, bufsize):
  81. if self.websockets:
  82. while len(self.buffer) < bufsize:
  83. self.wsrecv()
  84. out = self.buffer[:bufsize]
  85. self.buffer = self.buffer[bufsize:]
  86. else:
  87. if bufsize <= len(self.buffer):
  88. out = self.buffer[:bufsize]
  89. self.buffer = self.buffer[bufsize:]
  90. else:
  91. out = self.buffer + \
  92. self.socket.recv(bufsize - len(self.buffer))
  93. self.buffer = bytes()
  94. return out
  95. def __getattr__(self, name):
  96. return getattr(self.socket, name)
  97. def send(self, data):
  98. header = bytearray()
  99. if self.websockets:
  100. header.append(0x82) # opcode
  101. l = len(data)
  102. if l < 126:
  103. header.append(l)
  104. elif l < 65536:
  105. """ If 126, the following 2 bytes interpreted as a 16-bit unsigned integer are
  106. the payload length.
  107. """
  108. header += bytearray([126, l // 256, l % 256])
  109. elif l < 2**64:
  110. """ If 127, the following 8 bytes interpreted as a 64-bit unsigned integer (the
  111. most significant bit MUST be 0) are the payload length.
  112. """
  113. mybytes = [127]
  114. for i in range(0, 7):
  115. divisor = 2**((7 - i)*8)
  116. mybytes.append(l // divisor)
  117. l %= divisor
  118. mybytes.append(l) # units
  119. header += bytearray(mybytes)
  120. totaldata = header + data
  121. # Ensure the entire packet is sent by calling send again if necessary
  122. sent = self.socket.send(totaldata)
  123. while sent < len(totaldata):
  124. sent += self.socket.send(totaldata[sent:])
  125. return sent
  126. def timestamp():
  127. now = datetime.datetime.now()
  128. return now.strftime('%Y%m%d %H%M%S')+str(float("."+str(now.microsecond)))[1:]
  129. suspended = []
  130. class MyHandler(socketserver.StreamRequestHandler):
  131. def getheaders(self, data):
  132. "return headers: keys are converted to upper case so that checks are case insensitive"
  133. headers = {}
  134. lines = data.splitlines()
  135. for curline in lines[1:]:
  136. if curline.find(":") != -1:
  137. key, value = curline.split(": ", 1)
  138. headers[key.upper()] = value # headers are case insensitive
  139. return headers
  140. def handshake(self, client):
  141. GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  142. data = client.recv(1024).decode('utf-8')
  143. headers = self.getheaders(data)
  144. digest = base64.b64encode(hashlib.sha1(
  145. (headers['SEC-WEBSOCKET-KEY'] + GUID).encode("utf-8")).digest())
  146. resp = b"HTTP/1.1 101 Switching Protocols\r\n" +\
  147. b"Upgrade: websocket\r\n" +\
  148. b"Connection: Upgrade\r\n" +\
  149. b"Sec-WebSocket-Protocol: mqtt\r\n" +\
  150. b"Sec-WebSocket-Accept: " + digest + b"\r\n\r\n"
  151. return client.send(resp)
  152. def handle(self):
  153. global MQTT
  154. if not hasattr(self, "ids"):
  155. self.ids = {}
  156. if not hasattr(self, "versions"):
  157. self.versions = {}
  158. inbuf = True
  159. first = True
  160. i = o = e = None
  161. try:
  162. clients = BufferedSockets(self.request)
  163. sock_no = clients.fileno()
  164. brokers = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  165. brokers.connect((brokerhost, brokerport))
  166. terminated = False
  167. while inbuf != None and not terminated:
  168. (i, o, e) = select.select([clients, brokers], [], [])
  169. for s in i:
  170. if s in suspended:
  171. print("suspended")
  172. if s == clients and s not in suspended:
  173. if first:
  174. char = clients.recv(1)
  175. clients.rebuffer(char)
  176. if char == b"G": # should be websocket connection
  177. self.handshake(clients)
  178. clients.websockets = True
  179. print("Switching to websockets for socket %d" % sock_no)
  180. inbuf = MQTT.getPacket(clients) # get one packet
  181. if inbuf == None:
  182. break
  183. try:
  184. # if connect, this could be MQTTV3 or MQTTV5
  185. if inbuf[0] >> 4 == 1: # connect packet
  186. protocol_string = b'MQTT'
  187. pos = inbuf.find(protocol_string)
  188. if pos != -1:
  189. version = inbuf[pos +
  190. len(protocol_string)]
  191. if version == 5:
  192. MQTT = MQTTV5
  193. else:
  194. MQTT = MQTTV311
  195. packet = MQTT.unpackPacket(inbuf)
  196. if hasattr(packet.fh, "MessageType"):
  197. packet_type = packet.fh.MessageType
  198. publish_type = MQTT.PUBLISH
  199. connect_type = MQTT.CONNECT
  200. else:
  201. packet_type = packet.fh.PacketType
  202. publish_type = MQTT.PacketTypes.PUBLISH
  203. connect_type = MQTT.PacketTypes.CONNECT
  204. if packet_type == publish_type and \
  205. packet.topicName == "MQTTSAS topic" and \
  206. packet.data == b"TERMINATE":
  207. print("Terminating client", self.ids[id(clients)])
  208. brokers.close()
  209. clients.close()
  210. terminated = True
  211. break
  212. elif packet_type == publish_type and \
  213. packet.topicName == "MQTTSAS topic" and \
  214. packet.data == b"TERMINATE_SERVER":
  215. print("Suspending client ", self.ids[id(clients)])
  216. suspended.append(clients)
  217. elif packet_type == connect_type:
  218. self.ids[id(clients)
  219. ] = packet.ClientIdentifier
  220. self.versions[id(clients)] = 3
  221. print(timestamp(), "C to S",
  222. self.ids[id(clients)], str(packet))
  223. #print([hex(b) for b in inbuf])
  224. # print(inbuf)
  225. except:
  226. traceback.print_exc()
  227. brokers.send(inbuf) # pass it on
  228. elif s == brokers:
  229. inbuf = MQTT.getPacket(brokers) # get one packet
  230. if inbuf == None:
  231. break
  232. try:
  233. print(timestamp(), "S to C", self.ids[id(clients)], str(MQTT.unpackPacket(inbuf)))
  234. except:
  235. traceback.print_exc()
  236. clients.send(inbuf)
  237. print(timestamp()+" client " + self.ids[id(clients)]+" connection closing")
  238. first = False
  239. except:
  240. print(repr((i, o, e)), repr(inbuf))
  241. traceback.print_exc()
  242. if id(clients) in self.ids.keys():
  243. del self.ids[id(clients)]
  244. elif id(clients) in self.versions.keys():
  245. del self.versions[id(clients)]
  246. class ThreadingTCPServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
  247. pass
  248. def run():
  249. global brokerhost, brokerport
  250. myhost = '127.0.0.1'
  251. if len(sys.argv) > 1:
  252. brokerhost = sys.argv[1]
  253. else:
  254. brokerhost = '127.0.0.1'
  255. if len(sys.argv) > 2:
  256. brokerport = int(sys.argv[2])
  257. else:
  258. brokerport = 1883
  259. if len(sys.argv) > 3:
  260. myport = int(sys.argv[3])
  261. else:
  262. if brokerhost == myhost:
  263. myport = brokerport + 1
  264. else:
  265. myport = 1883
  266. print("Listening on port", str(myport)+", broker on port", brokerport)
  267. s = ThreadingTCPServer(("127.0.0.1", myport), MyHandler)
  268. s.serve_forever()
  269. if __name__ == "__main__":
  270. run()