diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a525891e..06a48391 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,6 +7,11 @@ # default_language_version: # python: python3 repos: + - repo: https://github.com/psf/black + rev: 23.12.1 + hooks: + - id: black + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.5.0 hooks: @@ -37,3 +42,10 @@ repos: rev: v0.15 hooks: - id: validate-pyproject + + - repo: https://github.com/codespell-project/codespell + rev: v2.2.6 + hooks: + - id: codespell + additional_dependencies: + - tomli diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 468cef87..b0ee782d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -28,7 +28,7 @@ Please read the [Eclipse Foundation policy on accepting contributions via Git](h cherry-picked to the release branch. The only changes that goes directly to the release branch (``1.4``, - ``1.5``, ...) are bug fixe that does not apply to ``master`` (e.g. because + ``1.5``, ...) are bug fixes that does not apply to ``master`` (e.g. because there are fixed on master by a refactoring, or any other huge change we do not want to cherry-pick to the release branch). 4. Create a new branch from the latest ```master``` branch diff --git a/ChangeLog.txt b/ChangeLog.txt index c86f40c4..4b2ebc3b 100644 --- a/ChangeLog.txt +++ b/ChangeLog.txt @@ -3,6 +3,16 @@ v2.0.0 - 2023-xx-xx - **BREAKING** Drop support for Python 2.7, Python 3.5 and Python 3.6 Minimum tested version is Python 3.7 +- **BREAKING** connnect_srv changed it signature to take an additional bind_port parameter. + This is a breaking change, but in previous version connect_srv was broken anyway. + Closes #493. +- Add types to Client class, which caused few change which should be compatible. + Known risk of breaking changes: + - Use enum for returned error code (like MQTT_ERR_SUCCESS). It use an IntEnum + which should be a drop-in replacement. Excepted if someone is doing "rc is 0" instead of "rc == 0". + - reason in on_connect callback when using MQTTv5 is now always a ReasonCode object. It used to possibly be + an integer with the value 132. + - MQTTMessage field "dup" and "retain" used to be integer with value 0 and 1. They are now boolean. - Add on_pre_connect() callback, which is called immediately before a connection attempt is made. diff --git a/examples/aws_iot.py b/examples/aws_iot.py index e65aab71..60686003 100755 --- a/examples/aws_iot.py +++ b/examples/aws_iot.py @@ -10,7 +10,7 @@ def get_amazon_auth_headers(access_key, secret_key, region, host, port, headers=None): - """ Get the amazon auth headers for working with the amazon websockets + """Get the amazon auth headers for working with the amazon websockets protocol Requires a lot of extra stuff: @@ -47,8 +47,8 @@ def getSignatureKey(key, dateStamp, regionName, serviceName): algorithm = "AWS4-HMAC-SHA256" t = datetime.datetime.utcnow() - amzdate = t.strftime('%Y%m%dT%H%M%SZ') - datestamp = t.strftime("%Y%m%d") # Date w/o time, used in credential scope + amzdate = t.strftime("%Y%m%dT%H%M%SZ") + datestamp = t.strftime("%Y%m%d") # Date w/o time, used in credential scope if headers is None: headers = { @@ -61,12 +61,16 @@ def getSignatureKey(key, dateStamp, regionName, serviceName): "Sec-Websocket-Protocol": "mqtt", } - headers.update({ - "X-Amz-Date": amzdate, - }) + headers.update( + { + "X-Amz-Date": amzdate, + } + ) # get into 'canonical' form - lowercase, sorted alphabetically - canonical_headers = "\n".join(sorted("{}:{}".format(i.lower(), j).strip() for i, j in headers.items())) + canonical_headers = "\n".join( + sorted("{}:{}".format(i.lower(), j).strip() for i, j in headers.items()) + ) # Headers to sign - alphabetical order signed_headers = ";".join(sorted(i.lower().strip() for i in headers.keys())) @@ -88,14 +92,24 @@ def getSignatureKey(key, dateStamp, regionName, serviceName): # now actually hash request and sign hashed_request = hashlib.sha256(canonical_request).hexdigest() - credential_scope = "{datestamp:s}/{region:s}/{service:s}/aws4_request".format(**locals()) - string_to_sign = "{algorithm:s}\n{amzdate:s}\n{credential_scope:s}\n{hashed_request:s}".format(**locals()) + credential_scope = "{datestamp:s}/{region:s}/{service:s}/aws4_request".format( + **locals() + ) + string_to_sign = ( + "{algorithm:s}\n{amzdate:s}\n{credential_scope:s}\n{hashed_request:s}".format( + **locals() + ) + ) signing_key = getSignatureKey(secret_key, datestamp, region, service) - signature = hmac.new(signing_key, (string_to_sign).encode('utf-8'), hashlib.sha256).hexdigest() + signature = hmac.new( + signing_key, (string_to_sign).encode("utf-8"), hashlib.sha256 + ).hexdigest() # create auth header - authorization_header = "{algorithm:s} Credential={access_key:s}/{credential_scope:s}, SignedHeaders={signed_headers:s}, Signature={signature:s}".format(**locals()) + authorization_header = "{algorithm:s} Credential={access_key:s}/{credential_scope:s}, SignedHeaders={signed_headers:s}, Signature={signature:s}".format( + **locals() + ) # get final header string headers["Authorization"] = authorization_header diff --git a/examples/client_mqtt_clear_retain.py b/examples/client_mqtt_clear_retain.py index 5e100c5d..48022c19 100755 --- a/examples/client_mqtt_clear_retain.py +++ b/examples/client_mqtt_clear_retain.py @@ -55,7 +55,8 @@ def on_log(mqttc, userdata, level, string): def print_usage(): print( - "mqtt_clear_retain.py [-d] [-h hostname] [-i clientid] [-k keepalive] [-p port] [-u username [-P password]] [-v] -t topic") + "mqtt_clear_retain.py [-d] [-h hostname] [-i clientid] [-k keepalive] [-p port] [-u username [-P password]] [-v] -t topic" + ) def main(argv): @@ -70,8 +71,20 @@ def main(argv): verbose = False try: - opts, args = getopt.getopt(argv, "dh:i:k:p:P:t:u:v", - ["debug", "id", "keepalive", "port", "password", "topic", "username", "verbose"]) + opts, args = getopt.getopt( + argv, + "dh:i:k:p:P:t:u:v", + [ + "debug", + "id", + "keepalive", + "port", + "password", + "topic", + "username", + "verbose", + ], + ) except getopt.GetoptError: print_usage() sys.exit(2) diff --git a/examples/client_pub_opts.py b/examples/client_pub_opts.py index d09ffb53..fd59880c 100755 --- a/examples/client_pub_opts.py +++ b/examples/client_pub_opts.py @@ -24,22 +24,53 @@ parser = argparse.ArgumentParser() -parser.add_argument('-H', '--host', required=False, default="mqtt.eclipseprojects.io") -parser.add_argument('-t', '--topic', required=False, default="paho/test/opts") -parser.add_argument('-q', '--qos', required=False, type=int,default=0) -parser.add_argument('-c', '--clientid', required=False, default=None) -parser.add_argument('-u', '--username', required=False, default=None) -parser.add_argument('-d', '--disable-clean-session', action='store_true', help="disable 'clean session' (sub + msgs not cleared when client disconnects)") -parser.add_argument('-p', '--password', required=False, default=None) -parser.add_argument('-P', '--port', required=False, type=int, default=None, help='Defaults to 8883 for TLS or 1883 for non-TLS') -parser.add_argument('-N', '--nummsgs', required=False, type=int, default=1, help='send this many messages before disconnecting') -parser.add_argument('-S', '--delay', required=False, type=float, default=1, help='number of seconds to sleep between msgs') -parser.add_argument('-k', '--keepalive', required=False, type=int, default=60) -parser.add_argument('-s', '--use-tls', action='store_true') -parser.add_argument('--insecure', action='store_true') -parser.add_argument('-F', '--cacerts', required=False, default=None) -parser.add_argument('--tls-version', required=False, default=None, help='TLS protocol version, can be one of tlsv1.2 tlsv1.1 or tlsv1\n') -parser.add_argument('-D', '--debug', action='store_true') +parser.add_argument("-H", "--host", required=False, default="mqtt.eclipseprojects.io") +parser.add_argument("-t", "--topic", required=False, default="paho/test/opts") +parser.add_argument("-q", "--qos", required=False, type=int, default=0) +parser.add_argument("-c", "--clientid", required=False, default=None) +parser.add_argument("-u", "--username", required=False, default=None) +parser.add_argument( + "-d", + "--disable-clean-session", + action="store_true", + help="disable 'clean session' (sub + msgs not cleared when client disconnects)", +) +parser.add_argument("-p", "--password", required=False, default=None) +parser.add_argument( + "-P", + "--port", + required=False, + type=int, + default=None, + help="Defaults to 8883 for TLS or 1883 for non-TLS", +) +parser.add_argument( + "-N", + "--nummsgs", + required=False, + type=int, + default=1, + help="send this many messages before disconnecting", +) +parser.add_argument( + "-S", + "--delay", + required=False, + type=float, + default=1, + help="number of seconds to sleep between msgs", +) +parser.add_argument("-k", "--keepalive", required=False, type=int, default=60) +parser.add_argument("-s", "--use-tls", action="store_true") +parser.add_argument("--insecure", action="store_true") +parser.add_argument("-F", "--cacerts", required=False, default=None) +parser.add_argument( + "--tls-version", + required=False, + default=None, + help="TLS protocol version, can be one of tlsv1.2 tlsv1.1 or tlsv1\n", +) +parser.add_argument("-D", "--debug", action="store_true") args, unknown = parser.parse_known_args() @@ -63,6 +94,7 @@ def on_subscribe(mqttc, obj, mid, granted_qos): def on_log(mqttc, obj, level, string): print(string) + usetls = args.use_tls if args.cacerts: @@ -75,27 +107,33 @@ def on_log(mqttc, obj, level, string): else: port = 1883 -mqttc = mqtt.Client(args.clientid,clean_session = not args.disable_clean_session) +mqttc = mqtt.Client(args.clientid, clean_session=not args.disable_clean_session) if usetls: if args.tls_version == "tlsv1.2": - tlsVersion = ssl.PROTOCOL_TLSv1_2 + tlsVersion = ssl.PROTOCOL_TLSv1_2 elif args.tls_version == "tlsv1.1": - tlsVersion = ssl.PROTOCOL_TLSv1_1 + tlsVersion = ssl.PROTOCOL_TLSv1_1 elif args.tls_version == "tlsv1": - tlsVersion = ssl.PROTOCOL_TLSv1 + tlsVersion = ssl.PROTOCOL_TLSv1 elif args.tls_version is None: - tlsVersion = None + tlsVersion = None else: - print ("Unknown TLS version - ignoring") - tlsVersion = None + print("Unknown TLS version - ignoring") + tlsVersion = None if not args.insecure: cert_required = ssl.CERT_REQUIRED else: cert_required = ssl.CERT_NONE - mqttc.tls_set(ca_certs=args.cacerts, certfile=None, keyfile=None, cert_reqs=cert_required, tls_version=tlsVersion) + mqttc.tls_set( + ca_certs=args.cacerts, + certfile=None, + keyfile=None, + cert_reqs=cert_required, + tls_version=tlsVersion, + ) if args.insecure: mqttc.tls_insecure_set(True) @@ -111,18 +149,17 @@ def on_log(mqttc, obj, level, string): if args.debug: mqttc.on_log = on_log -print("Connecting to "+args.host+" port: "+str(port)) +print("Connecting to " + args.host + " port: " + str(port)) mqttc.connect(args.host, port, args.keepalive) mqttc.loop_start() -for x in range (0, args.nummsgs): - msg_txt = '{"msgnum": "'+str(x)+'"}' - print("Publishing: "+msg_txt) +for x in range(0, args.nummsgs): + msg_txt = '{"msgnum": "' + str(x) + '"}' + print("Publishing: " + msg_txt) infot = mqttc.publish(args.topic, msg_txt, qos=args.qos) infot.wait_for_publish() time.sleep(args.delay) mqttc.disconnect() - diff --git a/examples/client_rpc_math.py b/examples/client_rpc_math.py index 918f5b02..6bc53e0f 100755 --- a/examples/client_rpc_math.py +++ b/examples/client_rpc_math.py @@ -33,15 +33,16 @@ # This correlates the outbound request with the returned reply corr_id = b"1" -# This is sent in the message callback when we get the respone +# This is sent in the message callback when we get the response reply = None + # The MQTTv5 callback takes the additional 'props' parameter. def on_connect(mqttc, userdata, flags, rc, props): global client_id, reply_to - print("Connected: '"+str(flags)+"', '"+str(rc)+"', '"+str(props)) - if hasattr(props, 'AssignedClientIdentifier'): + print("Connected: '" + str(flags) + "', '" + str(rc) + "', '" + str(props)) + if hasattr(props, "AssignedClientIdentifier"): client_id = props.AssignedClientIdentifier reply_to = "replies/math/" + client_id mqttc.subscribe(reply_to) @@ -51,9 +52,9 @@ def on_connect(mqttc, userdata, flags, rc, props): def on_message(mqttc, userdata, msg): global reply - print(msg.topic+" "+str(msg.payload)+" "+str(msg.properties)) + print(msg.topic + " " + str(msg.payload) + " " + str(msg.properties)) props = msg.properties - if not hasattr(props, 'CorrelationData'): + if not hasattr(props, "CorrelationData"): print("No correlation ID") # Match the response to the request correlation ID. @@ -69,7 +70,7 @@ def on_message(mqttc, userdata, msg): mqttc.on_message = on_message mqttc.on_connect = on_connect -mqttc.connect(host='localhost', clean_start=True) +mqttc.connect(host="localhost", clean_start=True) mqttc.loop_start() # Wait for connection to set `client_id`, etc. @@ -82,9 +83,9 @@ def on_message(mqttc, userdata, msg): props.ResponseTopic = reply_to # Uncomment to see what got set -#print("Client ID: "+client_id) -#print("Reply To: "+reply_to) -#print(props) +# print("Client ID: "+client_id) +# print("Reply To: "+reply_to) +# print(props) # The requested operation, 'add' or 'mult' func = sys.argv[1] @@ -106,7 +107,6 @@ def on_message(mqttc, userdata, msg): # Extract the response and print it. rsp = json.loads(reply) -print("Response: "+str(rsp)) +print("Response: " + str(rsp)) mqttc.loop_stop() - diff --git a/examples/client_session_present.py b/examples/client_session_present.py index 806d1aff..4b4a6c47 100755 --- a/examples/client_session_present.py +++ b/examples/client_session_present.py @@ -29,7 +29,7 @@ def on_connect(mqttc, obj, flags, rc): print("Second connection:") elif obj == 2: print("Third connection (with clean session=True):") - print(" Session present: " + str(flags['session present'])) + print(" Session present: " + str(flags["session present"])) print(" Connection result: " + str(rc)) mqttc.disconnect() diff --git a/examples/client_sub-class.py b/examples/client_sub-class.py index d5776253..5eb220f8 100755 --- a/examples/client_sub-class.py +++ b/examples/client_sub-class.py @@ -21,21 +21,20 @@ class MyMQTTClass(mqtt.Client): - def on_connect(self, mqttc, obj, flags, rc): - print("rc: "+str(rc)) + print("rc: " + str(rc)) def on_connect_fail(self, mqttc, obj): print("Connect failed") def on_message(self, mqttc, obj, msg): - print(msg.topic+" "+str(msg.qos)+" "+str(msg.payload)) + print(msg.topic + " " + str(msg.qos) + " " + str(msg.payload)) def on_publish(self, mqttc, obj, mid): - print("mid: "+str(mid)) + print("mid: " + str(mid)) def on_subscribe(self, mqttc, obj, mid, granted_qos): - print("Subscribed: "+str(mid)+" "+str(granted_qos)) + print("Subscribed: " + str(mid) + " " + str(granted_qos)) def on_log(self, mqttc, obj, level, string): print(string) @@ -57,4 +56,4 @@ def run(self): mqttc = MyMQTTClass() rc = mqttc.run() -print("rc: "+str(rc)) +print("rc: " + str(rc)) diff --git a/examples/client_sub-srv.py b/examples/client_sub-srv.py index f3ebd386..af728800 100755 --- a/examples/client_sub-srv.py +++ b/examples/client_sub-srv.py @@ -25,18 +25,23 @@ def on_connect(mqttc, obj, flags, rc): print("Connected to %s:%s" % (mqttc._host, mqttc._port)) + def on_message(mqttc, obj, msg): - print(msg.topic+" "+str(msg.qos)+" "+str(msg.payload)) + print(msg.topic + " " + str(msg.qos) + " " + str(msg.payload)) + def on_publish(mqttc, obj, mid): - print("mid: "+str(mid)) + print("mid: " + str(mid)) + def on_subscribe(mqttc, obj, mid, granted_qos): - print("Subscribed: "+str(mid)+" "+str(granted_qos)) + print("Subscribed: " + str(mid) + " " + str(granted_qos)) + def on_log(mqttc, obj, level, string): print(string) + # If you want to use a specific client id, use # mqttc = mqtt.Client("client-id") # but note that the client id must be unique on the broker. Leaving the client @@ -47,7 +52,7 @@ def on_log(mqttc, obj, level, string): mqttc.on_publish = on_publish mqttc.on_subscribe = on_subscribe # Uncomment to enable debug messages -#mqttc.on_log = on_log +# mqttc.on_log = on_log mqttc.connect_srv("mosquitto.org", 60) mqttc.subscribe("$SYS/broker/version", 0) @@ -56,4 +61,4 @@ def on_log(mqttc, obj, level, string): while rc == 0: rc = mqttc.loop() -print("rc: "+str(rc)) +print("rc: " + str(rc)) diff --git a/examples/client_sub-ws.py b/examples/client_sub-ws.py index 085d4850..446912f6 100755 --- a/examples/client_sub-ws.py +++ b/examples/client_sub-ws.py @@ -23,20 +23,25 @@ def on_connect(mqttc, obj, flags, rc): - print("rc: "+str(rc)) + print("rc: " + str(rc)) + def on_message(mqttc, obj, msg): - print(msg.topic+" "+str(msg.qos)+" "+str(msg.payload)) + print(msg.topic + " " + str(msg.qos) + " " + str(msg.payload)) + def on_publish(mqttc, obj, mid): - print("mid: "+str(mid)) + print("mid: " + str(mid)) + def on_subscribe(mqttc, obj, mid, granted_qos): - print("Subscribed: "+str(mid)+" "+str(granted_qos)) + print("Subscribed: " + str(mid) + " " + str(granted_qos)) + def on_log(mqttc, obj, level, string): print(string) + # If you want to use a specific client id, use # mqttc = mqtt.Client("client-id") # but note that the client id must be unique on the broker. Leaving the client diff --git a/examples/client_sub_opts.py b/examples/client_sub_opts.py index b69ceb51..bd5c6604 100755 --- a/examples/client_sub_opts.py +++ b/examples/client_sub_opts.py @@ -23,20 +23,37 @@ parser = argparse.ArgumentParser() -parser.add_argument('-H', '--host', required=False, default="mqtt.eclipseprojects.io") -parser.add_argument('-t', '--topic', required=False, default="$SYS/#") -parser.add_argument('-q', '--qos', required=False, type=int, default=0) -parser.add_argument('-c', '--clientid', required=False, default=None) -parser.add_argument('-u', '--username', required=False, default=None) -parser.add_argument('-d', '--disable-clean-session', action='store_true', help="disable 'clean session' (sub + msgs not cleared when client disconnects)") -parser.add_argument('-p', '--password', required=False, default=None) -parser.add_argument('-P', '--port', required=False, type=int, default=None, help='Defaults to 8883 for TLS or 1883 for non-TLS') -parser.add_argument('-k', '--keepalive', required=False, type=int, default=60) -parser.add_argument('-s', '--use-tls', action='store_true') -parser.add_argument('--insecure', action='store_true') -parser.add_argument('-F', '--cacerts', required=False, default=None) -parser.add_argument('--tls-version', required=False, default=None, help='TLS protocol version, can be one of tlsv1.2 tlsv1.1 or tlsv1\n') -parser.add_argument('-D', '--debug', action='store_true') +parser.add_argument("-H", "--host", required=False, default="mqtt.eclipseprojects.io") +parser.add_argument("-t", "--topic", required=False, default="$SYS/#") +parser.add_argument("-q", "--qos", required=False, type=int, default=0) +parser.add_argument("-c", "--clientid", required=False, default=None) +parser.add_argument("-u", "--username", required=False, default=None) +parser.add_argument( + "-d", + "--disable-clean-session", + action="store_true", + help="disable 'clean session' (sub + msgs not cleared when client disconnects)", +) +parser.add_argument("-p", "--password", required=False, default=None) +parser.add_argument( + "-P", + "--port", + required=False, + type=int, + default=None, + help="Defaults to 8883 for TLS or 1883 for non-TLS", +) +parser.add_argument("-k", "--keepalive", required=False, type=int, default=60) +parser.add_argument("-s", "--use-tls", action="store_true") +parser.add_argument("--insecure", action="store_true") +parser.add_argument("-F", "--cacerts", required=False, default=None) +parser.add_argument( + "--tls-version", + required=False, + default=None, + help="TLS protocol version, can be one of tlsv1.2 tlsv1.1 or tlsv1\n", +) +parser.add_argument("-D", "--debug", action="store_true") args, unknown = parser.parse_known_args() @@ -60,6 +77,7 @@ def on_subscribe(mqttc, obj, mid, granted_qos): def on_log(mqttc, obj, level, string): print(string) + usetls = args.use_tls if args.cacerts: @@ -72,27 +90,33 @@ def on_log(mqttc, obj, level, string): else: port = 1883 -mqttc = mqtt.Client(args.clientid,clean_session = not args.disable_clean_session) +mqttc = mqtt.Client(args.clientid, clean_session=not args.disable_clean_session) if usetls: if args.tls_version == "tlsv1.2": - tlsVersion = ssl.PROTOCOL_TLSv1_2 + tlsVersion = ssl.PROTOCOL_TLSv1_2 elif args.tls_version == "tlsv1.1": - tlsVersion = ssl.PROTOCOL_TLSv1_1 + tlsVersion = ssl.PROTOCOL_TLSv1_1 elif args.tls_version == "tlsv1": - tlsVersion = ssl.PROTOCOL_TLSv1 + tlsVersion = ssl.PROTOCOL_TLSv1 elif args.tls_version is None: - tlsVersion = None + tlsVersion = None else: - print ("Unknown TLS version - ignoring") - tlsVersion = None + print("Unknown TLS version - ignoring") + tlsVersion = None if not args.insecure: cert_required = ssl.CERT_REQUIRED else: cert_required = ssl.CERT_NONE - mqttc.tls_set(ca_certs=args.cacerts, certfile=None, keyfile=None, cert_reqs=cert_required, tls_version=tlsVersion) + mqttc.tls_set( + ca_certs=args.cacerts, + certfile=None, + keyfile=None, + cert_reqs=cert_required, + tls_version=tlsVersion, + ) if args.insecure: mqttc.tls_insecure_set(True) @@ -108,7 +132,7 @@ def on_log(mqttc, obj, level, string): if args.debug: mqttc.on_log = on_log -print("Connecting to "+args.host+" port: "+str(port)) +print("Connecting to " + args.host + " port: " + str(port)) mqttc.connect(args.host, port, args.keepalive) mqttc.subscribe(args.topic, args.qos) diff --git a/examples/context.py b/examples/context.py index faef26aa..9e5bdbe9 100755 --- a/examples/context.py +++ b/examples/context.py @@ -14,11 +14,7 @@ cmd_subfolder = os.path.realpath( os.path.abspath( os.path.join( - os.path.split( - inspect.getfile(inspect.currentframe()) - )[0], - "..", - "src" + os.path.split(inspect.getfile(inspect.currentframe()))[0], "..", "src" ) ) ) diff --git a/examples/loop_asyncio.py b/examples/loop_asyncio.py index f4f22c79..7d58d44c 100755 --- a/examples/loop_asyncio.py +++ b/examples/loop_asyncio.py @@ -8,7 +8,7 @@ import paho.mqtt.client as mqtt -client_id = 'paho-mqtt-python/issue72/' + str(uuid.uuid4()) +client_id = "paho-mqtt-python/issue72/" + str(uuid.uuid4()) topic = client_id print("Using client_id / topic: " + client_id) @@ -88,14 +88,14 @@ async def main(self): aioh = AsyncioHelper(self.loop, self.client) - self.client.connect('mqtt.eclipseprojects.io', 1883, 60) + self.client.connect("mqtt.eclipseprojects.io", 1883, 60) self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) for c in range(3): await asyncio.sleep(5) print("Publishing") self.got_message = self.loop.create_future() - self.client.publish(topic, b'Hello' * 40000, qos=1) + self.client.publish(topic, b"Hello" * 40000, qos=1) msg = await self.got_message print("Got response with {} bytes".format(len(msg))) self.got_message = None diff --git a/examples/loop_select.py b/examples/loop_select.py index 328b1030..dfe5acbb 100755 --- a/examples/loop_select.py +++ b/examples/loop_select.py @@ -7,7 +7,7 @@ import paho.mqtt.client as mqtt -client_id = 'paho-mqtt-python/issue72/' + str(uuid.uuid4()) +client_id = "paho-mqtt-python/issue72/" + str(uuid.uuid4()) topic = client_id print("Using client_id / topic: " + client_id) @@ -37,13 +37,11 @@ def do_select(self): if not sock: raise Exception("Socket is gone") - print("Selecting for reading" + (" and writing" if self.client.want_write() else "")) - r, w, e = select( - [sock], - [sock] if self.client.want_write() else [], - [], - 1 + print( + "Selecting for reading" + + (" and writing" if self.client.want_write() else "") ) + r, w, e = select([sock], [sock] if self.client.want_write() else [], [], 1) if sock in r: print("Socket is readable, calling loop_read") @@ -65,7 +63,7 @@ def main(self): self.client.on_message = self.on_message self.client.on_disconnect = self.on_disconnect - self.client.connect('mqtt.eclipseprojects.io', 1883, 60) + self.client.connect("mqtt.eclipseprojects.io", 1883, 60) print("Socket opened") self.client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) @@ -75,7 +73,7 @@ def main(self): if self.state in {0, 2, 4}: if time() - self.t >= 5: print("Publishing") - self.client.publish(topic, b'Hello' * 40000) + self.client.publish(topic, b"Hello" * 40000) self.state += 1 if self.state == 6: diff --git a/examples/loop_trio.py b/examples/loop_trio.py index 8efe6cbf..95c63660 100755 --- a/examples/loop_trio.py +++ b/examples/loop_trio.py @@ -7,7 +7,7 @@ import paho.mqtt.client as mqtt -client_id = 'paho-mqtt-python/issue72/' + str(uuid.uuid4()) +client_id = "paho-mqtt-python/issue72/" + str(uuid.uuid4()) topic = client_id print("Using client_id / topic: " + client_id) @@ -45,7 +45,7 @@ def on_socket_open(self, client, userdata, sock): self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) def on_socket_register_write(self, client, userdata, sock): - print('large write request') + print("large write request") self._event_large_write.set() def on_socket_unregister_write(self, client, userdata, sock): @@ -62,13 +62,13 @@ def on_message(self, client, userdata, msg): print("Got response with {} bytes".format(len(msg.payload))) def on_disconnect(self, client, userdata, rc): - print('Disconnect result {}'.format(rc)) + print("Disconnect result {}".format(rc)) async def test_write(self, cancel_scope: trio.CancelScope): for c in range(3): await trio.sleep(5) print("Publishing") - self.client.publish(topic, b'Hello' * 40000, qos=1) + self.client.publish(topic, b"Hello" * 40000, qos=1) cancel_scope.cancel() async def main(self): @@ -79,7 +79,7 @@ async def main(self): trio_helper = TrioAsyncHelper(self.client) - self.client.connect('mqtt.eclipseprojects.io', 1883, 60) + self.client.connect("mqtt.eclipseprojects.io", 1883, 60) async with trio.open_nursery() as nursery: nursery.start_soon(trio_helper.read_loop) diff --git a/examples/publish_multiple.py b/examples/publish_multiple.py index 7c19a130..a879ff9e 100755 --- a/examples/publish_multiple.py +++ b/examples/publish_multiple.py @@ -19,5 +19,8 @@ import paho.mqtt.publish as publish -msgs = [{'topic': "paho/test/multiple", 'payload': "multiple 1"}, ("paho/test/multiple", "multiple 2", 0, False)] +msgs = [ + {"topic": "paho/test/multiple", "payload": "multiple 1"}, + ("paho/test/multiple", "multiple 2", 0, False), +] publish.multiple(msgs, hostname="mqtt.eclipseprojects.io") diff --git a/examples/publish_utf8-27.py b/examples/publish_utf8-27.py index 74dee2b7..fa9f6cc1 100755 --- a/examples/publish_utf8-27.py +++ b/examples/publish_utf8-27.py @@ -19,6 +19,6 @@ import paho.mqtt.publish as publish -topic = u"paho/test/single/ô" -payload = u"bôô" +topic = "paho/test/single/ô" +payload = "bôô" publish.single(topic, payload, hostname="mqtt.eclipseprojects.io") diff --git a/examples/publish_utf8-3.py b/examples/publish_utf8-3.py index 76c11b4d..279b7580 100755 --- a/examples/publish_utf8-3.py +++ b/examples/publish_utf8-3.py @@ -19,6 +19,6 @@ import paho.mqtt.publish as publish -topic = u"paho/test/single/ô" -payload = u'German umlauts like "ä" ü"ö" are not supported' +topic = "paho/test/single/ô" +payload = 'German umlauts like "ä" ü"ö" are not supported' publish.single(topic, payload, hostname="test.mosquitto.org") diff --git a/examples/server_rpc_math.py b/examples/server_rpc_math.py index e7d3ee4e..b5f0e124 100755 --- a/examples/server_rpc_math.py +++ b/examples/server_rpc_math.py @@ -26,25 +26,29 @@ # The math functions exported + def add(nums): sum = 0 for x in nums: sum += x return sum + def mult(nums): prod = 1 for x in nums: prod *= x return prod + # Remember that the MQTTv5 callback takes the additional 'props' parameter. def on_connect(mqttc, userdata, flags, rc, props): - print("Connected: '"+str(flags)+"', '"+str(rc)+"', '"+str(props)) + print("Connected: '" + str(flags) + "', '" + str(rc) + "', '" + str(props)) if not flags["session present"]: print("Subscribing to math requests") mqttc.subscribe("requests/math/#") + # Each incoming message should be an RPC request on the # 'requests/math/#' topic. def on_message(mqttc, userdata, msg): @@ -52,7 +56,7 @@ def on_message(mqttc, userdata, msg): # Get the response properties, abort if they're not given props = msg.properties - if not hasattr(props, 'ResponseTopic') or not hasattr(props, 'CorrelationData'): + if not hasattr(props, "ResponseTopic") or not hasattr(props, "CorrelationData"): print("No reply requested") return @@ -71,13 +75,14 @@ def on_message(mqttc, userdata, msg): # Now we have the result, res, so send it back on the 'reply_to' # topic using the same correlation ID as the request. - print("Sending response "+str(res)+" on '"+reply_to+"': "+str(corr_id)) + print("Sending response " + str(res) + " on '" + reply_to + "': " + str(corr_id)) props = mqtt.Properties(PacketTypes.PUBLISH) props.CorrelationData = corr_id payload = json.dumps(res) mqttc.publish(reply_to, payload, qos=1, properties=props) + def on_log(mqttc, obj, level, string): print(string) @@ -90,8 +95,8 @@ def on_log(mqttc, obj, level, string): mqttc.on_connect = on_connect # Uncomment to enable debug messages -#mqttc.on_log = on_log +# mqttc.on_log = on_log -#mqttc.connect("mqtt.eclipseprojects.io", 1883, 60) +# mqttc.connect("mqtt.eclipseprojects.io", 1883, 60) mqttc.connect(host="localhost", clean_start=False) mqttc.loop_forever() diff --git a/examples/subscribe_callback.py b/examples/subscribe_callback.py index f12bccc2..3591e985 100755 --- a/examples/subscribe_callback.py +++ b/examples/subscribe_callback.py @@ -23,4 +23,5 @@ def print_msg(client, userdata, message): print("%s : %s" % (message.topic, message.payload)) + subscribe.callback(print_msg, "#", hostname="mqtt.eclipseprojects.io") diff --git a/examples/subscribe_simple.py b/examples/subscribe_simple.py index 87adb9ff..3fd15064 100755 --- a/examples/subscribe_simple.py +++ b/examples/subscribe_simple.py @@ -19,9 +19,11 @@ import paho.mqtt.subscribe as subscribe -topics = ['#'] +topics = ["#"] -m = subscribe.simple(topics, hostname="mqtt.eclipseprojects.io", retained=False, msg_count=2) +m = subscribe.simple( + topics, hostname="mqtt.eclipseprojects.io", retained=False, msg_count=2 +) for a in m: print(a.topic) print(a.payload) diff --git a/pyproject.toml b/pyproject.toml index cf8ba363..70f457bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ proxy = [ [project.urls] Homepage = "http://eclipse.org/paho" +[tool.codespell] +skip = "*.key" + [tool.hatch.version] path = "src/paho/mqtt/__init__.py" @@ -67,6 +70,15 @@ include = [ "src/paho", ] +[tool.mypy] + +[[tool.mypy.overrides]] +module = "paho.mqtt.client" +# check_untyped_defs = true +# disallow_untyped_calls = true +# disallow_incomplete_defs = true +disallow_untyped_defs = true + [tool.pytest.ini_options] addopts = ["-r", "xs"] testpaths = "tests src" diff --git a/src/paho/mqtt/client.py b/src/paho/mqtt/client.py index f6336beb..25849ac7 100644 --- a/src/paho/mqtt/client.py +++ b/src/paho/mqtt/client.py @@ -19,6 +19,7 @@ import base64 import collections +import enum import errno import hashlib import logging @@ -30,6 +31,7 @@ import struct import threading import time +import typing import urllib.parse import urllib.request import uuid @@ -42,13 +44,13 @@ try: import ssl except ImportError: - ssl = None + ssl = None # type: ignore[assignment] try: - import socks + import socks # type: ignore[import-untyped] except ImportError: - socks = None + socks = None # type: ignore[assignment] try: @@ -65,11 +67,39 @@ HAVE_DNS = False -if platform.system() == 'Windows': - EAGAIN = errno.WSAEWOULDBLOCK +if platform.system() == "Windows": + EAGAIN = errno.WSAEWOULDBLOCK # type: ignore[attr-defined] else: EAGAIN = errno.EAGAIN +if typing.TYPE_CHECKING: + try: + from typing import TypedDict + except ImportError: + from typing_extensions import TypedDict + + from typing_extensions import Literal + + class _InPacket(TypedDict): + command: int + have_remaining: int + remaining_count: typing.List[int] + remaining_mult: int + remaining_length: int + packet: bytearray + to_process: int + pos: int + + class _OutPacket(TypedDict): + command: int + mid: int + qos: int + pos: int + to_process: int + packet: bytes + info: typing.Optional["MQTTMessageInfo"] + + MQTTv31 = 3 MQTTv311 = 4 MQTTv5 = 5 @@ -131,25 +161,49 @@ mqtt_ms_send_pubrec = 8 mqtt_ms_queued = 9 + # Error values -MQTT_ERR_AGAIN = -1 -MQTT_ERR_SUCCESS = 0 -MQTT_ERR_NOMEM = 1 -MQTT_ERR_PROTOCOL = 2 -MQTT_ERR_INVAL = 3 -MQTT_ERR_NO_CONN = 4 -MQTT_ERR_CONN_REFUSED = 5 -MQTT_ERR_NOT_FOUND = 6 -MQTT_ERR_CONN_LOST = 7 -MQTT_ERR_TLS = 8 -MQTT_ERR_PAYLOAD_SIZE = 9 -MQTT_ERR_NOT_SUPPORTED = 10 -MQTT_ERR_AUTH = 11 -MQTT_ERR_ACL_DENIED = 12 -MQTT_ERR_UNKNOWN = 13 -MQTT_ERR_ERRNO = 14 -MQTT_ERR_QUEUE_SIZE = 15 -MQTT_ERR_KEEPALIVE = 16 +class MQTTErrorCode(enum.IntEnum): + MQTT_ERR_AGAIN = -1 + MQTT_ERR_SUCCESS = 0 + MQTT_ERR_NOMEM = 1 + MQTT_ERR_PROTOCOL = 2 + MQTT_ERR_INVAL = 3 + MQTT_ERR_NO_CONN = 4 + MQTT_ERR_CONN_REFUSED = 5 + MQTT_ERR_NOT_FOUND = 6 + MQTT_ERR_CONN_LOST = 7 + MQTT_ERR_TLS = 8 + MQTT_ERR_PAYLOAD_SIZE = 9 + MQTT_ERR_NOT_SUPPORTED = 10 + MQTT_ERR_AUTH = 11 + MQTT_ERR_ACL_DENIED = 12 + MQTT_ERR_UNKNOWN = 13 + MQTT_ERR_ERRNO = 14 + MQTT_ERR_QUEUE_SIZE = 15 + MQTT_ERR_KEEPALIVE = 16 + + +# This probably do the same as @global_enum, but this decorator require Python 3.11 +MQTT_ERR_AGAIN = MQTTErrorCode.MQTT_ERR_AGAIN +MQTT_ERR_SUCCESS = MQTTErrorCode.MQTT_ERR_SUCCESS +MQTT_ERR_NOMEM = MQTTErrorCode.MQTT_ERR_NOMEM +MQTT_ERR_PROTOCOL = MQTTErrorCode.MQTT_ERR_PROTOCOL +MQTT_ERR_INVAL = MQTTErrorCode.MQTT_ERR_INVAL +MQTT_ERR_NO_CONN = MQTTErrorCode.MQTT_ERR_NO_CONN +MQTT_ERR_CONN_REFUSED = MQTTErrorCode.MQTT_ERR_CONN_REFUSED +MQTT_ERR_NOT_FOUND = MQTTErrorCode.MQTT_ERR_NOT_FOUND +MQTT_ERR_CONN_LOST = MQTTErrorCode.MQTT_ERR_CONN_LOST +MQTT_ERR_TLS = MQTTErrorCode.MQTT_ERR_TLS +MQTT_ERR_PAYLOAD_SIZE = MQTTErrorCode.MQTT_ERR_PAYLOAD_SIZE +MQTT_ERR_NOT_SUPPORTED = MQTTErrorCode.MQTT_ERR_NOT_SUPPORTED +MQTT_ERR_AUTH = MQTTErrorCode.MQTT_ERR_AUTH +MQTT_ERR_ACL_DENIED = MQTTErrorCode.MQTT_ERR_ACL_DENIED +MQTT_ERR_UNKNOWN = MQTTErrorCode.MQTT_ERR_UNKNOWN +MQTT_ERR_ERRNO = MQTTErrorCode.MQTT_ERR_ERRNO +MQTT_ERR_QUEUE_SIZE = MQTTErrorCode.MQTT_ERR_QUEUE_SIZE +MQTT_ERR_KEEPALIVE = MQTTErrorCode.MQTT_ERR_KEEPALIVE + MQTT_CLIENT = 0 MQTT_BRIDGE = 1 @@ -159,12 +213,57 @@ sockpair_data = b"0" +# Payload support all those type and will be converted to bytes: +# * str are utf8 encoded +# * int/float are converted to string and utf8 encoded (e.g. 1 is converted to b"1") +# * None is converted to a zero-length payload (i.e. b"") +PayloadType = typing.Union[str, bytes, bytearray, int, float, None] + +HTTPHeader = typing.Dict[str, str] +WebSocketHeaders = typing.Union[typing.Callable[[HTTPHeader], HTTPHeader], HTTPHeader] + +SocketLike = typing.Union[socket.socket, "ssl.SSLSocket", "WebsocketWrapper"] + + +CallbackOnConnect = typing.Union[ + typing.Callable[["Client", typing.Any, ReasonCodes, Properties], None], + typing.Callable[["Client", typing.Any, MQTTErrorCode], None], +] +CallbackOnConnectFail = typing.Callable[["Client", typing.Any], None] +CallbackOnDisconnect = typing.Union[ + typing.Callable[ + ["Client", typing.Any, typing.Dict[str, typing.Any], ReasonCodes, Properties], + None, + ], + typing.Callable[ + ["Client", typing.Any, typing.Dict[str, typing.Any], MQTTErrorCode], None + ], +] +CallbackOnLog = typing.Callable[["Client", typing.Any, int, str], None] +CallbackOnMessage = typing.Callable[["Client", typing.Any, "MQTTMessage"], None] +CallbackOnPreConnect = typing.Callable[["Client", typing.Any], None] +CallbackOnPublish = typing.Callable[["Client", typing.Any, int], None] +CallbackOnSocket = typing.Callable[["Client", typing.Any, SocketLike], None] +CallbackOnSubscribe = typing.Union[ + typing.Callable[ + ["Client", typing.Any, Properties, typing.List[ReasonCodes], Properties], None + ], + typing.Callable[["Client", typing.Any, int, typing.Tuple[int, ...]], None], +] +CallbackOnUnsubscribe = typing.Union[ + typing.Callable[["Client", typing.Any, Properties, ReasonCodes], None], + typing.Callable[["Client", typing.Any, int], None], +] + +# This is needed for typing because class Client redefined the name "socket" +_socket = socket + class WebsocketConnectionError(ValueError): pass -def error_string(mqtt_errno): +def error_string(mqtt_errno: MQTTErrorCode) -> str: """Return the error string associated with an mqtt error number.""" if mqtt_errno == MQTT_ERR_SUCCESS: return "No error." @@ -204,7 +303,7 @@ def error_string(mqtt_errno): return "Unknown error." -def connack_string(connack_code): +def connack_string(connack_code: int) -> str: """Return the string associated with a CONNACK result.""" if connack_code == CONNACK_ACCEPTED: return "Connection Accepted." @@ -222,7 +321,9 @@ def connack_string(connack_code): return "Connection Refused: unknown reason." -def base62(num, base=string.digits + string.ascii_letters, padding=1): +def base62( + num: int, base: str = string.digits + string.ascii_letters, padding: int = 1 +) -> str: """Convert a number to base-62 representation.""" if num < 0: raise ValueError("Number must be positive or zero") @@ -231,10 +332,10 @@ def base62(num, base=string.digits + string.ascii_letters, padding=1): num, rest = divmod(num, 62) digits.append(base[rest]) digits.extend(base[0] for _ in range(len(digits), padding)) - return ''.join(reversed(digits)) + return "".join(reversed(digits)) -def topic_matches_sub(sub, topic): +def topic_matches_sub(sub: str, topic: str) -> bool: """Check whether a topic matches a subscription. For example: @@ -251,24 +352,22 @@ def topic_matches_sub(sub, topic): return False -def _socketpair_compat(): +def _socketpair_compat() -> typing.Tuple[socket.socket, socket.socket]: """TCP/IP socketpair including Windows support""" - listensock = socket.socket( - socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) + listensock = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) listensock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) listensock.bind(("127.0.0.1", 0)) listensock.listen(1) iface, port = listensock.getsockname() - sock1 = socket.socket( - socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) - sock1.setblocking(0) + sock1 = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_IP) + sock1.setblocking(False) try: sock1.connect(("127.0.0.1", port)) except BlockingIOError: pass sock2, address = listensock.accept() - sock2.setblocking(0) + sock2.setblocking(False) listensock.close() return (sock1, sock2) @@ -279,26 +378,26 @@ class MQTTMessageInfo: message has been published, and/or wait until it is published. """ - __slots__ = 'mid', '_published', '_condition', 'rc', '_iterpos' + __slots__ = "mid", "_published", "_condition", "rc", "_iterpos" - def __init__(self, mid): + def __init__(self, mid: int): self.mid = mid self._published = False self._condition = threading.Condition() - self.rc = 0 + self.rc: MQTTErrorCode = MQTTErrorCode.MQTT_ERR_SUCCESS self._iterpos = 0 - def __str__(self): + def __str__(self) -> str: return str((self.rc, self.mid)) - def __iter__(self): + def __iter__(self) -> typing.Iterator[typing.Union[MQTTErrorCode, int]]: self._iterpos = 0 return self - def __next__(self): + def __next__(self) -> typing.Union[MQTTErrorCode, int]: return self.next() - def next(self): + def next(self) -> typing.Union[MQTTErrorCode, int]: if self._iterpos == 0: self._iterpos = 1 return self.rc @@ -308,7 +407,7 @@ def next(self): else: raise StopIteration - def __getitem__(self, index): + def __getitem__(self, index: int) -> typing.Union[MQTTErrorCode, int]: if index == 0: return self.rc elif index == 1: @@ -316,12 +415,12 @@ def __getitem__(self, index): else: raise IndexError("index out of range") - def _set_as_published(self): + def _set_as_published(self) -> None: with self._condition: self._published = True self._condition.notify() - def wait_for_publish(self, timeout=None): + def wait_for_publish(self, timeout: typing.Optional[float] = None) -> None: """Block until the message associated with this object is published, or until the timeout occurs. If timeout is None, this will never time out. Set timeout to a positive number of seconds, e.g. 1.2, to enable the @@ -334,37 +433,38 @@ def wait_for_publish(self, timeout=None): reason. """ if self.rc == MQTT_ERR_QUEUE_SIZE: - raise ValueError('Message is not queued due to ERR_QUEUE_SIZE') + raise ValueError("Message is not queued due to ERR_QUEUE_SIZE") elif self.rc == MQTT_ERR_AGAIN: pass elif self.rc > 0: - raise RuntimeError(f'Message publish failed: {error_string(self.rc)}') + raise RuntimeError(f"Message publish failed: {error_string(self.rc)}") timeout_time = None if timeout is None else time_func() + timeout - timeout_tenth = None if timeout is None else timeout / 10. - def timed_out(): - return False if timeout is None else time_func() > timeout_time + timeout_tenth = None if timeout is None else timeout / 10.0 + + def timed_out() -> bool: + return False if timeout_time is None else time_func() > timeout_time with self._condition: while not self._published and not timed_out(): self._condition.wait(timeout_tenth) - def is_published(self): + def is_published(self) -> bool: """Returns True if the message associated with this object has been published, else returns False.""" - if self.rc == MQTT_ERR_QUEUE_SIZE: - raise ValueError('Message is not queued due to ERR_QUEUE_SIZE') - elif self.rc == MQTT_ERR_AGAIN: + if self.rc == MQTTErrorCode.MQTT_ERR_QUEUE_SIZE: + raise ValueError("Message is not queued due to ERR_QUEUE_SIZE") + elif self.rc == MQTTErrorCode.MQTT_ERR_AGAIN: pass elif self.rc > 0: - raise RuntimeError(f'Message publish failed: {error_string(self.rc)}') + raise RuntimeError(f"Message publish failed: {error_string(self.rc)}") with self._condition: return self._published class MQTTMessage: - """ This is a class that describes an incoming or outgoing message. It is + """This is a class that describes an incoming or outgoing message. It is passed to the on_message callback as the message parameter. Members: @@ -377,10 +477,21 @@ class MQTTMessage: properties: Properties class. In MQTT v5.0, the properties associated with the message. """ - __slots__ = 'timestamp', 'state', 'dup', 'mid', '_topic', 'payload', 'qos', 'retain', 'info', 'properties' - - def __init__(self, mid=0, topic=b""): - self.timestamp = 0 + __slots__ = ( + "timestamp", + "state", + "dup", + "mid", + "_topic", + "payload", + "qos", + "retain", + "info", + "properties", + ) + + def __init__(self, mid: int = 0, topic: bytes = b""): + self.timestamp = 0.0 self.state = mqtt_ms_invalid self.dup = False self.mid = mid @@ -389,23 +500,24 @@ def __init__(self, mid=0, topic=b""): self.qos = 0 self.retain = False self.info = MQTTMessageInfo(mid) + self.properties: typing.Optional[Properties] = None - def __eq__(self, other): + def __eq__(self, other: object) -> bool: """Override the default Equals behavior""" if isinstance(other, self.__class__): return self.mid == other.mid return False - def __ne__(self, other): + def __ne__(self, other: object) -> bool: """Define a non-equality test""" return not self.__eq__(other) @property - def topic(self): - return self._topic.decode('utf-8') + def topic(self) -> str: + return self._topic.decode("utf-8") @topic.setter - def topic(self, value): + def topic(self, value: bytes) -> None: self._topic = value @@ -469,9 +581,16 @@ def on_connect(client, userdata, flags, rc): on_socket_register_write, on_socket_unregister_write """ - def __init__(self, client_id="", clean_session=None, userdata=None, - protocol=MQTTv311, transport="tcp", reconnect_on_failure=True, - manual_ack=False ): + def __init__( + self, + client_id: str = "", + clean_session: typing.Optional[bool] = None, + userdata: typing.Any = None, + protocol: int = MQTTv311, + transport: str = "tcp", + reconnect_on_failure: bool = True, + manual_ack: bool = False, + ) -> None: """client_id is the unique client id string used when connecting to the broker. If client_id is zero length or None, then the behaviour is defined by which protocol version is in use. If using MQTT v3.1.1, then @@ -507,50 +626,56 @@ def __init__(self, client_id="", clean_session=None, userdata=None, Normally, when a message is received, the library automatically acknowledges after on_message callback returns. manual_ack=True allows the application to acknowledge receipt after it has completed processing of a message - using a the ack() method. This addresses vulnerabilty to message loss + using a the ack() method. This addresses vulnerability to message loss if applications fails while processing a message, or while it pending locally. """ - if transport.lower() not in ('websockets', 'tcp'): + if transport.lower() not in ("websockets", "tcp"): raise ValueError( - f'transport must be "websockets" or "tcp", not {transport}') + f'transport must be "websockets" or "tcp", not {transport}' + ) self._manual_ack = manual_ack self._transport = transport.lower() self._protocol = protocol self._userdata = userdata - self._sock = None - self._sockpairR, self._sockpairW = (None, None,) + self._sock: typing.Union[ + socket.socket, WebsocketWrapper, "ssl.SSLSocket", None + ] = None + self._sockpairR: typing.Optional[socket.socket] = None + self._sockpairW: typing.Optional[socket.socket] = None self._keepalive = 60 self._connect_timeout = 5.0 self._client_mode = MQTT_CLIENT if protocol == MQTTv5: if clean_session is not None: - raise ValueError('Clean session is not used for MQTT 5.0') + raise ValueError("Clean session is not used for MQTT 5.0") else: if clean_session is None: clean_session = True if not clean_session and (client_id == "" or client_id is None): raise ValueError( - 'A client id must be provided if clean session is False.') + "A client id must be provided if clean session is False." + ) self._clean_session = clean_session # [MQTT-3.1.3-4] Client Id must be UTF-8 encoded string. if client_id == "" or client_id is None: if protocol == MQTTv31: - self._client_id = base62(uuid.uuid4().int, padding=22) + self._client_id = base62(uuid.uuid4().int, padding=22).encode("utf8") else: self._client_id = b"" else: - self._client_id = client_id - if isinstance(self._client_id, str): - self._client_id = self._client_id.encode('utf-8') + if isinstance(client_id, str): + self._client_id = client_id.encode("utf-8") + else: + self._client_id = client_id - self._username = None - self._password = None - self._in_packet = { + self._username: typing.Optional[bytes] = None + self._password: typing.Optional[bytes] = None + self._in_packet: "_InPacket" = { "command": 0, "have_remaining": 0, "remaining_count": [], @@ -558,24 +683,30 @@ def __init__(self, client_id="", clean_session=None, userdata=None, "remaining_length": 0, "packet": bytearray(b""), "to_process": 0, - "pos": 0} - self._out_packet = collections.deque() + "pos": 0, + } + + self._out_packet: typing.Deque["_OutPacket"] = collections.deque() self._last_msg_in = time_func() self._last_msg_out = time_func() self._reconnect_min_delay = 1 self._reconnect_max_delay = 120 - self._reconnect_delay = None + self._reconnect_delay: typing.Optional[int] = None self._reconnect_on_failure = reconnect_on_failure - self._ping_t = 0 + self._ping_t = 0.0 self._last_mid = 0 self._state = mqtt_cs_new - self._out_messages = collections.OrderedDict() - self._in_messages = collections.OrderedDict() + self._out_messages: collections.OrderedDict[ + int, MQTTMessage + ] = collections.OrderedDict() + self._in_messages: collections.OrderedDict[ + int, MQTTMessage + ] = collections.OrderedDict() self._max_inflight_messages = 20 self._inflight_messages = 0 self._max_queued_messages = 0 - self._connect_properties = None - self._will_properties = None + self._connect_properties: typing.Optional[Properties] = None + self._will_properties: typing.Optional[Properties] = None self._will = False self._will_topic = b"" self._will_payload = b"" @@ -586,7 +717,7 @@ def __init__(self, client_id="", clean_session=None, userdata=None, self._port = 1883 self._bind_address = "" self._bind_port = 0 - self._proxy = {} + self._proxy: typing.Any = {} self._in_callback_mutex = threading.Lock() self._callback_mutex = threading.RLock() self._msgtime_mutex = threading.Lock() @@ -594,38 +725,41 @@ def __init__(self, client_id="", clean_session=None, userdata=None, self._in_message_mutex = threading.Lock() self._reconnect_delay_mutex = threading.Lock() self._mid_generate_mutex = threading.Lock() - self._thread = None + self._thread: typing.Optional[threading.Thread] = None self._thread_terminate = False self._ssl = False - self._ssl_context = None + self._ssl_context: typing.Optional["ssl.SSLContext"] = None # Only used when SSL context does not have check_hostname attribute self._tls_insecure = False - self._logger = None + self._logger: typing.Optional[logging.Logger] = None self._registered_write = False # No default callbacks - self._on_log = None - self._on_pre_connect = None - self._on_connect = None - self._on_connect_fail = None - self._on_subscribe = None - self._on_message = None - self._on_publish = None - self._on_unsubscribe = None - self._on_disconnect = None - self._on_socket_open = None - self._on_socket_close = None - self._on_socket_register_write = None - self._on_socket_unregister_write = None + self._on_log: typing.Optional[CallbackOnLog] = None + self._on_pre_connect: typing.Optional[CallbackOnPreConnect] = None + self._on_connect: typing.Optional[CallbackOnConnect] = None + self._on_connect_fail: typing.Optional[CallbackOnConnectFail] = None + self._on_subscribe: typing.Optional[CallbackOnSubscribe] = None + self._on_message: typing.Optional[CallbackOnMessage] = None + self._on_publish: typing.Optional[CallbackOnPublish] = None + self._on_unsubscribe: typing.Optional[CallbackOnUnsubscribe] = None + self._on_disconnect: typing.Optional[CallbackOnDisconnect] = None + self._on_socket_open: typing.Optional[CallbackOnSocket] = None + self._on_socket_close: typing.Optional[CallbackOnSocket] = None + self._on_socket_register_write: typing.Optional[CallbackOnSocket] = None + self._on_socket_unregister_write: typing.Optional[CallbackOnSocket] = None self._websocket_path = "/mqtt" - self._websocket_extra_headers = None + self._websocket_extra_headers: typing.Optional[WebSocketHeaders] = None # for clean_start == MQTT_CLEAN_START_FIRST_ONLY self._mqttv5_first_connect = True - self.suppress_exceptions = False # For callbacks + self.suppress_exceptions = False # For callbacks - def __del__(self): + def __del__(self) -> None: self._reset_sockets() - def _sock_recv(self, bufsize): + def _sock_recv(self, bufsize: int) -> bytes: + if self._sock is None: + raise ConnectionError("self._sock is None") + try: return self._sock.recv(bufsize) except ssl.SSLWantReadError as err: @@ -634,11 +768,13 @@ def _sock_recv(self, bufsize): self._call_socket_register_write() raise BlockingIOError() from err except AttributeError as err: - self._easy_log( - MQTT_LOG_DEBUG, "socket was None: %s", err) + self._easy_log(MQTT_LOG_DEBUG, "socket was None: %s", err) raise ConnectionError() from err - def _sock_send(self, buf): + def _sock_send(self, buf: bytes) -> int: + if self._sock is None: + raise ConnectionError("self._sock is None") + try: return self._sock.send(buf) except ssl.SSLWantReadError as err: @@ -650,7 +786,7 @@ def _sock_send(self, buf): self._call_socket_register_write() raise BlockingIOError() from err - def _sock_close(self): + def _sock_close(self) -> None: """Close the connection to the server.""" if not self._sock: return @@ -664,7 +800,7 @@ def _sock_close(self): # In case a callback fails, still close the socket to avoid leaking the file descriptor. sock.close() - def _reset_sockets(self, sockpair_only=False): + def _reset_sockets(self, sockpair_only: bool = False) -> None: if not sockpair_only: self._sock_close() @@ -675,13 +811,20 @@ def _reset_sockets(self, sockpair_only=False): self._sockpairW.close() self._sockpairW = None - def reinitialise(self, client_id="", clean_session=True, userdata=None): + def reinitialise( + self, + client_id: str = "", + clean_session: bool = True, + userdata: typing.Any = None, + ) -> None: self._reset_sockets() - self.__init__(client_id, clean_session, userdata) + self.__init__(client_id, clean_session, userdata) # type: ignore[misc] - def ws_set_options(self, path="/mqtt", headers=None): - """ Set the path and headers for a websocket connection + def ws_set_options( + self, path: str = "/mqtt", headers: typing.Optional[WebSocketHeaders] = None + ) -> None: + """Set the path and headers for a websocket connection path is a string starting with / which should be the endpoint of the mqtt connection on the remote server @@ -698,9 +841,12 @@ def ws_set_options(self, path="/mqtt", headers=None): self._websocket_extra_headers = headers else: raise ValueError( - "'headers' option to ws_set_options has to be either a dictionary or callable") + "'headers' option to ws_set_options has to be either a dictionary or callable" + ) - def tls_set_context(self, context=None): + def tls_set_context( + self, context: typing.Optional["ssl.SSLContext"] = None + ) -> None: """Configure network encryption and authentication context. Enables SSL/TLS support. context : an ssl.SSLContext object. By default this is given by @@ -708,7 +854,7 @@ def tls_set_context(self, context=None): Must be called before connect() or connect_async().""" if self._ssl_context is not None: - raise ValueError('SSL/TLS has already been configured.') + raise ValueError("SSL/TLS has already been configured.") if context is None: context = ssl.create_default_context() @@ -717,10 +863,19 @@ def tls_set_context(self, context=None): self._ssl_context = context # Ensure _tls_insecure is consistent with check_hostname attribute - if hasattr(context, 'check_hostname'): + if hasattr(context, "check_hostname"): self._tls_insecure = not context.check_hostname - def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tls_version=None, ciphers=None, keyfile_password=None): + def tls_set( + self, + ca_certs: typing.Optional[str] = None, + certfile: typing.Optional[str] = None, + keyfile: typing.Optional[str] = None, + cert_reqs: typing.Optional["ssl.VerifyMode"] = None, + tls_version: typing.Optional[int] = None, + ciphers: typing.Optional[str] = None, + keyfile_password: typing.Optional[str] = None, + ) -> None: """Configure network encryption and authentication options. Enables SSL/TLS support. ca_certs : a string path to the Certificate Authority certificate files @@ -761,15 +916,16 @@ def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tl Must be called before connect() or connect_async().""" if ssl is None: - raise ValueError('This platform has no SSL/TLS.') + raise ValueError("This platform has no SSL/TLS.") - if not hasattr(ssl, 'SSLContext'): + if not hasattr(ssl, "SSLContext"): # Require Python version that has SSL context support in standard library raise ValueError( - 'Python 2.7.9 and 3.2 are the minimum supported versions for TLS.') + "Python 2.7.9 and 3.2 are the minimum supported versions for TLS." + ) - if ca_certs is None and not hasattr(ssl.SSLContext, 'load_default_certs'): - raise ValueError('ca_certs must not be None.') + if ca_certs is None and not hasattr(ssl.SSLContext, "load_default_certs"): + raise ValueError("ca_certs must not be None.") # Create SSLContext object if tls_version is None: @@ -789,7 +945,7 @@ def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tl if certfile is not None: context.load_cert_chain(certfile, keyfile, keyfile_password) - if cert_reqs == ssl.CERT_NONE and hasattr(context, 'check_hostname'): + if cert_reqs == ssl.CERT_NONE and hasattr(context, "check_hostname"): context.check_hostname = False context.verify_mode = ssl.CERT_REQUIRED if cert_reqs is None else cert_reqs @@ -809,7 +965,7 @@ def tls_set(self, ca_certs=None, certfile=None, keyfile=None, cert_reqs=None, tl # But with ssl.CERT_NONE, we can not check_hostname self.tls_insecure_set(True) - def tls_insecure_set(self, value): + def tls_insecure_set(self, value: bool) -> None: """Configure verification of the server hostname in the server certificate. If value is set to true, it is impossible to guarantee that the host @@ -826,17 +982,18 @@ def tls_insecure_set(self, value): if self._ssl_context is None: raise ValueError( - 'Must configure SSL context before using tls_insecure_set.') + "Must configure SSL context before using tls_insecure_set." + ) self._tls_insecure = value # Ensure check_hostname is consistent with _tls_insecure attribute - if hasattr(self._ssl_context, 'check_hostname'): + if hasattr(self._ssl_context, "check_hostname"): # Rely on SSLContext to check host name # If verify_mode is CERT_NONE then the host name will never be checked self._ssl_context.check_hostname = not value - def proxy_set(self, **proxy_args): + def proxy_set(self, **proxy_args: typing.Any) -> None: """Configure proxying of MQTT connection. Enables support for SOCKS or HTTP proxies. @@ -861,8 +1018,8 @@ def proxy_set(self, **proxy_args): else: self._proxy = proxy_args - def enable_logger(self, logger=None): - """ Enables a logger to send log messages to """ + def enable_logger(self, logger: typing.Optional[logging.Logger] = None) -> None: + """Enables a logger to send log messages to""" if logger is None: if self._logger is not None: # Do not replace existing logger @@ -870,11 +1027,21 @@ def enable_logger(self, logger=None): logger = logging.getLogger(__name__) self._logger = logger - def disable_logger(self): + def disable_logger(self) -> None: self._logger = None - def connect(self, host, port=1883, keepalive=60, bind_address="", bind_port=0, - clean_start=MQTT_CLEAN_START_FIRST_ONLY, properties=None): + def connect( + self, + host: str, + port: int = 1883, + keepalive: int = 60, + bind_address: str = "", + bind_port: int = 0, + clean_start: typing.Union[ + bool, "Literal[3]" + ] = MQTT_CLEAN_START_FIRST_ONLY, # type: ignore + properties: typing.Optional[Properties] = None, + ) -> MQTTErrorCode: """Connect to a remote broker. This is a blocking call that establishes the underlying connection and transmits a CONNECT packet. @@ -901,12 +1068,22 @@ def connect(self, host, port=1883, keepalive=60, bind_address="", bind_port=0, if properties: raise ValueError("Properties only apply to MQTT V5") - self.connect_async(host, port, keepalive, - bind_address, bind_port, clean_start, properties) + self.connect_async( + host, port, keepalive, bind_address, bind_port, clean_start, properties + ) return self.reconnect() - def connect_srv(self, domain=None, keepalive=60, bind_address="", - clean_start=MQTT_CLEAN_START_FIRST_ONLY, properties=None): + def connect_srv( + self, + domain: typing.Optional[str] = None, + keepalive: int = 60, + bind_address: str = "", + bind_port: int = 0, + clean_start: typing.Union[ + bool, "Literal[3]" + ] = MQTT_CLEAN_START_FIRST_ONLY, # type: ignore + properties: typing.Optional[Properties] = None, + ) -> MQTTErrorCode: """Connect to a remote broker. domain is the DNS domain to search for SRV records; if None, @@ -916,23 +1093,27 @@ def connect_srv(self, domain=None, keepalive=60, bind_address="", if HAVE_DNS is False: raise ValueError( - 'No DNS resolver library found, try "pip install dnspython" or "pip3 install dnspython3".') + 'No DNS resolver library found, try "pip install dnspython" or "pip3 install dnspython3".' + ) if domain is None: domain = socket.getfqdn() - domain = domain[domain.find('.') + 1:] + domain = domain[domain.find(".") + 1 :] try: - rr = f'_mqtt._tcp.{domain}' + rr = f"_mqtt._tcp.{domain}" if self._ssl: # IANA specifies secure-mqtt (not mqtts) for port 8883 - rr = f'_secure-mqtt._tcp.{domain}' + rr = f"_secure-mqtt._tcp.{domain}" answers = [] for answer in dns.resolver.query(rr, dns.rdatatype.SRV): addr = answer.target.to_text()[:-1] - answers.append( - (addr, answer.port, answer.priority, answer.weight)) - except (dns.resolver.NXDOMAIN, dns.resolver.NoAnswer, dns.resolver.NoNameservers) as err: + answers.append((addr, answer.port, answer.priority, answer.weight)) + except ( + dns.resolver.NXDOMAIN, + dns.resolver.NoAnswer, + dns.resolver.NoNameservers, + ) as err: raise ValueError(f"No answer/NXDOMAIN for SRV in {domain}") from err # FIXME: doesn't account for weight @@ -940,14 +1121,32 @@ def connect_srv(self, domain=None, keepalive=60, bind_address="", host, port, prio, weight = answer try: - return self.connect(host, port, keepalive, bind_address, clean_start, properties) + return self.connect( + host, + port, + keepalive, + bind_address, + bind_port, + clean_start, + properties, + ) except Exception: # noqa: S110 pass raise ValueError("No SRV hosts responded") - def connect_async(self, host, port=1883, keepalive=60, bind_address="", bind_port=0, - clean_start=MQTT_CLEAN_START_FIRST_ONLY, properties=None): + def connect_async( + self, + host: str, + port: int = 1883, + keepalive: int = 60, + bind_address: str = "", + bind_port: int = 0, + clean_start: typing.Union[ + bool, "Literal[3]" + ] = MQTT_CLEAN_START_FIRST_ONLY, # type: ignore + properties: typing.Optional[Properties] = None, + ) -> None: """Connect to a remote broker asynchronously. This is a non-blocking connect call that can be used with loop_start() to provide very quick start. @@ -967,13 +1166,13 @@ def connect_async(self, host, port=1883, keepalive=60, bind_address="", bind_por MQTT connect packet. Use the Properties class. """ if host is None or len(host) == 0: - raise ValueError('Invalid host.') + raise ValueError("Invalid host.") if port <= 0: - raise ValueError('Invalid port number.') + raise ValueError("Invalid port number.") if keepalive < 0: - raise ValueError('Keepalive must be >=0.') + raise ValueError("Keepalive must be >=0.") if bind_port < 0: - raise ValueError('Invalid bind port number.') + raise ValueError("Invalid bind port number.") self._host = host self._port = port @@ -984,27 +1183,26 @@ def connect_async(self, host, port=1883, keepalive=60, bind_address="", bind_por self._connect_properties = properties self._state = mqtt_cs_connect_async + def reconnect_delay_set(self, min_delay: int = 1, max_delay: int = 120) -> None: + """Configure the exponential reconnect delay - def reconnect_delay_set(self, min_delay=1, max_delay=120): - """ Configure the exponential reconnect delay - - When connection is lost, wait initially min_delay seconds and - double this time every attempt. The wait is capped at max_delay. - Once the client is fully connected (e.g. not only TCP socket, but - received a success CONNACK), the wait timer is reset to min_delay. + When connection is lost, wait initially min_delay seconds and + double this time every attempt. The wait is capped at max_delay. + Once the client is fully connected (e.g. not only TCP socket, but + received a success CONNACK), the wait timer is reset to min_delay. """ with self._reconnect_delay_mutex: self._reconnect_min_delay = min_delay self._reconnect_max_delay = max_delay self._reconnect_delay = None - def reconnect(self): + def reconnect(self) -> MQTTErrorCode: """Reconnect the client after a disconnect. Can only be called after connect()/connect_async().""" if len(self._host) == 0: - raise ValueError('Invalid host.') + raise ValueError("Invalid host.") if self._port <= 0: - raise ValueError('Invalid port number.') + raise ValueError("Invalid port number.") self._in_packet = { "command": 0, @@ -1014,7 +1212,8 @@ def reconnect(self): "remaining_length": 0, "packet": bytearray(b""), "to_process": 0, - "pos": 0} + "pos": 0, + } self._out_packet = collections.deque() @@ -1022,7 +1221,7 @@ def reconnect(self): self._last_msg_in = time_func() self._last_msg_out = time_func() - self._ping_t = 0 + self._ping_t = 0.0 self._state = mqtt_cs_new self._sock_close() @@ -1038,20 +1237,24 @@ def reconnect(self): on_pre_connect(self, self._userdata) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_pre_connect: %s', err) + MQTT_LOG_ERR, "Caught exception in on_pre_connect: %s", err + ) if not self.suppress_exceptions: raise - sock = self._create_socket_connection() + tcp_sock = self._create_socket_connection() if self._ssl: - # SSL is only supported when SSLContext is available (implies Python >= 2.7.9 or >= 3.2) + if self._ssl_context is None: + raise ValueError( + "Impossible condition. _ssl_context should never be None if _ssl is True" + ) verify_host = not self._tls_insecure try: # Try with server_hostname, even it's not supported in certain scenarios - sock = self._ssl_context.wrap_socket( - sock, + ssl_sock = self._ssl_context.wrap_socket( + tcp_sock, server_hostname=self._host, do_handshake_on_connect=False, ) @@ -1060,35 +1263,52 @@ def reconnect(self): raise except ValueError: # Python version requires SNI in order to handle server_hostname, but SNI is not available - sock = self._ssl_context.wrap_socket( - sock, + ssl_sock = self._ssl_context.wrap_socket( + tcp_sock, do_handshake_on_connect=False, ) else: # If SSL context has already checked hostname, then don't need to do it again - if (hasattr(self._ssl_context, 'check_hostname') and - self._ssl_context.check_hostname): + if ( + hasattr(self._ssl_context, "check_hostname") + and self._ssl_context.check_hostname # type: ignore + ): verify_host = False - sock.settimeout(self._keepalive) - sock.do_handshake() + ssl_sock.settimeout(self._keepalive) + ssl_sock.do_handshake() if verify_host: - ssl.match_hostname(sock.getpeercert(), self._host) + # TODO: this type error is a true error: + # error: Module has no attribute "match_hostname" [attr-defined] + # Python 3.12 no longer have this method. + ssl.match_hostname(ssl_sock.getpeercert(), self._host) # type: ignore + + sock_without_ws: typing.Union[socket.socket, "ssl.SSLSocket"] = ssl_sock + else: + sock_without_ws = tcp_sock if self._transport == "websockets": - sock.settimeout(self._keepalive) - sock = WebsocketWrapper(sock, self._host, self._port, self._ssl, - self._websocket_path, self._websocket_extra_headers) + sock_without_ws.settimeout(self._keepalive) + ws_sock = WebsocketWrapper( + sock_without_ws, + self._host, + self._port, + self._ssl, + self._websocket_path, + self._websocket_extra_headers, + ) + self._sock = ws_sock + else: + self._sock = sock_without_ws - self._sock = sock - self._sock.setblocking(0) + self._sock.setblocking(False) # type: ignore[attr-defined] self._registered_write = False self._call_socket_open() return self._send_connect(self._keepalive) - def loop(self, timeout=1.0, max_packets=1): + def loop(self, timeout: float = 1.0, max_packets: int = 1) -> MQTTErrorCode: """Process network events. It is strongly recommended that you use loop_start(), or @@ -1120,9 +1340,9 @@ def loop(self, timeout=1.0, max_packets=1): return self._loop(timeout) - def _loop(self, timeout=1.0): + def _loop(self, timeout: float = 1.0) -> MQTTErrorCode: if timeout < 0.0: - raise ValueError('Invalid timeout.') + raise ValueError("Invalid timeout.") try: packet = self._out_packet.popleft() @@ -1133,8 +1353,8 @@ def _loop(self, timeout=1.0): # used to check if there are any bytes left in the (SSL) socket pending_bytes = 0 - if hasattr(self._sock, 'pending'): - pending_bytes = self._sock.pending() + if hasattr(self._sock, "pending"): + pending_bytes = self._sock.pending() # type: ignore[union-attr] # if bytes are pending do not wait in select if pending_bytes > 0: @@ -1151,15 +1371,15 @@ def _loop(self, timeout=1.0): socklist = select.select(rlist, wlist, [], timeout) except TypeError: # Socket isn't correct type, in likelihood connection is lost - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST except ValueError: # Can occur if we just reconnected but rlist/wlist contain a -1 for # some reason. - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST except Exception: # Note that KeyboardInterrupt, etc. can still terminate since they # are not derived from Exception - return MQTT_ERR_UNKNOWN + return MQTTErrorCode.MQTT_ERR_UNKNOWN if self._sock in socklist[0] or pending_bytes > 0: rc = self.loop_read() @@ -1173,7 +1393,7 @@ def _loop(self, timeout=1.0): # Clear sockpairR - only ever a single byte written. try: # Read many bytes at once - this allows up to 10000 calls to - # publish() inbetween calls to loop(). + # publish() in between calls to loop(). self._sockpairR.recv(10000) except BlockingIOError: pass @@ -1185,7 +1405,14 @@ def _loop(self, timeout=1.0): return self.loop_misc() - def publish(self, topic, payload=None, qos=0, retain=False, properties=None): + def publish( + self, + topic: str, + payload: PayloadType = None, + qos: int = 0, + retain: bool = False, + properties: typing.Optional[Properties] = None, + ) -> MQTTMessageInfo: """Publish a message on a topic. This causes a message to be sent to the broker and subsequently from @@ -1226,41 +1453,51 @@ def publish(self, topic, payload=None, qos=0, retain=False, properties=None): the length of the payload is greater than 268435455 bytes.""" if self._protocol != MQTTv5: if topic is None or len(topic) == 0: - raise ValueError('Invalid topic.') + raise ValueError("Invalid topic.") - topic = topic.encode('utf-8') + topic_bytes = topic.encode("utf-8") - if self._topic_wildcard_len_check(topic) != MQTT_ERR_SUCCESS: - raise ValueError('Publish topic cannot contain wildcards.') + if ( + self._topic_wildcard_len_check(topic_bytes) + != MQTTErrorCode.MQTT_ERR_SUCCESS + ): + raise ValueError("Publish topic cannot contain wildcards.") if qos < 0 or qos > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") if isinstance(payload, str): - local_payload = payload.encode('utf-8') + local_payload = payload.encode("utf-8") elif isinstance(payload, (bytes, bytearray)): local_payload = payload elif isinstance(payload, (int, float)): - local_payload = str(payload).encode('ascii') + local_payload = str(payload).encode("ascii") elif payload is None: - local_payload = b'' + local_payload = b"" else: - raise TypeError( - 'payload must be a string, bytearray, int, float or None.') + raise TypeError("payload must be a string, bytearray, int, float or None.") if len(local_payload) > 268435455: - raise ValueError('Payload too large.') + raise ValueError("Payload too large.") local_mid = self._mid_generate() if qos == 0: info = MQTTMessageInfo(local_mid) rc = self._send_publish( - local_mid, topic, local_payload, qos, retain, False, info, properties) + local_mid, + topic_bytes, + local_payload, + qos, + retain, + False, + info, + properties, + ) info.rc = rc return info else: - message = MQTTMessage(local_mid, topic) + message = MQTTMessage(local_mid, topic_bytes) message.timestamp = time_func() message.payload = local_payload message.qos = qos @@ -1269,27 +1506,41 @@ def publish(self, topic, payload=None, qos=0, retain=False, properties=None): message.properties = properties with self._out_message_mutex: - if self._max_queued_messages > 0 and len(self._out_messages) >= self._max_queued_messages: - message.info.rc = MQTT_ERR_QUEUE_SIZE + if ( + self._max_queued_messages > 0 + and len(self._out_messages) >= self._max_queued_messages + ): + message.info.rc = MQTTErrorCode.MQTT_ERR_QUEUE_SIZE return message.info if local_mid in self._out_messages: - message.info.rc = MQTT_ERR_QUEUE_SIZE + message.info.rc = MQTTErrorCode.MQTT_ERR_QUEUE_SIZE return message.info self._out_messages[message.mid] = message - if self._max_inflight_messages == 0 or self._inflight_messages < self._max_inflight_messages: + if ( + self._max_inflight_messages == 0 + or self._inflight_messages < self._max_inflight_messages + ): self._inflight_messages += 1 if qos == 1: message.state = mqtt_ms_wait_for_puback elif qos == 2: message.state = mqtt_ms_wait_for_pubrec - rc = self._send_publish(message.mid, topic, message.payload, message.qos, message.retain, - message.dup, message.info, message.properties) + rc = self._send_publish( + message.mid, + topic_bytes, + message.payload, + message.qos, + message.retain, + message.dup, + message.info, + message.properties, + ) # remove from inflight messages so it will be send after a connection is made - if rc is MQTT_ERR_NO_CONN: + if rc == MQTTErrorCode.MQTT_ERR_NO_CONN: self._inflight_messages -= 1 message.state = mqtt_ms_publish @@ -1297,10 +1548,12 @@ def publish(self, topic, payload=None, qos=0, retain=False, properties=None): return message.info else: message.state = mqtt_ms_queued - message.info.rc = MQTT_ERR_SUCCESS + message.info.rc = MQTTErrorCode.MQTT_ERR_SUCCESS return message.info - def username_pw_set(self, username, password=None): + def username_pw_set( + self, username: typing.Optional[str], password: typing.Optional[str] = None + ) -> None: """Set a username and optionally a password for broker authentication. Must be called before connect() to have any effect. @@ -1314,12 +1567,13 @@ def username_pw_set(self, username, password=None): """ # [MQTT-3.1.3-11] User name must be UTF-8 encoded string - self._username = None if username is None else username.encode('utf-8') - self._password = password - if isinstance(self._password, str): - self._password = self._password.encode('utf-8') + self._username = None if username is None else username.encode("utf-8") + if isinstance(password, str): + self._password = password.encode("utf-8") + else: + self._password = password - def enable_bridge_mode(self): + def enable_bridge_mode(self) -> None: """Sets the client in a bridge mode instead of client mode. Must be called before connect() to have any effect. @@ -1335,7 +1589,7 @@ def enable_bridge_mode(self): """ self._client_mode = MQTT_BRIDGE - def is_connected(self): + def is_connected(self) -> bool: """Returns the current status of the connection True if connection exists @@ -1343,7 +1597,11 @@ def is_connected(self): """ return self._state == mqtt_cs_connected - def disconnect(self, reasoncode=None, properties=None): + def disconnect( + self, + reasoncode: typing.Optional[ReasonCodes] = None, + properties: typing.Optional[Properties] = None, + ) -> MQTTErrorCode: """Disconnect a connected client from the broker. reasoncode: (MQTT v5.0 only) a ReasonCodes instance setting the MQTT v5.0 reasoncode to be sent with the disconnect. It is optional, the receiver @@ -1358,7 +1616,19 @@ def disconnect(self, reasoncode=None, properties=None): return self._send_disconnect(reasoncode, properties) - def subscribe(self, topic, qos=0, options=None, properties=None): + def subscribe( + self, + topic: typing.Union[ + str, + typing.Tuple[str, int], + typing.Tuple[str, SubscribeOptions], + typing.List[typing.Tuple[str, int]], + typing.List[typing.Tuple[str, SubscribeOptions]], + ], + qos: int = 0, + options: typing.Optional[SubscribeOptions] = None, + properties: typing.Optional[Properties] = None, + ) -> typing.Tuple[MQTTErrorCode, typing.Optional[int]]: """Subscribe the client to one or more topics. This function may be called in three different ways (and a further three for MQTT v5.0): @@ -1442,31 +1712,34 @@ def subscribe(self, topic, qos=0, options=None, properties=None): if isinstance(topic, tuple): if self._protocol == MQTTv5: - topic, options = topic + topic, options = topic # type: ignore if not isinstance(options, SubscribeOptions): raise ValueError( - 'Subscribe options must be instance of SubscribeOptions class.') + "Subscribe options must be instance of SubscribeOptions class." + ) else: - topic, qos = topic + topic, qos = topic # type: ignore if isinstance(topic, (bytes, str)): if qos < 0 or qos > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") if self._protocol == MQTTv5: if options is None: # if no options are provided, use the QoS passed instead options = SubscribeOptions(qos=qos) elif qos != 0: raise ValueError( - 'Subscribe options and qos parameters cannot be combined.') + "Subscribe options and qos parameters cannot be combined." + ) if not isinstance(options, SubscribeOptions): raise ValueError( - 'Subscribe options must be instance of SubscribeOptions class.') - topic_qos_list = [(topic.encode('utf-8'), options)] + "Subscribe options must be instance of SubscribeOptions class." + ) + topic_qos_list = [(topic.encode("utf-8"), options)] else: if topic is None or len(topic) == 0: - raise ValueError('Invalid topic.') - topic_qos_list = [(topic.encode('utf-8'), qos)] + raise ValueError("Invalid topic.") + topic_qos_list = [(topic.encode("utf-8"), qos)] # type: ignore elif isinstance(topic, list): topic_qos_list = [] if self._protocol == MQTTv5: @@ -1474,29 +1747,34 @@ def subscribe(self, topic, qos=0, options=None, properties=None): if not isinstance(o, SubscribeOptions): # then the second value should be QoS if o < 0 or o > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") o = SubscribeOptions(qos=o) - topic_qos_list.append((t.encode('utf-8'), o)) + topic_qos_list.append((t.encode("utf-8"), o)) else: for t, q in topic: - if q < 0 or q > 2: - raise ValueError('Invalid QoS level.') + if isinstance(q, SubscribeOptions) or q < 0 or q > 2: + raise ValueError("Invalid QoS level.") if t is None or len(t) == 0 or not isinstance(t, (bytes, str)): - raise ValueError('Invalid topic.') - topic_qos_list.append((t.encode('utf-8'), q)) + raise ValueError("Invalid topic.") + topic_qos_list.append((t.encode("utf-8"), q)) # type: ignore if topic_qos_list is None: raise ValueError("No topic specified, or incorrect topic type.") - if any(self._filter_wildcard_len_check(topic) != MQTT_ERR_SUCCESS for topic, _ in topic_qos_list): - raise ValueError('Invalid subscription filter.') + if any( + self._filter_wildcard_len_check(topic) != MQTT_ERR_SUCCESS + for topic, _ in topic_qos_list + ): + raise ValueError("Invalid subscription filter.") if self._sock is None: return (MQTT_ERR_NO_CONN, None) return self._send_subscribe(False, topic_qos_list, properties) - def unsubscribe(self, topic, properties=None): + def unsubscribe( + self, topic: str, properties: typing.Optional[Properties] = None + ) -> typing.Tuple[MQTTErrorCode, typing.Optional[int]]: """Unsubscribe the client from one or more topics. topic: A single string, or list of strings that are the subscription @@ -1516,27 +1794,27 @@ def unsubscribe(self, topic, properties=None): """ topic_list = None if topic is None: - raise ValueError('Invalid topic.') + raise ValueError("Invalid topic.") if isinstance(topic, (bytes, str)): if len(topic) == 0: - raise ValueError('Invalid topic.') - topic_list = [topic.encode('utf-8')] + raise ValueError("Invalid topic.") + topic_list = [topic.encode("utf-8")] elif isinstance(topic, list): topic_list = [] for t in topic: if len(t) == 0 or not isinstance(t, (bytes, str)): - raise ValueError('Invalid topic.') - topic_list.append(t.encode('utf-8')) + raise ValueError("Invalid topic.") + topic_list.append(t.encode("utf-8")) if topic_list is None: raise ValueError("No topic specified, or incorrect topic type.") if self._sock is None: - return (MQTT_ERR_NO_CONN, None) + return (MQTTErrorCode.MQTT_ERR_NO_CONN, None) return self._send_unsubscribe(False, topic_list, properties) - def loop_read(self, max_packets=1): + def loop_read(self, max_packets: int = 1) -> MQTTErrorCode: """Process read network events. Use in place of calling loop() if you wish to handle your client reads as part of your own application. @@ -1545,7 +1823,7 @@ def loop_read(self, max_packets=1): Do not use if you are using the threaded interface loop_start().""" if self._sock is None: - return MQTT_ERR_NO_CONN + return MQTTErrorCode.MQTT_ERR_NO_CONN max_packets = len(self._out_messages) + len(self._in_messages) if max_packets < 1: @@ -1553,15 +1831,15 @@ def loop_read(self, max_packets=1): for _ in range(0, max_packets): if self._sock is None: - return MQTT_ERR_NO_CONN + return MQTTErrorCode.MQTT_ERR_NO_CONN rc = self._packet_read() if rc > 0: - return self._loop_rc_handle(rc) - elif rc == MQTT_ERR_AGAIN: - return MQTT_ERR_SUCCESS - return MQTT_ERR_SUCCESS + return self._loop_rc_handle(rc) # type: ignore + elif rc == MQTTErrorCode.MQTT_ERR_AGAIN: + return MQTTErrorCode.MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def loop_write(self, max_packets=1): + def loop_write(self, max_packets: int = 1) -> MQTTErrorCode: """Process write network events. Use in place of calling loop() if you wish to handle your client writes as part of your own application. @@ -1572,23 +1850,23 @@ def loop_write(self, max_packets=1): Do not use if you are using the threaded interface loop_start().""" if self._sock is None: - return MQTT_ERR_NO_CONN + return MQTTErrorCode.MQTT_ERR_NO_CONN try: rc = self._packet_write() - if rc == MQTT_ERR_AGAIN: - return MQTT_ERR_SUCCESS + if rc == MQTTErrorCode.MQTT_ERR_AGAIN: + return MQTTErrorCode.MQTT_ERR_SUCCESS elif rc > 0: - return self._loop_rc_handle(rc) + return self._loop_rc_handle(rc) # type: ignore else: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS finally: if self.want_write(): self._call_socket_register_write() else: self._call_socket_unregister_write() - def want_write(self): + def want_write(self) -> bool: """Call to determine if there is network data waiting to be written. Useful if you are calling select() yourself rather than using loop(). """ @@ -1599,13 +1877,13 @@ def want_write(self): except IndexError: return False - def loop_misc(self): + def loop_misc(self) -> MQTTErrorCode: """Process miscellaneous network events. Use in place of calling loop() if you wish to call select() or equivalent on. Do not use if you are using the threaded interface loop_start().""" if self._sock is None: - return MQTT_ERR_NO_CONN + return MQTTErrorCode.MQTT_ERR_NO_CONN now = time_func() self._check_keepalive() @@ -1616,46 +1894,53 @@ def loop_misc(self): self._sock_close() if self._state == mqtt_cs_disconnecting: - rc = MQTT_ERR_SUCCESS + rc = MQTTErrorCode.MQTT_ERR_SUCCESS else: - rc = MQTT_ERR_KEEPALIVE + rc = MQTTErrorCode.MQTT_ERR_KEEPALIVE self._do_on_disconnect(rc) - return MQTT_ERR_CONN_LOST + return MQTTErrorCode.MQTT_ERR_CONN_LOST - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def max_inflight_messages_set(self, inflight): + def max_inflight_messages_set(self, inflight: int) -> None: """Set the maximum number of messages with QoS>0 that can be part way through their network flow at once. Defaults to 20.""" if inflight < 0: - raise ValueError('Invalid inflight.') + raise ValueError("Invalid inflight.") self._max_inflight_messages = inflight - def max_queued_messages_set(self, queue_size): + def max_queued_messages_set(self, queue_size: int) -> "Client": """Set the maximum number of messages in the outgoing message queue. 0 means unlimited.""" if queue_size < 0: - raise ValueError('Invalid queue size.') + raise ValueError("Invalid queue size.") if not isinstance(queue_size, int): - raise ValueError('Invalid type of queue size.') + raise ValueError("Invalid type of queue size.") self._max_queued_messages = queue_size return self - def message_retry_set(self, retry): + def message_retry_set(self, retry): # type: ignore """No longer used, remove in version 2.0""" pass - def user_data_set(self, userdata): + def user_data_set(self, userdata: typing.Any) -> None: """Set the user data variable passed to callbacks. May be any data type.""" self._userdata = userdata - def user_data_get(self): + def user_data_get(self) -> typing.Any: """Get the user data variable passed to callbacks. May be any data type.""" return self._userdata - def will_set(self, topic, payload=None, qos=0, retain=False, properties=None): + def will_set( + self, + topic: str, + payload: PayloadType = None, + qos: int = 0, + retain: bool = False, + properties: typing.Optional[Properties] = None, + ) -> None: """Set a Will to be sent by the broker in case the client disconnects unexpectedly. This must be called before connect() to have any effect. @@ -1676,35 +1961,35 @@ def will_set(self, topic, payload=None, qos=0, retain=False, properties=None): zero string length. """ if topic is None or len(topic) == 0: - raise ValueError('Invalid topic.') + raise ValueError("Invalid topic.") if qos < 0 or qos > 2: - raise ValueError('Invalid QoS level.') + raise ValueError("Invalid QoS level.") if properties and not isinstance(properties, Properties): raise ValueError( - "The properties argument must be an instance of the Properties class.") + "The properties argument must be an instance of the Properties class." + ) if isinstance(payload, str): - self._will_payload = payload.encode('utf-8') + self._will_payload = payload.encode("utf-8") elif isinstance(payload, (bytes, bytearray)): self._will_payload = payload elif isinstance(payload, (int, float)): - self._will_payload = str(payload).encode('ascii') + self._will_payload = str(payload).encode("ascii") elif payload is None: self._will_payload = b"" else: - raise TypeError( - 'payload must be a string, bytearray, int, float or None.') + raise TypeError("payload must be a string, bytearray, int, float or None.") self._will = True - self._will_topic = topic.encode('utf-8') + self._will_topic = topic.encode("utf-8") self._will_qos = qos self._will_retain = retain self._will_properties = properties - def will_clear(self): - """ Removes a will that was previously configured with will_set(). + def will_clear(self) -> None: + """Removes a will that was previously configured with will_set(). Must be called before connect() to have any effect.""" self._will = False @@ -1713,11 +1998,16 @@ def will_clear(self): self._will_qos = 0 self._will_retain = False - def socket(self): + def socket(self) -> typing.Optional[SocketLike]: """Return the socket or ssl object for this client.""" return self._sock - def loop_forever(self, timeout=1.0, max_packets=1, retry_first_connection=False): + def loop_forever( + self, + timeout: float = 1.0, + max_packets: int = 1, + retry_first_connection: bool = False, + ) -> MQTTErrorCode: """This function calls the network loop functions for you in an infinite blocking loop. It is useful for the case where you only want to run the MQTT client loop in your program. @@ -1749,31 +2039,33 @@ def loop_forever(self, timeout=1.0, max_packets=1, retry_first_connection=False) self._handle_on_connect_fail() if not retry_first_connection: raise - self._easy_log( - MQTT_LOG_DEBUG, "Connection failed, retrying") + self._easy_log(MQTT_LOG_DEBUG, "Connection failed, retrying") self._reconnect_wait() else: break while run: - rc = MQTT_ERR_SUCCESS - while rc == MQTT_ERR_SUCCESS: + rc = MQTTErrorCode.MQTT_ERR_SUCCESS + while rc == MQTTErrorCode.MQTT_ERR_SUCCESS: rc = self._loop(timeout) # We don't need to worry about locking here, because we've # either called loop_forever() when in single threaded mode, or # in multi threaded mode when loop_stop() has been called and # so no other threads can access _out_packet or _messages. - if (self._thread_terminate is True + if ( + self._thread_terminate is True and len(self._out_packet) == 0 - and len(self._out_messages) == 0): - rc = 1 + and len(self._out_messages) == 0 + ): + rc = MQTTErrorCode.MQTT_ERR_NOMEM run = False - def should_exit(): + def should_exit() -> bool: + # B023: uses the run variable from the outer scope on purpose return ( - self._state == mqtt_cs_disconnecting or - run is False or # noqa: B023 (uses the run variable from the outer scope on purpose) - self._thread_terminate is True + self._state == mqtt_cs_disconnecting + or run is False # noqa: B023 + or self._thread_terminate is True ) if should_exit() or not self._reconnect_on_failure: @@ -1788,26 +2080,30 @@ def should_exit(): self.reconnect() except (OSError, WebsocketConnectionError): self._handle_on_connect_fail() - self._easy_log( - MQTT_LOG_DEBUG, "Connection failed, retrying") + self._easy_log(MQTT_LOG_DEBUG, "Connection failed, retrying") return rc - def loop_start(self): + def loop_start(self) -> MQTTErrorCode: """This is part of the threaded client interface. Call this once to start a new thread to process network traffic. This provides an alternative to repeatedly calling loop() yourself. """ if self._thread is not None: - return MQTT_ERR_INVAL + return MQTTErrorCode.MQTT_ERR_INVAL self._sockpairR, self._sockpairW = _socketpair_compat() self._thread_terminate = False - self._thread = threading.Thread(target=self._thread_main, name=f"paho-mqtt-client-{self._client_id.decode()}") + self._thread = threading.Thread( + target=self._thread_main, + name=f"paho-mqtt-client-{self._client_id.decode()}", + ) self._thread.daemon = True self._thread.start() - def loop_stop(self, force=False): + return MQTTErrorCode.MQTT_ERR_SUCCESS + + def loop_stop(self, force: bool = False) -> MQTTErrorCode: """This is part of the threaded client interface. Call this once to stop the network thread previously created with loop_start(). This call will block until the network thread finishes. @@ -1815,22 +2111,24 @@ def loop_stop(self, force=False): The force parameter is currently ignored. """ if self._thread is None: - return MQTT_ERR_INVAL + return MQTTErrorCode.MQTT_ERR_INVAL self._thread_terminate = True if threading.current_thread() != self._thread: self._thread.join() self._thread = None + return MQTTErrorCode.MQTT_ERR_SUCCESS + @property - def on_log(self): + def on_log(self) -> typing.Optional[CallbackOnLog]: """If implemented, called when the client has log information. Defined to allow debugging.""" return self._on_log @on_log.setter - def on_log(self, func): - """ Define the logging callback implementation. + def on_log(self, func: typing.Optional[CallbackOnLog]) -> None: + """Define the logging callback implementation. Expected signature is: log_callback(client, userdata, level, buf) @@ -1847,21 +2145,22 @@ def on_log(self, func): """ self._on_log = func - def log_callback(self): - def decorator(func): + def log_callback(self) -> typing.Callable[[CallbackOnLog], CallbackOnLog]: + def decorator(func: CallbackOnLog) -> CallbackOnLog: self.on_log = func return func + return decorator @property - def on_pre_connect(self): + def on_pre_connect(self) -> typing.Optional[CallbackOnPreConnect]: """If implemented, called immediately prior to the connection is made request.""" return self._on_pre_connect @on_pre_connect.setter - def on_pre_connect(self, func): - """ Define the pre_connect callback implementation. + def on_pre_connect(self, func: typing.Optional[CallbackOnPreConnect]) -> None: + """Define the pre_connect callback implementation. Expected signature: connect_callback(client, userdata) @@ -1876,21 +2175,24 @@ def on_pre_connect(self, func): with self._callback_mutex: self._on_pre_connect = func - def pre_connect_callback(self): - def decorator(func): + def pre_connect_callback( + self, + ) -> typing.Callable[[CallbackOnPreConnect], CallbackOnPreConnect]: + def decorator(func: CallbackOnPreConnect) -> CallbackOnPreConnect: self.on_pre_connect = func return func + return decorator @property - def on_connect(self): + def on_connect(self) -> typing.Optional[CallbackOnConnect]: """If implemented, called when the broker responds to our connection request.""" return self._on_connect @on_connect.setter - def on_connect(self, func): - """ Define the connect callback implementation. + def on_connect(self, func: typing.Optional[CallbackOnConnect]) -> None: + """Define the connect callback implementation. Expected signature for MQTT v3.1 and v3.1.1 is: connect_callback(client, userdata, flags, rc) @@ -1932,21 +2234,24 @@ def on_connect(self, func): with self._callback_mutex: self._on_connect = func - def connect_callback(self): - def decorator(func): + def connect_callback( + self, + ) -> typing.Callable[[CallbackOnConnect], CallbackOnConnect]: + def decorator(func: CallbackOnConnect) -> CallbackOnConnect: self.on_connect = func return func + return decorator @property - def on_connect_fail(self): + def on_connect_fail(self) -> typing.Optional[CallbackOnConnectFail]: """If implemented, called when the client failed to connect to the broker.""" return self._on_connect_fail @on_connect_fail.setter - def on_connect_fail(self, func): - """ Define the connection failure callback implementation + def on_connect_fail(self, func: typing.Optional[CallbackOnConnectFail]) -> None: + """Define the connection failure callback implementation Expected signature is: on_connect_fail(client, userdata) @@ -1961,21 +2266,24 @@ def on_connect_fail(self, func): with self._callback_mutex: self._on_connect_fail = func - def connect_fail_callback(self): - def decorator(func): + def connect_fail_callback( + self, + ) -> typing.Callable[[CallbackOnConnectFail], CallbackOnConnectFail]: + def decorator(func: CallbackOnConnectFail) -> CallbackOnConnectFail: self.on_connect_fail = func return func + return decorator @property - def on_subscribe(self): + def on_subscribe(self) -> typing.Optional[CallbackOnSubscribe]: """If implemented, called when the broker responds to a subscribe request.""" return self._on_subscribe @on_subscribe.setter - def on_subscribe(self, func): - """ Define the subscribe callback implementation. + def on_subscribe(self, func: typing.Optional[CallbackOnSubscribe]) -> None: + """Define the subscribe callback implementation. Expected signature for MQTT v3.1.1 and v3.1 is: subscribe_callback(client, userdata, mid, granted_qos) @@ -2000,14 +2308,17 @@ def on_subscribe(self, func): with self._callback_mutex: self._on_subscribe = func - def subscribe_callback(self): - def decorator(func): + def subscribe_callback( + self, + ) -> typing.Callable[[CallbackOnSubscribe], CallbackOnSubscribe]: + def decorator(func: CallbackOnSubscribe) -> CallbackOnSubscribe: self.on_subscribe = func return func + return decorator @property - def on_message(self): + def on_message(self) -> typing.Optional[CallbackOnMessage]: """If implemented, called when a message has been received on a topic that the client subscribes to. @@ -2017,8 +2328,8 @@ def on_message(self): return self._on_message @on_message.setter - def on_message(self, func): - """ Define the message received callback implementation. + def on_message(self, func: typing.Optional[CallbackOnMessage]) -> None: + """Define the message received callback implementation. Expected signature is: on_message_callback(client, userdata, message) @@ -2035,14 +2346,17 @@ def on_message(self, func): with self._callback_mutex: self._on_message = func - def message_callback(self): - def decorator(func): + def message_callback( + self, + ) -> typing.Callable[[CallbackOnMessage], CallbackOnMessage]: + def decorator(func: CallbackOnMessage) -> CallbackOnMessage: self.on_message = func return func + return decorator @property - def on_publish(self): + def on_publish(self) -> typing.Optional[CallbackOnPublish]: """If implemented, called when a message that was to be sent using the publish() call has completed transmission to the broker. @@ -2054,8 +2368,8 @@ def on_publish(self): return self._on_publish @on_publish.setter - def on_publish(self, func): - """ Define the published message callback implementation. + def on_publish(self, func: typing.Optional[CallbackOnPublish]) -> None: + """Define the published message callback implementation. Expected signature is: on_publish_callback(client, userdata, mid) @@ -2072,21 +2386,24 @@ def on_publish(self, func): with self._callback_mutex: self._on_publish = func - def publish_callback(self): - def decorator(func): + def publish_callback( + self, + ) -> typing.Callable[[CallbackOnPublish], CallbackOnPublish]: + def decorator(func: CallbackOnPublish) -> CallbackOnPublish: self.on_publish = func return func + return decorator @property - def on_unsubscribe(self): + def on_unsubscribe(self) -> typing.Optional[CallbackOnUnsubscribe]: """If implemented, called when the broker responds to an unsubscribe request.""" return self._on_unsubscribe @on_unsubscribe.setter - def on_unsubscribe(self, func): - """ Define the unsubscribe callback implementation. + def on_unsubscribe(self, func: typing.Optional[CallbackOnUnsubscribe]) -> None: + """Define the unsubscribe callback implementation. Expected signature for MQTT v3.1.1 and v3.1 is: unsubscribe_callback(client, userdata, mid) @@ -2109,21 +2426,23 @@ def on_unsubscribe(self, func): with self._callback_mutex: self._on_unsubscribe = func - def unsubscribe_callback(self): - def decorator(func): + def unsubscribe_callback( + self, + ) -> typing.Callable[[CallbackOnUnsubscribe], CallbackOnUnsubscribe]: + def decorator(func: CallbackOnUnsubscribe) -> CallbackOnUnsubscribe: self.on_unsubscribe = func return func + return decorator @property - def on_disconnect(self): - """If implemented, called when the client disconnects from the broker. - """ + def on_disconnect(self) -> typing.Optional[CallbackOnDisconnect]: + """If implemented, called when the client disconnects from the broker.""" return self._on_disconnect @on_disconnect.setter - def on_disconnect(self, func): - """ Define the disconnect callback implementation. + def on_disconnect(self, func: typing.Optional[CallbackOnDisconnect]) -> None: + """Define the disconnect callback implementation. Expected signature for MQTT v3.1.1 and v3.1 is: disconnect_callback(client, userdata, rc) @@ -2146,19 +2465,22 @@ def on_disconnect(self, func): with self._callback_mutex: self._on_disconnect = func - def disconnect_callback(self): - def decorator(func): + def disconnect_callback( + self, + ) -> typing.Callable[[CallbackOnDisconnect], CallbackOnDisconnect]: + def decorator(func: CallbackOnDisconnect) -> CallbackOnDisconnect: self.on_disconnect = func return func + return decorator @property - def on_socket_open(self): + def on_socket_open(self) -> typing.Optional[CallbackOnSocket]: """If implemented, called just after the socket was opend.""" return self._on_socket_open @on_socket_open.setter - def on_socket_open(self, func): + def on_socket_open(self, func: typing.Optional[CallbackOnSocket]) -> None: """Define the socket_open callback implementation. This should be used to register the socket to an external event loop for reading. @@ -2176,34 +2498,45 @@ def on_socket_open(self, func): with self._callback_mutex: self._on_socket_open = func - def socket_open_callback(self): - def decorator(func): + def socket_open_callback( + self, + ) -> typing.Callable[[CallbackOnSocket], CallbackOnSocket]: + def decorator(func: CallbackOnSocket) -> CallbackOnSocket: self.on_socket_open = func return func + return decorator - def _call_socket_open(self): + def _call_socket_open(self) -> None: """Call the socket_open callback with the just-opened socket""" with self._callback_mutex: on_socket_open = self.on_socket_open if on_socket_open: with self._in_callback_mutex: + if self._sock is None: + self._easy_log( + MQTT_LOG_ERR, + "socket() is None in _call_socket_open", + ) + return + try: on_socket_open(self, self._userdata, self._sock) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_open: %s', err) + MQTT_LOG_ERR, "Caught exception in on_socket_open: %s", err + ) if not self.suppress_exceptions: raise @property - def on_socket_close(self): + def on_socket_close(self) -> typing.Optional[CallbackOnSocket]: """If implemented, called just before the socket is closed.""" return self._on_socket_close @on_socket_close.setter - def on_socket_close(self, func): + def on_socket_close(self, func: typing.Optional[CallbackOnSocket]) -> None: """Define the socket_close callback implementation. This should be used to unregister the socket from an external event loop for reading. @@ -2221,13 +2554,16 @@ def on_socket_close(self, func): with self._callback_mutex: self._on_socket_close = func - def socket_close_callback(self): - def decorator(func): + def socket_close_callback( + self, + ) -> typing.Callable[[CallbackOnSocket], CallbackOnSocket]: + def decorator(func: CallbackOnSocket) -> CallbackOnSocket: self.on_socket_close = func return func + return decorator - def _call_socket_close(self, sock): + def _call_socket_close(self, sock: SocketLike) -> None: """Call the socket_close callback with the about-to-be-closed socket""" with self._callback_mutex: on_socket_close = self.on_socket_close @@ -2238,17 +2574,18 @@ def _call_socket_close(self, sock): on_socket_close(self, self._userdata, sock) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_close: %s', err) + MQTT_LOG_ERR, "Caught exception in on_socket_close: %s", err + ) if not self.suppress_exceptions: raise @property - def on_socket_register_write(self): + def on_socket_register_write(self) -> typing.Optional[CallbackOnSocket]: """If implemented, called when the socket needs writing but can't.""" return self._on_socket_register_write @on_socket_register_write.setter - def on_socket_register_write(self, func): + def on_socket_register_write(self, func: typing.Optional[CallbackOnSocket]) -> None: """Define the socket_register_write callback implementation. This should be used to register the socket with an external event loop for writing. @@ -2266,13 +2603,16 @@ def on_socket_register_write(self, func): with self._callback_mutex: self._on_socket_register_write = func - def socket_register_write_callback(self): - def decorator(func): + def socket_register_write_callback( + self, + ) -> typing.Callable[[CallbackOnSocket], CallbackOnSocket]: + def decorator(func: CallbackOnSocket) -> CallbackOnSocket: self._on_socket_register_write = func return func + return decorator - def _call_socket_register_write(self): + def _call_socket_register_write(self) -> None: """Call the socket_register_write callback with the unwritable socket""" if not self._sock or self._registered_write: return @@ -2282,21 +2622,27 @@ def _call_socket_register_write(self): if on_socket_register_write: try: - on_socket_register_write( - self, self._userdata, self._sock) + on_socket_register_write(self, self._userdata, self._sock) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_register_write: %s', err) + MQTT_LOG_ERR, + "Caught exception in on_socket_register_write: %s", + err, + ) if not self.suppress_exceptions: raise @property - def on_socket_unregister_write(self): + def on_socket_unregister_write( + self, + ) -> typing.Optional[CallbackOnSocket]: """If implemented, called when the socket doesn't need writing anymore.""" return self._on_socket_unregister_write @on_socket_unregister_write.setter - def on_socket_unregister_write(self, func): + def on_socket_unregister_write( + self, func: typing.Optional[CallbackOnSocket] + ) -> None: """Define the socket_unregister_write callback implementation. This should be used to unregister the socket from an external event loop for writing. @@ -2314,13 +2660,20 @@ def on_socket_unregister_write(self, func): with self._callback_mutex: self._on_socket_unregister_write = func - def socket_unregister_write_callback(self): - def decorator(func): + def socket_unregister_write_callback( + self, + ) -> typing.Callable[[CallbackOnSocket], CallbackOnSocket]: + def decorator( + func: CallbackOnSocket, + ) -> CallbackOnSocket: self._on_socket_unregister_write = func return func + return decorator - def _call_socket_unregister_write(self, sock=None): + def _call_socket_unregister_write( + self, sock: typing.Optional[SocketLike] = None + ) -> None: """Call the socket_unregister_write callback with the writable socket""" sock = sock or self._sock if not sock or not self._registered_write: @@ -2335,11 +2688,14 @@ def _call_socket_unregister_write(self, sock=None): on_socket_unregister_write(self, self._userdata, sock) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_socket_unregister_write: %s', err) + MQTT_LOG_ERR, + "Caught exception in on_socket_unregister_write: %s", + err, + ) if not self.suppress_exceptions: raise - def message_callback_add(self, sub, callback): + def message_callback_add(self, sub: str, callback: CallbackOnMessage) -> None: """Register a message callback for a specific topic. Messages that match 'sub' will be passed to 'callback'. Any non-matching messages will be passed to the default on_message @@ -2356,13 +2712,16 @@ def message_callback_add(self, sub, callback): with self._callback_mutex: self._on_message_filtered[sub] = callback - def topic_callback(self, sub): - def decorator(func): + def topic_callback( + self, sub: str + ) -> typing.Callable[[CallbackOnMessage], CallbackOnMessage]: + def decorator(func: CallbackOnMessage) -> CallbackOnMessage: self.message_callback_add(sub, func) return func + return decorator - def message_callback_remove(self, sub): + def message_callback_remove(self, sub: str) -> None: """Remove a message callback previously registered with message_callback_add().""" if sub is None: @@ -2378,18 +2737,22 @@ def message_callback_remove(self, sub): # Private functions # ============================================================ - def _loop_rc_handle(self, rc, properties=None): + def _loop_rc_handle( + self, + rc: typing.Union[MQTTErrorCode, ReasonCodes, None], + properties: typing.Optional[Properties] = None, + ) -> typing.Union[MQTTErrorCode, ReasonCodes, None]: if rc: self._sock_close() if self._state == mqtt_cs_disconnecting: - rc = MQTT_ERR_SUCCESS + rc = MQTTErrorCode.MQTT_ERR_SUCCESS self._do_on_disconnect(rc, properties) return rc - def _packet_read(self): + def _packet_read(self) -> MQTTErrorCode: # This gets called if pselect() indicates that there is network data # available - ie. at least one byte. What we do depends on what data we # already have. @@ -2403,26 +2766,24 @@ def _packet_read(self): # fail due to longer length, so save current data and current position. # After all data is read, send to _mqtt_handle_packet() to deal with. # Finally, free the memory and reset everything to starting conditions. - if self._in_packet['command'] == 0: + if self._in_packet["command"] == 0: try: command = self._sock_recv(1) except BlockingIOError: - return MQTT_ERR_AGAIN + return MQTTErrorCode.MQTT_ERR_AGAIN except ConnectionError as err: - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) - return MQTT_ERR_CONN_LOST + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) + return MQTTErrorCode.MQTT_ERR_CONN_LOST except TimeoutError as err: - self._easy_log( - MQTT_LOG_ERR, 'timeout on socket: %s', err) - return MQTT_ERR_CONN_LOST + self._easy_log(MQTT_LOG_ERR, "timeout on socket: %s", err) + return MQTTErrorCode.MQTT_ERR_CONN_LOST else: if len(command) == 0: - return MQTT_ERR_CONN_LOST - command, = struct.unpack("!B", command) - self._in_packet['command'] = command + return MQTTErrorCode.MQTT_ERR_CONN_LOST + (command_value,) = struct.unpack("!B", command) + self._in_packet["command"] = command_value - if self._in_packet['have_remaining'] == 0: + if self._in_packet["have_remaining"] == 0: # Read remaining # Algorithm for decoding taken from pseudo code at # http://publib.boulder.ibm.com/infocenter/wmbhelp/v6r0m0/topic/com.ibm.etools.mft.doc/ac10870_.htm @@ -2430,122 +2791,128 @@ def _packet_read(self): try: byte = self._sock_recv(1) except BlockingIOError: - return MQTT_ERR_AGAIN + return MQTTErrorCode.MQTT_ERR_AGAIN except ConnectionError as err: - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) - return MQTT_ERR_CONN_LOST + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) + return MQTTErrorCode.MQTT_ERR_CONN_LOST else: if len(byte) == 0: - return MQTT_ERR_CONN_LOST - byte, = struct.unpack("!B", byte) - self._in_packet['remaining_count'].append(byte) + return MQTTErrorCode.MQTT_ERR_CONN_LOST + (byte_value,) = struct.unpack("!B", byte) + self._in_packet["remaining_count"].append(byte_value) # Max 4 bytes length for remaining length as defined by protocol. # Anything more likely means a broken/malicious client. - if len(self._in_packet['remaining_count']) > 4: - return MQTT_ERR_PROTOCOL - - self._in_packet['remaining_length'] += ( - byte & 127) * self._in_packet['remaining_mult'] - self._in_packet['remaining_mult'] = self._in_packet['remaining_mult'] * 128 + if len(self._in_packet["remaining_count"]) > 4: + return MQTTErrorCode.MQTT_ERR_PROTOCOL + + self._in_packet["remaining_length"] += ( + byte_value & 127 + ) * self._in_packet["remaining_mult"] + self._in_packet["remaining_mult"] = ( + self._in_packet["remaining_mult"] * 128 + ) - if (byte & 128) == 0: + if (byte_value & 128) == 0: break - self._in_packet['have_remaining'] = 1 - self._in_packet['to_process'] = self._in_packet['remaining_length'] + self._in_packet["have_remaining"] = 1 + self._in_packet["to_process"] = self._in_packet["remaining_length"] - count = 100 # Don't get stuck in this loop if we have a huge message. - while self._in_packet['to_process'] > 0: + count = 100 # Don't get stuck in this loop if we have a huge message. + while self._in_packet["to_process"] > 0: try: - data = self._sock_recv(self._in_packet['to_process']) + data = self._sock_recv(self._in_packet["to_process"]) except BlockingIOError: - return MQTT_ERR_AGAIN + return MQTTErrorCode.MQTT_ERR_AGAIN except ConnectionError as err: - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) - return MQTT_ERR_CONN_LOST + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) + return MQTTErrorCode.MQTT_ERR_CONN_LOST else: if len(data) == 0: - return MQTT_ERR_CONN_LOST - self._in_packet['to_process'] -= len(data) - self._in_packet['packet'] += data + return MQTTErrorCode.MQTT_ERR_CONN_LOST + self._in_packet["to_process"] -= len(data) + self._in_packet["packet"] += data count -= 1 if count == 0: with self._msgtime_mutex: self._last_msg_in = time_func() - return MQTT_ERR_AGAIN + return MQTTErrorCode.MQTT_ERR_AGAIN # All data for this packet is read. - self._in_packet['pos'] = 0 + self._in_packet["pos"] = 0 rc = self._packet_handle() # Free data and reset values self._in_packet = { - 'command': 0, - 'have_remaining': 0, - 'remaining_count': [], - 'remaining_mult': 1, - 'remaining_length': 0, - 'packet': bytearray(b""), - 'to_process': 0, - 'pos': 0} + "command": 0, + "have_remaining": 0, + "remaining_count": [], + "remaining_mult": 1, + "remaining_length": 0, + "packet": bytearray(b""), + "to_process": 0, + "pos": 0, + } with self._msgtime_mutex: self._last_msg_in = time_func() return rc - def _packet_write(self): + def _packet_write(self) -> MQTTErrorCode: while True: try: packet = self._out_packet.popleft() except IndexError: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS try: - write_length = self._sock_send( - packet['packet'][packet['pos']:]) + write_length = self._sock_send(packet["packet"][packet["pos"] :]) except (AttributeError, ValueError): self._out_packet.appendleft(packet) - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS except BlockingIOError: self._out_packet.appendleft(packet) - return MQTT_ERR_AGAIN + return MQTTErrorCode.MQTT_ERR_AGAIN except ConnectionError as err: self._out_packet.appendleft(packet) - self._easy_log( - MQTT_LOG_ERR, 'failed to receive on socket: %s', err) - return MQTT_ERR_CONN_LOST + self._easy_log(MQTT_LOG_ERR, "failed to receive on socket: %s", err) + return MQTTErrorCode.MQTT_ERR_CONN_LOST if write_length > 0: - packet['to_process'] -= write_length - packet['pos'] += write_length + packet["to_process"] -= write_length + packet["pos"] += write_length - if packet['to_process'] == 0: - if (packet['command'] & 0xF0) == PUBLISH and packet['qos'] == 0: + if packet["to_process"] == 0: + if (packet["command"] & 0xF0) == PUBLISH and packet["qos"] == 0: with self._callback_mutex: on_publish = self.on_publish if on_publish: with self._in_callback_mutex: try: - on_publish( - self, self._userdata, packet['mid']) + on_publish(self, self._userdata, packet["mid"]) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_publish: %s', err) + MQTT_LOG_ERR, + "Caught exception in on_publish: %s", + err, + ) if not self.suppress_exceptions: raise - packet['info']._set_as_published() + # TODO: Something is odd here. I don't see why packet["info"] can't be None. + # A packet could be produced by _handle_connack with qos=0 and no info + # (around line 3645). Ignore the mypy check for now but I feel their is a bug + # somewhere. + packet["info"]._set_as_published() # type: ignore - if (packet['command'] & 0xF0) == DISCONNECT: + if (packet["command"] & 0xF0) == DISCONNECT: with self._msgtime_mutex: self._last_msg_out = time_func() - self._do_on_disconnect(MQTT_ERR_SUCCESS) + self._do_on_disconnect(MQTTErrorCode.MQTT_ERR_SUCCESS) self._sock_close() - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS else: # We haven't finished with this packet @@ -2556,9 +2923,9 @@ def _packet_write(self): with self._msgtime_mutex: self._last_msg_out = time_func() - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _easy_log(self, level, fmt, *args): + def _easy_log(self, level: int, fmt: str, *args: typing.Any) -> None: if self.on_log is not None: buf = fmt % args try: @@ -2570,9 +2937,9 @@ def _easy_log(self, level, fmt, *args): level_std = LOGGING_LEVEL[level] self._logger.log(level_std, fmt, *args) - def _check_keepalive(self): + def _check_keepalive(self) -> None: if self._keepalive == 0: - return MQTT_ERR_SUCCESS + return now = time_func() @@ -2580,7 +2947,10 @@ def _check_keepalive(self): last_msg_out = self._last_msg_out last_msg_in = self._last_msg_in - if self._sock is not None and (now - last_msg_out >= self._keepalive or now - last_msg_in >= self._keepalive): + if self._sock is not None and ( + now - last_msg_out >= self._keepalive + or now - last_msg_in >= self._keepalive + ): if self._state == mqtt_cs_connected and self._ping_t == 0: try: self._send_pingreq() @@ -2595,13 +2965,13 @@ def _check_keepalive(self): self._sock_close() if self._state == mqtt_cs_disconnecting: - rc = MQTT_ERR_SUCCESS + rc = MQTTErrorCode.MQTT_ERR_SUCCESS else: - rc = MQTT_ERR_KEEPALIVE + rc = MQTTErrorCode.MQTT_ERR_KEEPALIVE self._do_on_disconnect(rc) - def _mid_generate(self): + def _mid_generate(self) -> int: with self._mid_generate_mutex: self._last_mid += 1 if self._last_mid == 65536: @@ -2609,44 +2979,49 @@ def _mid_generate(self): return self._last_mid @staticmethod - def _topic_wildcard_len_check(topic): + def _topic_wildcard_len_check(topic: bytes) -> MQTTErrorCode: # Search for + or # in a topic. Return MQTT_ERR_INVAL if found. # Also returns MQTT_ERR_INVAL if the topic string is too long. # Returns MQTT_ERR_SUCCESS if everything is fine. - if b'+' in topic or b'#' in topic or len(topic) > 65535: - return MQTT_ERR_INVAL + if b"+" in topic or b"#" in topic or len(topic) > 65535: + return MQTTErrorCode.MQTT_ERR_INVAL else: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS @staticmethod - def _filter_wildcard_len_check(sub): - if (len(sub) == 0 or len(sub) > 65535 - or any(b'+' in p or b'#' in p for p in sub.split(b'/') if len(p) > 1) - or b'#/' in sub): - return MQTT_ERR_INVAL + def _filter_wildcard_len_check(sub: bytes) -> MQTTErrorCode: + if ( + len(sub) == 0 + or len(sub) > 65535 + or any(b"+" in p or b"#" in p for p in sub.split(b"/") if len(p) > 1) + or b"#/" in sub + ): + return MQTTErrorCode.MQTT_ERR_INVAL else: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _send_pingreq(self): + def _send_pingreq(self) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PINGREQ") rc = self._send_simple_command(PINGREQ) - if rc == MQTT_ERR_SUCCESS: + if rc == MQTTErrorCode.MQTT_ERR_SUCCESS: self._ping_t = time_func() return rc - def _send_pingresp(self): + def _send_pingresp(self) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PINGRESP") return self._send_simple_command(PINGRESP) - def _send_puback(self, mid): + def _send_puback(self, mid: int) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PUBACK (Mid: %d)", mid) return self._send_command_with_mid(PUBACK, mid, False) - def _send_pubcomp(self, mid): + def _send_pubcomp(self, mid: int) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PUBCOMP (Mid: %d)", mid) return self._send_command_with_mid(PUBCOMP, mid, False) - def _pack_remaining_length(self, packet, remaining_length): + def _pack_remaining_length( + self, packet: bytearray, remaining_length: int + ) -> bytearray: remaining_bytes = [] while True: byte = remaining_length % 128 @@ -2661,21 +3036,31 @@ def _pack_remaining_length(self, packet, remaining_length): # FIXME - this doesn't deal with incorrectly large payloads return packet - def _pack_str16(self, packet, data): + def _pack_str16(self, packet: bytearray, data: typing.Union[bytes, str]) -> None: if isinstance(data, str): - data = data.encode('utf-8') + data = data.encode("utf-8") packet.extend(struct.pack("!H", len(data))) packet.extend(data) - def _send_publish(self, mid, topic, payload=b'', qos=0, retain=False, dup=False, info=None, properties=None): + def _send_publish( + self, + mid: int, + topic: bytes, + payload: bytes = b"", + qos: int = 0, + retain: bool = False, + dup: bool = False, + info: typing.Optional[MQTTMessageInfo] = None, + properties: typing.Optional[Properties] = None, + ) -> MQTTErrorCode: # we assume that topic and payload are already properly encoded if not isinstance(topic, bytes): - raise TypeError('topic must be bytes, not str') + raise TypeError("topic must be bytes, not str") if payload and not isinstance(payload, bytes): - raise TypeError('payload must be bytes if set') + raise TypeError("payload must be bytes if set") if self._sock is None: - return MQTT_ERR_NO_CONN + return MQTTErrorCode.MQTT_ERR_NO_CONN command = PUBLISH | ((dup & 0x1) << 3) | (qos << 1) | retain packet = bytearray() @@ -2689,26 +3074,46 @@ def _send_publish(self, mid, topic, payload=b'', qos=0, retain=False, dup=False, self._easy_log( MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', properties=%s (NULL payload)", - dup, qos, retain, mid, topic, properties + dup, + qos, + retain, + mid, + topic, + properties, ) else: self._easy_log( MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s' (NULL payload)", - dup, qos, retain, mid, topic + dup, + qos, + retain, + mid, + topic, ) else: if self._protocol == MQTTv5: self._easy_log( MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', properties=%s, ... (%d bytes)", - dup, qos, retain, mid, topic, properties, payloadlen + dup, + qos, + retain, + mid, + topic, + properties, + payloadlen, ) else: self._easy_log( MQTT_LOG_DEBUG, "Sending PUBLISH (d%d, q%d, r%d, m%d), '%s', ... (%d bytes)", - dup, qos, retain, mid, topic, payloadlen + dup, + qos, + retain, + mid, + topic, + payloadlen, ) if qos > 0: @@ -2717,7 +3122,7 @@ def _send_publish(self, mid, topic, payload=b'', qos=0, retain=False, dup=False, if self._protocol == MQTTv5: if properties is None: - packed_properties = b'\x00' + packed_properties = b"\x00" else: packed_properties = properties.pack() remaining_length += len(packed_properties) @@ -2736,51 +3141,55 @@ def _send_publish(self, mid, topic, payload=b'', qos=0, retain=False, dup=False, return self._packet_queue(PUBLISH, packet, mid, qos, info) - def _send_pubrec(self, mid): + def _send_pubrec(self, mid: int) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PUBREC (Mid: %d)", mid) return self._send_command_with_mid(PUBREC, mid, False) - def _send_pubrel(self, mid): + def _send_pubrel(self, mid: int) -> MQTTErrorCode: self._easy_log(MQTT_LOG_DEBUG, "Sending PUBREL (Mid: %d)", mid) return self._send_command_with_mid(PUBREL | 2, mid, False) - def _send_command_with_mid(self, command, mid, dup): + def _send_command_with_mid(self, command: int, mid: int, dup: int) -> MQTTErrorCode: # For PUBACK, PUBCOMP, PUBREC, and PUBREL if dup: command |= 0x8 remaining_length = 2 - packet = struct.pack('!BBH', command, remaining_length, mid) + packet = struct.pack("!BBH", command, remaining_length, mid) return self._packet_queue(command, packet, mid, 1) - def _send_simple_command(self, command): + def _send_simple_command(self, command: int) -> MQTTErrorCode: # For DISCONNECT, PINGREQ and PINGRESP remaining_length = 0 - packet = struct.pack('!BB', command, remaining_length) + packet = struct.pack("!BB", command, remaining_length) return self._packet_queue(command, packet, 0, 0) - def _send_connect(self, keepalive): + def _send_connect(self, keepalive: int) -> MQTTErrorCode: proto_ver = self._protocol # hard-coded UTF-8 encoded string protocol = b"MQTT" if proto_ver >= MQTTv311 else b"MQIsdp" - remaining_length = 2 + len(protocol) + 1 + \ - 1 + 2 + 2 + len(self._client_id) + remaining_length = 2 + len(protocol) + 1 + 1 + 2 + 2 + len(self._client_id) connect_flags = 0 if self._protocol == MQTTv5: if self._clean_start is True: connect_flags |= 0x02 - elif self._clean_start == MQTT_CLEAN_START_FIRST_ONLY and self._mqttv5_first_connect: + elif ( + self._clean_start == MQTT_CLEAN_START_FIRST_ONLY + and self._mqttv5_first_connect + ): connect_flags |= 0x02 elif self._clean_session: connect_flags |= 0x02 if self._will: - remaining_length += 2 + \ - len(self._will_topic) + 2 + len(self._will_payload) - connect_flags |= 0x04 | ((self._will_qos & 0x03) << 3) | ( - (self._will_retain & 0x01) << 5) + remaining_length += 2 + len(self._will_topic) + 2 + len(self._will_payload) + connect_flags |= ( + 0x04 + | ((self._will_qos & 0x03) << 3) + | ((self._will_retain & 0x01) << 5) + ) if self._username is not None: remaining_length += 2 + len(self._username) @@ -2791,13 +3200,13 @@ def _send_connect(self, keepalive): if self._protocol == MQTTv5: if self._connect_properties is None: - packed_connect_properties = b'\x00' + packed_connect_properties = b"\x00" else: packed_connect_properties = self._connect_properties.pack() remaining_length += len(packed_connect_properties) if self._will: if self._will_properties is None: - packed_will_properties = b'\x00' + packed_will_properties = b"\x00" else: packed_will_properties = self._will_properties.pack() remaining_length += len(packed_will_properties) @@ -2812,10 +3221,16 @@ def _send_connect(self, keepalive): proto_ver |= 0x80 self._pack_remaining_length(packet, remaining_length) - packet.extend(struct.pack( - f"!H{len(protocol)}sBBH", - len(protocol), protocol, proto_ver, connect_flags, keepalive, - )) + packet.extend( + struct.pack( + f"!H{len(protocol)}sBBH", + len(protocol), + protocol, + proto_ver, + connect_flags, + keepalive, + ) + ) if self._protocol == MQTTv5: packet += packed_connect_properties @@ -2847,7 +3262,7 @@ def _send_connect(self, keepalive): (connect_flags & 0x2) >> 1, keepalive, self._client_id, - self._connect_properties + self._connect_properties, ) else: self._easy_log( @@ -2860,16 +3275,22 @@ def _send_connect(self, keepalive): (connect_flags & 0x4) >> 2, (connect_flags & 0x2) >> 1, keepalive, - self._client_id + self._client_id, ) return self._packet_queue(command, packet, 0, 0) - def _send_disconnect(self, reasoncode=None, properties=None): + def _send_disconnect( + self, + reasoncode: typing.Optional[ReasonCodes] = None, + properties: typing.Optional[Properties] = None, + ) -> MQTTErrorCode: if self._protocol == MQTTv5: - self._easy_log(MQTT_LOG_DEBUG, "Sending DISCONNECT reasonCode=%s properties=%s", - reasoncode, - properties - ) + self._easy_log( + MQTT_LOG_DEBUG, + "Sending DISCONNECT reasonCode=%s properties=%s", + reasoncode, + properties, + ) else: self._easy_log(MQTT_LOG_DEBUG, "Sending DISCONNECT") @@ -2898,11 +3319,18 @@ def _send_disconnect(self, reasoncode=None, properties=None): return self._packet_queue(command, packet, 0, 0) - def _send_subscribe(self, dup, topics, properties=None): + def _send_subscribe( + self, + dup: int, + topics: typing.Sequence[ + typing.Tuple[bytes, typing.Union[SubscribeOptions, int]] + ], + properties: typing.Optional[Properties] = None, + ) -> typing.Tuple[MQTTErrorCode, int]: remaining_length = 2 if self._protocol == MQTTv5: if properties is None: - packed_subscribe_properties = b'\x00' + packed_subscribe_properties = b"\x00" else: packed_subscribe_properties = properties.pack() remaining_length += len(packed_subscribe_properties) @@ -2922,9 +3350,9 @@ def _send_subscribe(self, dup, topics, properties=None): for t, q in topics: self._pack_str16(packet, t) if self._protocol == MQTTv5: - packet += q.pack() + packet += q.pack() # type: ignore else: - packet.append(q) + packet.append(q) # type: ignore self._easy_log( MQTT_LOG_DEBUG, @@ -2935,11 +3363,16 @@ def _send_subscribe(self, dup, topics, properties=None): ) return (self._packet_queue(command, packet, local_mid, 1), local_mid) - def _send_unsubscribe(self, dup, topics, properties=None): + def _send_unsubscribe( + self, + dup: int, + topics: typing.List[bytes], + properties: typing.Optional[Properties] = None, + ) -> typing.Tuple[MQTTErrorCode, int]: remaining_length = 2 if self._protocol == MQTTv5: if properties is None: - packed_unsubscribe_properties = b'\x00' + packed_unsubscribe_properties = b"\x00" else: packed_unsubscribe_properties = properties.pack() remaining_length += len(packed_unsubscribe_properties) @@ -2979,21 +3412,24 @@ def _send_unsubscribe(self, dup, topics, properties=None): ) return (self._packet_queue(command, packet, local_mid, 1), local_mid) - def _check_clean_session(self): + def _check_clean_session(self) -> bool: if self._protocol == MQTTv5: if self._clean_start == MQTT_CLEAN_START_FIRST_ONLY: return self._mqttv5_first_connect else: - return self._clean_start + return self._clean_start # type: ignore else: return self._clean_session - def _messages_reconnect_reset_out(self): + def _messages_reconnect_reset_out(self) -> None: with self._out_message_mutex: self._inflight_messages = 0 for m in self._out_messages.values(): m.timestamp = 0 - if self._max_inflight_messages == 0 or self._inflight_messages < self._max_inflight_messages: + if ( + self._max_inflight_messages == 0 + or self._inflight_messages < self._max_inflight_messages + ): if m.qos == 0: m.state = mqtt_ms_publish elif m.qos == 1: @@ -3017,7 +3453,7 @@ def _messages_reconnect_reset_out(self): else: m.state = mqtt_ms_queued - def _messages_reconnect_reset_in(self): + def _messages_reconnect_reset_in(self) -> None: with self._in_message_mutex: if self._check_clean_session(): self._in_messages = collections.OrderedDict() @@ -3030,19 +3466,27 @@ def _messages_reconnect_reset_in(self): # Preserve current state pass - def _messages_reconnect_reset(self): + def _messages_reconnect_reset(self) -> None: self._messages_reconnect_reset_out() self._messages_reconnect_reset_in() - def _packet_queue(self, command, packet, mid, qos, info=None): - mpkt = { - 'command': command, - 'mid': mid, - 'qos': qos, - 'pos': 0, - 'to_process': len(packet), - 'packet': packet, - 'info': info} + def _packet_queue( + self, + command: int, + packet: bytes, + mid: int, + qos: int, + info: typing.Optional[MQTTMessageInfo] = None, + ) -> MQTTErrorCode: + mpkt: "_OutPacket" = { + "command": command, + "mid": mid, + "qos": qos, + "pos": 0, + "to_process": len(packet), + "packet": packet, + "info": info, + } self._out_packet.append(mpkt) @@ -3063,10 +3507,10 @@ def _packet_queue(self, command, packet, mid, qos, info=None): self._call_socket_register_write() - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _packet_handle(self): - cmd = self._in_packet['command'] & 0xF0 + def _packet_handle(self) -> MQTTErrorCode: + cmd = self._in_packet["command"] & 0xF0 if cmd == PINGREQ: return self._handle_pingreq() elif cmd == PINGRESP: @@ -3092,45 +3536,44 @@ def _packet_handle(self): else: # If we don't recognise the command, return an error straight away. self._easy_log(MQTT_LOG_ERR, "Error: Unrecognised command %s", cmd) - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL - def _handle_pingreq(self): - if self._in_packet['remaining_length'] != 0: - return MQTT_ERR_PROTOCOL + def _handle_pingreq(self) -> MQTTErrorCode: + if self._in_packet["remaining_length"] != 0: + return MQTTErrorCode.MQTT_ERR_PROTOCOL self._easy_log(MQTT_LOG_DEBUG, "Received PINGREQ") return self._send_pingresp() - def _handle_pingresp(self): - if self._in_packet['remaining_length'] != 0: - return MQTT_ERR_PROTOCOL + def _handle_pingresp(self) -> MQTTErrorCode: + if self._in_packet["remaining_length"] != 0: + return MQTTErrorCode.MQTT_ERR_PROTOCOL # No longer waiting for a PINGRESP. self._ping_t = 0 self._easy_log(MQTT_LOG_DEBUG, "Received PINGRESP") - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _handle_connack(self): + def _handle_connack(self) -> MQTTErrorCode: if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: - return MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: - return MQTT_ERR_PROTOCOL + if self._in_packet["remaining_length"] < 2: + return MQTTErrorCode.MQTT_ERR_PROTOCOL + elif self._in_packet["remaining_length"] != 2: + return MQTTErrorCode.MQTT_ERR_PROTOCOL if self._protocol == MQTTv5: - (flags, result) = struct.unpack( - "!BB", self._in_packet['packet'][:2]) + (flags, result) = struct.unpack("!BB", self._in_packet["packet"][:2]) if result == 1: # This is probably a failure from a broker that doesn't support # MQTT v5. - reason = 132 # Unsupported protocol version + reason = ReasonCodes(CONNACK >> 4, aName="Unsupported protocol version") properties = None else: reason = ReasonCodes(CONNACK >> 4, identifier=result) properties = Properties(CONNACK >> 4) - properties.unpack(self._in_packet['packet'][2:]) + properties.unpack(self._in_packet["packet"][2:]) else: - (flags, result) = struct.unpack("!BB", self._in_packet['packet']) + (flags, result) = struct.unpack("!BB", self._in_packet["packet"]) if self._protocol == MQTTv311: if result == CONNACK_REFUSED_PROTOCOL_VERSION: if not self._reconnect_on_failure: @@ -3138,21 +3581,24 @@ def _handle_connack(self): self._easy_log( MQTT_LOG_DEBUG, "Received CONNACK (%s, %s), attempting downgrade to MQTT v3.1.", - flags, result + flags, + result, ) # Downgrade to MQTT v3.1 self._protocol = MQTTv31 return self.reconnect() - elif (result == CONNACK_REFUSED_IDENTIFIER_REJECTED - and self._client_id == b''): + elif ( + result == CONNACK_REFUSED_IDENTIFIER_REJECTED and self._client_id == b"" + ): if not self._reconnect_on_failure: return MQTT_ERR_PROTOCOL self._easy_log( MQTT_LOG_DEBUG, "Received CONNACK (%s, %s), attempting to use non-empty CID", - flags, result, + flags, + result, ) - self._client_id = base62(uuid.uuid4().int, padding=22) + self._client_id = base62(uuid.uuid4().int, padding=22).encode("utf8") return self.reconnect() if result == 0: @@ -3161,10 +3607,14 @@ def _handle_connack(self): if self._protocol == MQTTv5: self._easy_log( - MQTT_LOG_DEBUG, "Received CONNACK (%s, %s) properties=%s", flags, reason, properties) + MQTT_LOG_DEBUG, + "Received CONNACK (%s, %s) properties=%s", + flags, + reason, + properties, + ) else: - self._easy_log( - MQTT_LOG_DEBUG, "Received CONNACK (%s, %s)", flags, result) + self._easy_log(MQTT_LOG_DEBUG, "Received CONNACK (%s, %s)", flags, result) # it won't be the first successful connect any more self._mqttv5_first_connect = False @@ -3174,23 +3624,22 @@ def _handle_connack(self): if on_connect: flags_dict = {} - flags_dict['session present'] = flags & 0x01 + flags_dict["session present"] = flags & 0x01 with self._in_callback_mutex: try: if self._protocol == MQTTv5: - on_connect(self, self._userdata, - flags_dict, reason, properties) + on_connect(self, self._userdata, flags_dict, reason, properties) # type: ignore else: - on_connect( - self, self._userdata, flags_dict, result) + on_connect(self, self._userdata, flags_dict, result) # type: ignore except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_connect: %s', err) + MQTT_LOG_ERR, "Caught exception in on_connect: %s", err + ) if not self.suppress_exceptions: raise if result == 0: - rc = 0 + rc = MQTTErrorCode.MQTT_ERR_SUCCESS with self._out_message_mutex: for m in self._out_messages.values(): m.timestamp = time_func() @@ -3202,14 +3651,14 @@ def _handle_connack(self): with self._in_callback_mutex: # Don't call loop_write after _send_publish() rc = self._send_publish( m.mid, - m.topic.encode('utf-8'), + m.topic.encode("utf-8"), m.payload, m.qos, m.retain, m.dup, - properties=m.properties + properties=m.properties, ) - if rc != 0: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc elif m.qos == 1: if m.state == mqtt_ms_publish: @@ -3218,14 +3667,14 @@ def _handle_connack(self): with self._in_callback_mutex: # Don't call loop_write after _send_publish() rc = self._send_publish( m.mid, - m.topic.encode('utf-8'), + m.topic.encode("utf-8"), m.payload, m.qos, m.retain, m.dup, - properties=m.properties + properties=m.properties, ) - if rc != 0: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc elif m.qos == 2: if m.state == mqtt_ms_publish: @@ -3234,53 +3683,51 @@ def _handle_connack(self): with self._in_callback_mutex: # Don't call loop_write after _send_publish() rc = self._send_publish( m.mid, - m.topic.encode('utf-8'), + m.topic.encode("utf-8"), m.payload, m.qos, m.retain, m.dup, - properties=m.properties + properties=m.properties, ) - if rc != 0: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc elif m.state == mqtt_ms_resend_pubrel: self._inflight_messages += 1 m.state = mqtt_ms_wait_for_pubcomp with self._in_callback_mutex: # Don't call loop_write after _send_publish() rc = self._send_pubrel(m.mid) - if rc != 0: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc self.loop_write() # Process outgoing messages that have just been queued up return rc elif result > 0 and result < 6: - return MQTT_ERR_CONN_REFUSED + return MQTTErrorCode.MQTT_ERR_CONN_REFUSED else: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL - def _handle_disconnect(self): + def _handle_disconnect(self) -> "Literal[MQTTErrorCode.MQTT_ERR_SUCCESS]": packet_type = DISCONNECT >> 4 reasonCode = properties = None - if self._in_packet['remaining_length'] > 2: + if self._in_packet["remaining_length"] > 2: reasonCode = ReasonCodes(packet_type) - reasonCode.unpack(self._in_packet['packet']) - if self._in_packet['remaining_length'] > 3: + reasonCode.unpack(self._in_packet["packet"]) + if self._in_packet["remaining_length"] > 3: properties = Properties(packet_type) - props, props_len = properties.unpack( - self._in_packet['packet'][1:]) - self._easy_log(MQTT_LOG_DEBUG, "Received DISCONNECT %s %s", - reasonCode, - properties - ) + props, props_len = properties.unpack(self._in_packet["packet"][1:]) + self._easy_log( + MQTT_LOG_DEBUG, "Received DISCONNECT %s %s", reasonCode, properties + ) self._loop_rc_handle(reasonCode, properties) - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _handle_suback(self): + def _handle_suback(self) -> "Literal[MQTTErrorCode.MQTT_ERR_SUCCESS]": self._easy_log(MQTT_LOG_DEBUG, "Received SUBACK") pack_format = f"!H{len(self._in_packet['packet']) - 2}s" - (mid, packet) = struct.unpack(pack_format, self._in_packet['packet']) + (mid, packet) = struct.unpack(pack_format, self._in_packet["packet"]) if self._protocol == MQTTv5: properties = Properties(SUBACK >> 4) @@ -3299,42 +3746,39 @@ def _handle_suback(self): with self._in_callback_mutex: # Don't call loop_write after _send_publish() try: if self._protocol == MQTTv5: - on_subscribe( - self, self._userdata, mid, reasoncodes, properties) + on_subscribe(self, self._userdata, mid, reasoncodes, properties) # type: ignore else: - on_subscribe( - self, self._userdata, mid, granted_qos) + on_subscribe(self, self._userdata, mid, granted_qos) # type: ignore except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_subscribe: %s', err) + MQTT_LOG_ERR, "Caught exception in on_subscribe: %s", err + ) if not self.suppress_exceptions: raise - return MQTT_ERR_SUCCESS - - def _handle_publish(self): - rc = 0 + return MQTTErrorCode.MQTT_ERR_SUCCESS - header = self._in_packet['command'] + def _handle_publish(self) -> MQTTErrorCode: + header = self._in_packet["command"] message = MQTTMessage() - message.dup = (header & 0x08) >> 3 + message.dup = ((header & 0x08) >> 3) != 0 message.qos = (header & 0x06) >> 1 - message.retain = (header & 0x01) + message.retain = (header & 0x01) != 0 pack_format = f"!H{len(self._in_packet['packet']) - 2}s" - (slen, packet) = struct.unpack(pack_format, self._in_packet['packet']) + (slen, packet) = struct.unpack(pack_format, self._in_packet["packet"]) pack_format = f"!{slen}s{len(packet) - slen}s" (topic, packet) = struct.unpack(pack_format, packet) if self._protocol != MQTTv5 and len(topic) == 0: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL # Handle topics with invalid UTF-8 # This replaces an invalid topic with a message and the hex # representation of the topic for logging. When the user attempts to # access message.topic in the callback, an exception will be raised. try: - print_topic = topic.decode('utf-8') + print_topic = topic.decode("utf-8") except UnicodeDecodeError: print_topic = f"TOPIC WITH INVALID UTF-8: {topic!r}" @@ -3355,29 +3799,37 @@ def _handle_publish(self): self._easy_log( MQTT_LOG_DEBUG, "Received PUBLISH (d%d, q%d, r%d, m%d), '%s', properties=%s, ... (%d bytes)", - message.dup, message.qos, message.retain, message.mid, - print_topic, message.properties, len(message.payload) + message.dup, + message.qos, + message.retain, + message.mid, + print_topic, + message.properties, + len(message.payload), ) else: self._easy_log( MQTT_LOG_DEBUG, "Received PUBLISH (d%d, q%d, r%d, m%d), '%s', ... (%d bytes)", - message.dup, message.qos, message.retain, message.mid, - print_topic, len(message.payload) + message.dup, + message.qos, + message.retain, + message.mid, + print_topic, + len(message.payload), ) message.timestamp = time_func() if message.qos == 0: self._handle_on_message(message) - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS elif message.qos == 1: self._handle_on_message(message) if self._manual_ack: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS else: return self._send_puback(message.mid) elif message.qos == 2: - rc = self._send_pubrec(message.mid) message.state = mqtt_ms_wait_for_pubrel @@ -3386,38 +3838,37 @@ def _handle_publish(self): return rc else: - return MQTT_ERR_PROTOCOL + return MQTTErrorCode.MQTT_ERR_PROTOCOL - def ack(self, mid: int, qos: int) -> int: + def ack(self, mid: int, qos: int) -> MQTTErrorCode: """ - send an acknowledgement for a given message id. (stored in message.mid ) - only useful in QoS=1 and auto_ack=False + send an acknowledgement for a given message id. (stored in message.mid ) + only useful in QoS=1 and auto_ack=False """ - if self._manual_ack : + if self._manual_ack: if qos == 1: return self._send_puback(mid) elif qos == 2: return self._send_pubcomp(mid) - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def manual_ack_set(self, on): + def manual_ack_set(self, on: bool) -> None: """ - The paho library normally acknowledges messages as soon as they are delivered to the caller. - If manual_ack is turned on, then the caller MUST manually acknowledge every message once - application processing is complete. + The paho library normally acknowledges messages as soon as they are delivered to the caller. + If manual_ack is turned on, then the caller MUST manually acknowledge every message once + application processing is complete. """ self._manual_ack = on - - def _handle_pubrel(self): + def _handle_pubrel(self) -> MQTTErrorCode: if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: - return MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: - return MQTT_ERR_PROTOCOL + if self._in_packet["remaining_length"] < 2: + return MQTTErrorCode.MQTT_ERR_PROTOCOL + elif self._in_packet["remaining_length"] != 2: + return MQTTErrorCode.MQTT_ERR_PROTOCOL - mid, = struct.unpack("!H", self._in_packet['packet']) + (mid,) = struct.unpack("!H", self._in_packet["packet"]) self._easy_log(MQTT_LOG_DEBUG, "Received PUBREL (Mid: %d)", mid) with self._in_message_mutex: @@ -3430,7 +3881,7 @@ def _handle_pubrel(self): if self._max_inflight_messages > 0: with self._out_message_mutex: rc = self._update_inflight() - if rc != MQTT_ERR_SUCCESS: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc # FIXME: this should only be done if the message is known @@ -3440,11 +3891,11 @@ def _handle_pubrel(self): # Choose to acknowledge this message (thus losing a message) but # avoid hanging. See #284. if self._manual_ack: - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS else: return self._send_pubcomp(mid) - def _update_inflight(self): + def _update_inflight(self) -> MQTTErrorCode: # Dont lock message_mutex here for m in self._out_messages.values(): if self._inflight_messages < self._max_inflight_messages: @@ -3456,35 +3907,34 @@ def _update_inflight(self): m.state = mqtt_ms_wait_for_pubrec rc = self._send_publish( m.mid, - m.topic.encode('utf-8'), + m.topic.encode("utf-8"), m.payload, m.qos, m.retain, m.dup, properties=m.properties, ) - if rc != 0: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc else: - return MQTT_ERR_SUCCESS - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _handle_pubrec(self): + def _handle_pubrec(self) -> MQTTErrorCode: if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: - return MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: - return MQTT_ERR_PROTOCOL + if self._in_packet["remaining_length"] < 2: + return MQTTErrorCode.MQTT_ERR_PROTOCOL + elif self._in_packet["remaining_length"] != 2: + return MQTTErrorCode.MQTT_ERR_PROTOCOL - mid, = struct.unpack("!H", self._in_packet['packet'][:2]) + (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] > 2: + if self._in_packet["remaining_length"] > 2: reasonCode = ReasonCodes(PUBREC >> 4) - reasonCode.unpack(self._in_packet['packet'][2:]) - if self._in_packet['remaining_length'] > 3: + reasonCode.unpack(self._in_packet["packet"][2:]) + if self._in_packet["remaining_length"] > 3: properties = Properties(PUBREC >> 4) - props, props_len = properties.unpack( - self._in_packet['packet'][3:]) + props, props_len = properties.unpack(self._in_packet["packet"][3:]) self._easy_log(MQTT_LOG_DEBUG, "Received PUBREC (Mid: %d)", mid) with self._out_message_mutex: @@ -3494,25 +3944,29 @@ def _handle_pubrec(self): msg.timestamp = time_func() return self._send_pubrel(mid) - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _handle_unsuback(self): + def _handle_unsuback(self) -> MQTTErrorCode: if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 4: - return MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: - return MQTT_ERR_PROTOCOL + if self._in_packet["remaining_length"] < 4: + return MQTTErrorCode.MQTT_ERR_PROTOCOL + elif self._in_packet["remaining_length"] != 2: + return MQTTErrorCode.MQTT_ERR_PROTOCOL - mid, = struct.unpack("!H", self._in_packet['packet'][:2]) + (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) if self._protocol == MQTTv5: - packet = self._in_packet['packet'][2:] + packet = self._in_packet["packet"][2:] properties = Properties(UNSUBACK >> 4) props, props_len = properties.unpack(packet) - reasoncodes = [] + reasoncodes_list = [] for c in packet[props_len:]: - reasoncodes.append(ReasonCodes(UNSUBACK >> 4, identifier=c)) - if len(reasoncodes) == 1: - reasoncodes = reasoncodes[0] + reasoncodes_list.append(ReasonCodes(UNSUBACK >> 4, identifier=c)) + + reasoncodes: typing.Union[ + ReasonCodes, typing.List[ReasonCodes] + ] = reasoncodes_list + if len(reasoncodes_list) == 1: + reasoncodes = reasoncodes_list[0] self._easy_log(MQTT_LOG_DEBUG, "Received UNSUBACK (Mid: %d)", mid) with self._callback_mutex: @@ -3523,18 +3977,24 @@ def _handle_unsuback(self): try: if self._protocol == MQTTv5: on_unsubscribe( - self, self._userdata, mid, properties, reasoncodes) + self, self._userdata, mid, properties, reasoncodes # type: ignore + ) else: - on_unsubscribe(self, self._userdata, mid) + on_unsubscribe(self, self._userdata, mid) # type: ignore except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_unsubscribe: %s', err) + MQTT_LOG_ERR, "Caught exception in on_unsubscribe: %s", err + ) if not self.suppress_exceptions: raise - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _do_on_disconnect(self, rc, properties=None): + def _do_on_disconnect( + self, + rc: typing.Union[MQTTErrorCode, ReasonCodes], + properties: typing.Optional[Properties] = None, + ) -> None: with self._callback_mutex: on_disconnect = self.on_disconnect @@ -3542,17 +4002,17 @@ def _do_on_disconnect(self, rc, properties=None): with self._in_callback_mutex: try: if self._protocol == MQTTv5: - on_disconnect( - self, self._userdata, rc, properties) + on_disconnect(self, self._userdata, rc, properties) # type: ignore else: - on_disconnect(self, self._userdata, rc) + on_disconnect(self, self._userdata, rc) # type: ignore except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_disconnect: %s', err) + MQTT_LOG_ERR, "Caught exception in on_disconnect: %s", err + ) if not self.suppress_exceptions: raise - def _do_on_publish(self, mid): + def _do_on_publish(self, mid: int) -> MQTTErrorCode: with self._callback_mutex: on_publish = self.on_publish @@ -3562,7 +4022,8 @@ def _do_on_publish(self, mid): on_publish(self, self._userdata, mid) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_publish: %s', err) + MQTT_LOG_ERR, "Caught exception in on_publish: %s", err + ) if not self.suppress_exceptions: raise @@ -3572,28 +4033,29 @@ def _do_on_publish(self, mid): self._inflight_messages -= 1 if self._max_inflight_messages > 0: rc = self._update_inflight() - if rc != MQTT_ERR_SUCCESS: + if rc != MQTTErrorCode.MQTT_ERR_SUCCESS: return rc - return MQTT_ERR_SUCCESS + return MQTTErrorCode.MQTT_ERR_SUCCESS - def _handle_pubackcomp(self, cmd): + def _handle_pubackcomp( + self, cmd: typing.Union["Literal['PUBACK']", "Literal['PUBCOMP']"] + ) -> MQTTErrorCode: if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] < 2: - return MQTT_ERR_PROTOCOL - elif self._in_packet['remaining_length'] != 2: - return MQTT_ERR_PROTOCOL + if self._in_packet["remaining_length"] < 2: + return MQTTErrorCode.MQTT_ERR_PROTOCOL + elif self._in_packet["remaining_length"] != 2: + return MQTTErrorCode.MQTT_ERR_PROTOCOL packet_type = PUBACK if cmd == "PUBACK" else PUBCOMP packet_type = packet_type >> 4 - mid, = struct.unpack("!H", self._in_packet['packet'][:2]) + (mid,) = struct.unpack("!H", self._in_packet["packet"][:2]) if self._protocol == MQTTv5: - if self._in_packet['remaining_length'] > 2: + if self._in_packet["remaining_length"] > 2: reasonCode = ReasonCodes(packet_type) - reasonCode.unpack(self._in_packet['packet'][2:]) - if self._in_packet['remaining_length'] > 3: + reasonCode.unpack(self._in_packet["packet"][2:]) + if self._in_packet["remaining_length"] > 3: properties = Properties(packet_type) - props, props_len = properties.unpack( - self._in_packet['packet'][3:]) + props, props_len = properties.unpack(self._in_packet["packet"][3:]) self._easy_log(MQTT_LOG_DEBUG, "Received %s (Mid: %d)", cmd, mid) with self._out_message_mutex: @@ -3602,10 +4064,9 @@ def _handle_pubackcomp(self, cmd): rc = self._do_on_publish(mid) return rc - return MQTT_ERR_SUCCESS - - def _handle_on_message(self, message): + return MQTTErrorCode.MQTT_ERR_SUCCESS + def _handle_on_message(self, message: MQTTMessage) -> None: try: topic = message.topic except UnicodeDecodeError: @@ -3629,9 +4090,9 @@ def _handle_on_message(self, message): except Exception as err: self._easy_log( MQTT_LOG_ERR, - 'Caught exception in user defined callback function %s: %s', + "Caught exception in user defined callback function %s: %s", callback.__name__, - err + err, ) if not self.suppress_exceptions: raise @@ -3642,12 +4103,12 @@ def _handle_on_message(self, message): on_message(self, self._userdata, message) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_message: %s', err) + MQTT_LOG_ERR, "Caught exception in on_message: %s", err + ) if not self.suppress_exceptions: raise - - def _handle_on_connect_fail(self): + def _handle_on_connect_fail(self) -> None: with self._callback_mutex: on_connect_fail = self.on_connect_fail @@ -3657,12 +4118,13 @@ def _handle_on_connect_fail(self): on_connect_fail(self, self._userdata) except Exception as err: self._easy_log( - MQTT_LOG_ERR, 'Caught exception in on_connect_fail: %s', err) + MQTT_LOG_ERR, "Caught exception in on_connect_fail: %s", err + ) - def _thread_main(self): + def _thread_main(self) -> None: self.loop_forever(retry_first_connection=True) - def _reconnect_wait(self): + def _reconnect_wait(self) -> None: # See reconnect_delay_set for details now = time_func() with self._reconnect_delay_mutex: @@ -3677,18 +4139,22 @@ def _reconnect_wait(self): target_time = now + self._reconnect_delay remaining = target_time - now - while (self._state != mqtt_cs_disconnecting - and not self._thread_terminate - and remaining > 0): - + while ( + self._state != mqtt_cs_disconnecting + and not self._thread_terminate + and remaining > 0 + ): time.sleep(min(remaining, 1)) remaining = target_time - time_func() @staticmethod - def _proxy_is_valid(p): - def check(t, a): - return (socks is not None and - t in set([socks.HTTP, socks.SOCKS4, socks.SOCKS5]) and a) + def _proxy_is_valid(p) -> bool: # type: ignore[no-untyped-def] + def check(t, a) -> bool: # type: ignore[no-untyped-def] + return ( + socks is not None + and t in set([socks.HTTP, socks.SOCKS4, socks.SOCKS5]) + and a + ) if isinstance(p, dict): return check(p.get("proxy_type"), p.get("proxy_addr")) @@ -3697,7 +4163,7 @@ def check(t, a): else: return False - def _get_proxy(self): + def _get_proxy(self) -> typing.Optional[typing.Dict[str, typing.Any]]: if socks is None: return None @@ -3708,8 +4174,10 @@ def _get_proxy(self): # Next, check for an mqtt_proxy environment variable as long as the host # we're trying to connect to isn't listed under the no_proxy environment # variable (matches built-in module urllib's behavior) - if not (hasattr(urllib.request, "proxy_bypass") and - urllib.request.proxy_bypass(self._host)): + if not ( + hasattr(urllib.request, "proxy_bypass") + and urllib.request.proxy_bypass(self._host) + ): env_proxies = urllib.request.getproxies() if "mqtt" in env_proxies: parts = urllib.parse.urlparse(env_proxies["mqtt"]) @@ -3717,14 +4185,14 @@ def _get_proxy(self): proxy = { "proxy_type": socks.HTTP, "proxy_addr": parts.hostname, - "proxy_port": parts.port + "proxy_port": parts.port, } return proxy elif parts.scheme == "socks": proxy = { "proxy_type": socks.SOCKS5, "proxy_addr": parts.hostname, - "proxy_port": parts.port + "proxy_port": parts.port, } return proxy @@ -3732,23 +4200,33 @@ def _get_proxy(self): # a default proxy socks_default = socks.get_default_proxy() if self._proxy_is_valid(socks_default): - proxy_keys = ("proxy_type", "proxy_addr", "proxy_port", - "proxy_rdns", "proxy_username", "proxy_password") + proxy_keys = ( + "proxy_type", + "proxy_addr", + "proxy_port", + "proxy_rdns", + "proxy_username", + "proxy_password", + ) return dict(zip(proxy_keys, socks_default)) # If we didn't find a proxy through any of the above methods, return # None to indicate that the connection should be handled normally return None - def _create_socket_connection(self): + def _create_socket_connection(self) -> _socket.socket: proxy = self._get_proxy() addr = (self._host, self._port) source = (self._bind_address, self._bind_port) if proxy: - return socks.create_connection(addr, timeout=self._connect_timeout, source_address=source, **proxy) + return socks.create_connection( + addr, timeout=self._connect_timeout, source_address=source, **proxy + ) else: - return socket.create_connection(addr, timeout=self._connect_timeout, source_address=source) + return socket.create_connection( + addr, timeout=self._connect_timeout, source_address=source + ) class WebsocketWrapper: @@ -3757,10 +4235,17 @@ class WebsocketWrapper: OPCODE_BINARY = 0x2 OPCODE_CONNCLOSE = 0x8 OPCODE_PING = 0x9 - OPCODE_PONG = 0xa - - def __init__(self, socket, host, port, is_ssl, path, extra_headers): - + OPCODE_PONG = 0xA + + def __init__( + self, + socket: typing.Union[socket.socket, "ssl.SSLSocket"], + host: str, + port: int, + is_ssl: bool, + path: str, + extra_headers: typing.Optional[WebSocketHeaders], + ): self.connected = False self._ssl = is_ssl @@ -3778,13 +4263,11 @@ def __init__(self, socket, host, port, is_ssl, path, extra_headers): self._do_handshake(extra_headers) - def __del__(self): - - self._sendbuffer = None - self._readbuffer = None - - def _do_handshake(self, extra_headers): + def __del__(self) -> None: + self._sendbuffer = bytearray() + self._readbuffer = bytearray() + def _do_handshake(self, extra_headers: typing.Optional[WebSocketHeaders]) -> None: sec_websocket_key = uuid.uuid4().bytes sec_websocket_key = base64.b64encode(sec_websocket_key) @@ -3805,11 +4288,13 @@ def _do_handshake(self, extra_headers): elif callable(extra_headers): websocket_headers = extra_headers(websocket_headers) - header = "\r\n".join([ - f"GET {self._path} HTTP/1.1", - "\r\n".join(f"{i}: {j}" for i, j in websocket_headers.items()), - "\r\n", - ]).encode("utf8") + header = "\r\n".join( + [ + f"GET {self._path} HTTP/1.1", + "\r\n".join(f"{i}: {j}" for i, j in websocket_headers.items()), + "\r\n", + ] + ).encode("utf8") self._socket.send(header) @@ -3826,29 +4311,38 @@ def _do_handshake(self, extra_headers): if byte == b"\n": if len(self._readbuffer) > 2: # check upgrade - if b"connection" in str(self._readbuffer).lower().encode('utf-8'): - if b"upgrade" not in str(self._readbuffer).lower().encode('utf-8'): + if b"connection" in str(self._readbuffer).lower().encode("utf-8"): + if b"upgrade" not in str(self._readbuffer).lower().encode( + "utf-8" + ): raise WebsocketConnectionError( - "WebSocket handshake error, connection not upgraded") + "WebSocket handshake error, connection not upgraded" + ) else: has_upgrade = True # check key hash - if b"sec-websocket-accept" in str(self._readbuffer).lower().encode('utf-8'): + if b"sec-websocket-accept" in str(self._readbuffer).lower().encode( + "utf-8" + ): GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - server_hash = self._readbuffer.decode( - 'utf-8').split(": ", 1)[1] - server_hash = server_hash.strip().encode('utf-8') + server_hash_str = self._readbuffer.decode("utf-8").split( + ": ", 1 + )[1] + server_hash = server_hash_str.strip().encode("utf-8") - client_hash = sec_websocket_key.decode('utf-8') + GUID + client_hash_key = sec_websocket_key.decode("utf-8") + GUID # Use of SHA-1 is OK here; it's according to the Websocket spec. - client_hash = hashlib.sha1(client_hash.encode('utf-8')) # noqa: S324 - client_hash = base64.b64encode(client_hash.digest()) + client_hash_digest = hashlib.sha1( # noqa: S324 + client_hash_key.encode("utf-8") + ) + client_hash = base64.b64encode(client_hash_digest.digest()) if server_hash != client_hash: raise WebsocketConnectionError( - "WebSocket handshake error, invalid secret key") + "WebSocket handshake error, invalid secret key" + ) else: has_secret = True else: @@ -3868,8 +4362,9 @@ def _do_handshake(self, extra_headers): self._readbuffer = bytearray() self.connected = True - def _create_frame(self, opcode, data, do_masking=1): - + def _create_frame( + self, opcode: int, data: bytearray, do_masking: int = 1 + ) -> bytearray: header = bytearray() length = len(data) @@ -3900,12 +4395,10 @@ def _create_frame(self, opcode, data, do_masking=1): return header + data - def _buffered_read(self, length): - + def _buffered_read(self, length: int) -> bytearray: # try to recv and store needed bytes wanted_bytes = length - (len(self._readbuffer) - self._readbuffer_head) if wanted_bytes > 0: - data = self._socket.recv(wanted_bytes) if not data: @@ -3917,16 +4410,14 @@ def _buffered_read(self, length): raise BlockingIOError self._readbuffer_head += length - return self._readbuffer[self._readbuffer_head - length:self._readbuffer_head] - - def _recv_impl(self, length): + return self._readbuffer[self._readbuffer_head - length : self._readbuffer_head] + def _recv_impl(self, length: int) -> bytes: # try to decode websocket payload part from data try: - self._readbuffer_head = 0 - result = None + result = b"" chunk_startindex = self._payload_head chunk_endindex = self._payload_head + length @@ -3934,22 +4425,20 @@ def _recv_impl(self, length): header1 = self._buffered_read(1) header2 = self._buffered_read(1) - opcode = (header1[0] & 0x0f) + opcode = header1[0] & 0x0F maskbit = (header2[0] & 0x80) == 0x80 - lengthbits = (header2[0] & 0x7f) + lengthbits = header2[0] & 0x7F payload_length = lengthbits mask_key = None # read length - if lengthbits == 0x7e: - + if lengthbits == 0x7E: value = self._buffered_read(2) - payload_length, = struct.unpack("!H", value) - - elif lengthbits == 0x7f: + (payload_length,) = struct.unpack("!H", value) + elif lengthbits == 0x7F: value = self._buffered_read(8) - payload_length, = struct.unpack("!Q", value) + (payload_length,) = struct.unpack("!Q", value) # read mask if maskbit: @@ -3965,7 +4454,7 @@ def _recv_impl(self, length): payload = self._buffered_read(readindex) # unmask only the needed part - if maskbit: + if mask_key is not None: for index in range(chunk_startindex, readindex): payload[index] ^= mask_key[index % 4] @@ -3982,33 +4471,33 @@ def _recv_impl(self, length): # respond to non-binary opcodes, their arrival is not guaranteed because of non-blocking sockets if opcode == WebsocketWrapper.OPCODE_CONNCLOSE: frame = self._create_frame( - WebsocketWrapper.OPCODE_CONNCLOSE, payload, 0) + WebsocketWrapper.OPCODE_CONNCLOSE, payload, 0 + ) self._socket.send(frame) if opcode == WebsocketWrapper.OPCODE_PING: - frame = self._create_frame( - WebsocketWrapper.OPCODE_PONG, payload, 0) + frame = self._create_frame(WebsocketWrapper.OPCODE_PONG, payload, 0) self._socket.send(frame) # This isn't *proper* handling of continuation frames, but given # that we only support binary frames, it is *probably* good enough. - if (opcode == WebsocketWrapper.OPCODE_BINARY or opcode == WebsocketWrapper.OPCODE_CONTINUATION) \ - and payload_length > 0: + if ( + opcode == WebsocketWrapper.OPCODE_BINARY + or opcode == WebsocketWrapper.OPCODE_CONTINUATION + ) and payload_length > 0: return result else: raise BlockingIOError except ConnectionError: self.connected = False - return b'' - - def _send_impl(self, data): + return b"" + def _send_impl(self, data: bytes) -> int: # if previous frame was sent successfully if len(self._sendbuffer) == 0: # create websocket frame - frame = self._create_frame( - WebsocketWrapper.OPCODE_BINARY, bytearray(data)) + frame = self._create_frame(WebsocketWrapper.OPCODE_BINARY, bytearray(data)) self._sendbuffer.extend(frame) self._requested_size = len(data) @@ -4024,32 +4513,32 @@ def _send_impl(self, data): # couldn't send whole data, request the same data again with 0 as sent length return 0 - def recv(self, length): + def recv(self, length: int) -> bytes: return self._recv_impl(length) - def read(self, length): + def read(self, length: int) -> bytes: return self._recv_impl(length) - def send(self, data): + def send(self, data: bytes) -> int: return self._send_impl(data) - def write(self, data): + def write(self, data: bytes) -> int: return self._send_impl(data) - def close(self): + def close(self) -> None: self._socket.close() - def fileno(self): + def fileno(self) -> int: return self._socket.fileno() - def pending(self): + def pending(self) -> int: # Fix for bug #131: a SSL socket may still have data available # for reading without select() being aware of it. if self._ssl: - return self._socket.pending() + return self._socket.pending() # type: ignore[union-attr] else: # normal socket rely only on select() return 0 - def setblocking(self, flag): + def setblocking(self, flag: bool) -> None: self._socket.setblocking(flag) diff --git a/src/paho/mqtt/matcher.py b/src/paho/mqtt/matcher.py index b73c13ac..37940a30 100644 --- a/src/paho/mqtt/matcher.py +++ b/src/paho/mqtt/matcher.py @@ -7,7 +7,7 @@ class MQTTMatcher: some topic name.""" class Node: - __slots__ = '_children', '_content' + __slots__ = "_children", "_content" def __init__(self): self._children = {} @@ -20,7 +20,7 @@ def __setitem__(self, key, value): """Add a topic filter :key to the prefix tree and associate it to :value""" node = self._root - for sym in key.split('/'): + for sym in key.split("/"): node = node._children.setdefault(sym, self.Node()) node._content = value @@ -28,7 +28,7 @@ def __getitem__(self, key): """Retrieve the value associated with some topic filter :key""" try: node = self._root - for sym in key.split('/'): + for sym in key.split("/"): node = node._children[sym] if node._content is None: raise KeyError(key) @@ -41,9 +41,9 @@ def __delitem__(self, key): lst = [] try: parent, node = None, self._root - for k in key.split('/'): - parent, node = node, node._children[k] - lst.append((parent, k, node)) + for k in key.split("/"): + parent, node = node, node._children[k] + lst.append((parent, k, node)) # TODO node._content = None except KeyError as ke: @@ -51,14 +51,15 @@ def __delitem__(self, key): else: # cleanup for parent, k, node in reversed(lst): if node._children or node._content is not None: - break + break del parent._children[k] def iter_match(self, topic): """Return an iterator on all values associated with filters that match the :topic""" - lst = topic.split('/') - normal = not topic.startswith('$') + lst = topic.split("/") + normal = not topic.startswith("$") + def rec(node, i=0): if i == len(lst): if node._content is not None: @@ -68,11 +69,12 @@ def rec(node, i=0): if part in node._children: for content in rec(node._children[part], i + 1): yield content - if '+' in node._children and (normal or i > 0): - for content in rec(node._children['+'], i + 1): + if "+" in node._children and (normal or i > 0): + for content in rec(node._children["+"], i + 1): yield content - if '#' in node._children and (normal or i > 0): - content = node._children['#']._content + if "#" in node._children and (normal or i > 0): + content = node._children["#"]._content if content is not None: yield content + return rec(self._root) diff --git a/src/paho/mqtt/packettypes.py b/src/paho/mqtt/packettypes.py index 2fd6a1b5..33143e3f 100644 --- a/src/paho/mqtt/packettypes.py +++ b/src/paho/mqtt/packettypes.py @@ -30,14 +30,42 @@ class PacketTypes: indexes = range(1, 16) # Packet types - CONNECT, CONNACK, PUBLISH, PUBACK, PUBREC, PUBREL, \ - PUBCOMP, SUBSCRIBE, SUBACK, UNSUBSCRIBE, UNSUBACK, \ - PINGREQ, PINGRESP, DISCONNECT, AUTH = indexes + ( + CONNECT, + CONNACK, + PUBLISH, + PUBACK, + PUBREC, + PUBREL, + PUBCOMP, + SUBSCRIBE, + SUBACK, + UNSUBSCRIBE, + UNSUBACK, + PINGREQ, + PINGRESP, + DISCONNECT, + AUTH, + ) = indexes # Dummy packet type for properties use - will delay only applies to will WILLMESSAGE = 99 - Names = [ "reserved", \ - "Connect", "Connack", "Publish", "Puback", "Pubrec", "Pubrel", \ - "Pubcomp", "Subscribe", "Suback", "Unsubscribe", "Unsuback", \ - "Pingreq", "Pingresp", "Disconnect", "Auth"] + Names = [ + "reserved", + "Connect", + "Connack", + "Publish", + "Puback", + "Pubrec", + "Pubrel", + "Pubcomp", + "Subscribe", + "Suback", + "Unsubscribe", + "Unsuback", + "Pingreq", + "Pingresp", + "Disconnect", + "Auth", + ] diff --git a/src/paho/mqtt/properties.py b/src/paho/mqtt/properties.py index e5e19103..d906c432 100644 --- a/src/paho/mqtt/properties.py +++ b/src/paho/mqtt/properties.py @@ -64,17 +64,17 @@ def readUTF(buffer, maxlen): maxlen -= 2 if length > maxlen: raise MalformedPacket("Length delimited string too long") - buf = buffer[2:2+length].decode("utf-8") + buf = buffer[2 : 2 + length].decode("utf-8") # look for chars which are invalid for MQTT - for c in buf: # look for D800-DFFF in the UTF string + for c in buf: # look for D800-DFFF in the UTF string ord_c = ord(c) if ord_c >= 0xD800 and ord_c <= 0xDFFF: raise MalformedPacket("[MQTT-1.5.4-1] D800-DFFF found in UTF-8 data") - if ord_c == 0x00: # look for null in the UTF string + if ord_c == 0x00: # look for null in the UTF string raise MalformedPacket("[MQTT-1.5.4-2] Null found in UTF-8 data") if ord_c == 0xFEFF: raise MalformedPacket("[MQTT-1.5.4-3] U+FEFF in UTF-8 data") - return buf, length+2 + return buf, length + 2 def writeBytes(buffer): @@ -83,7 +83,7 @@ def writeBytes(buffer): def readBytes(buffer): length = readInt16(buffer) - return buffer[2:2+length], length+2 + return buffer[2 : 2 + length], length + 2 class VariableByteIntegers: # Variable Byte Integer @@ -96,12 +96,12 @@ class VariableByteIntegers: # Variable Byte Integer @staticmethod def encode(x): """ - Convert an integer 0 <= x <= 268435455 into multi-byte format. - Returns the buffer convered from the integer. + Convert an integer 0 <= x <= 268435455 into multi-byte format. + Returns the buffer converted from the integer. """ if not 0 <= x <= 268435455: raise ValueError(f"Value {x!r} must be in range 0-268435455") - buffer = b'' + buffer = b"" while 1: digit = x % 128 x //= 128 @@ -115,10 +115,10 @@ def encode(x): @staticmethod def decode(buffer): """ - Get the value of a multi-byte integer from a buffer - Return the value, and the number of bytes used. + Get the value of a multi-byte integer from a buffer + Return the value, and the number of bytes used. - [MQTT-1.5.5-1] the encoded value MUST use the minimum number of bytes necessary to represent the value + [MQTT-1.5.5-1] the encoded value MUST use the minimum number of bytes necessary to represent the value """ multiplier = 1 value = 0 @@ -155,8 +155,15 @@ class Properties: def __init__(self, packetType): self.packetType = packetType - self.types = ["Byte", "Two Byte Integer", "Four Byte Integer", "Variable Byte Integer", - "Binary Data", "UTF-8 Encoded String", "UTF-8 String Pair"] + self.types = [ + "Byte", + "Two Byte Integer", + "Four Byte Integer", + "Variable Byte Integer", + "Binary Data", + "UTF-8 Encoded String", + "UTF-8 String Pair", + ] self.names = { "Payload Format Indicator": 1, @@ -185,54 +192,106 @@ def __init__(self, packetType): "Maximum Packet Size": 39, "Wildcard Subscription Available": 40, "Subscription Identifier Available": 41, - "Shared Subscription Available": 42 + "Shared Subscription Available": 42, } self.properties = { # id: type, packets # payload format indicator - 1: (self.types.index("Byte"), [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE]), - 2: (self.types.index("Four Byte Integer"), [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE]), - 3: (self.types.index("UTF-8 Encoded String"), [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE]), - 8: (self.types.index("UTF-8 Encoded String"), [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE]), - 9: (self.types.index("Binary Data"), [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE]), - 11: (self.types.index("Variable Byte Integer"), - [PacketTypes.PUBLISH, PacketTypes.SUBSCRIBE]), - 17: (self.types.index("Four Byte Integer"), - [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.DISCONNECT]), + 1: ( + self.types.index("Byte"), + [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE], + ), + 2: ( + self.types.index("Four Byte Integer"), + [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE], + ), + 3: ( + self.types.index("UTF-8 Encoded String"), + [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE], + ), + 8: ( + self.types.index("UTF-8 Encoded String"), + [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE], + ), + 9: ( + self.types.index("Binary Data"), + [PacketTypes.PUBLISH, PacketTypes.WILLMESSAGE], + ), + 11: ( + self.types.index("Variable Byte Integer"), + [PacketTypes.PUBLISH, PacketTypes.SUBSCRIBE], + ), + 17: ( + self.types.index("Four Byte Integer"), + [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.DISCONNECT], + ), 18: (self.types.index("UTF-8 Encoded String"), [PacketTypes.CONNACK]), 19: (self.types.index("Two Byte Integer"), [PacketTypes.CONNACK]), - 21: (self.types.index("UTF-8 Encoded String"), - [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.AUTH]), - 22: (self.types.index("Binary Data"), - [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.AUTH]), - 23: (self.types.index("Byte"), - [PacketTypes.CONNECT]), + 21: ( + self.types.index("UTF-8 Encoded String"), + [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.AUTH], + ), + 22: ( + self.types.index("Binary Data"), + [PacketTypes.CONNECT, PacketTypes.CONNACK, PacketTypes.AUTH], + ), + 23: (self.types.index("Byte"), [PacketTypes.CONNECT]), 24: (self.types.index("Four Byte Integer"), [PacketTypes.WILLMESSAGE]), 25: (self.types.index("Byte"), [PacketTypes.CONNECT]), 26: (self.types.index("UTF-8 Encoded String"), [PacketTypes.CONNACK]), - 28: (self.types.index("UTF-8 Encoded String"), - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]), - 31: (self.types.index("UTF-8 Encoded String"), - [PacketTypes.CONNACK, PacketTypes.PUBACK, PacketTypes.PUBREC, - PacketTypes.PUBREL, PacketTypes.PUBCOMP, PacketTypes.SUBACK, - PacketTypes.UNSUBACK, PacketTypes.DISCONNECT, PacketTypes.AUTH]), - 33: (self.types.index("Two Byte Integer"), - [PacketTypes.CONNECT, PacketTypes.CONNACK]), - 34: (self.types.index("Two Byte Integer"), - [PacketTypes.CONNECT, PacketTypes.CONNACK]), + 28: ( + self.types.index("UTF-8 Encoded String"), + [PacketTypes.CONNACK, PacketTypes.DISCONNECT], + ), + 31: ( + self.types.index("UTF-8 Encoded String"), + [ + PacketTypes.CONNACK, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.PUBREL, + PacketTypes.PUBCOMP, + PacketTypes.SUBACK, + PacketTypes.UNSUBACK, + PacketTypes.DISCONNECT, + PacketTypes.AUTH, + ], + ), + 33: ( + self.types.index("Two Byte Integer"), + [PacketTypes.CONNECT, PacketTypes.CONNACK], + ), + 34: ( + self.types.index("Two Byte Integer"), + [PacketTypes.CONNECT, PacketTypes.CONNACK], + ), 35: (self.types.index("Two Byte Integer"), [PacketTypes.PUBLISH]), 36: (self.types.index("Byte"), [PacketTypes.CONNACK]), 37: (self.types.index("Byte"), [PacketTypes.CONNACK]), - 38: (self.types.index("UTF-8 String Pair"), - [PacketTypes.CONNECT, PacketTypes.CONNACK, - PacketTypes.PUBLISH, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.PUBREL, PacketTypes.PUBCOMP, - PacketTypes.SUBSCRIBE, PacketTypes.SUBACK, - PacketTypes.UNSUBSCRIBE, PacketTypes.UNSUBACK, - PacketTypes.DISCONNECT, PacketTypes.AUTH, PacketTypes.WILLMESSAGE]), - 39: (self.types.index("Four Byte Integer"), - [PacketTypes.CONNECT, PacketTypes.CONNACK]), + 38: ( + self.types.index("UTF-8 String Pair"), + [ + PacketTypes.CONNECT, + PacketTypes.CONNACK, + PacketTypes.PUBLISH, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.PUBREL, + PacketTypes.PUBCOMP, + PacketTypes.SUBSCRIBE, + PacketTypes.SUBACK, + PacketTypes.UNSUBSCRIBE, + PacketTypes.UNSUBACK, + PacketTypes.DISCONNECT, + PacketTypes.AUTH, + PacketTypes.WILLMESSAGE, + ], + ), + 39: ( + self.types.index("Four Byte Integer"), + [PacketTypes.CONNECT, PacketTypes.CONNACK], + ), 40: (self.types.index("Byte"), [PacketTypes.CONNACK]), 41: (self.types.index("Byte"), [PacketTypes.CONNACK]), 42: (self.types.index("Byte"), [PacketTypes.CONNACK]), @@ -245,44 +304,50 @@ def getIdentFromName(self, compressedName): # return the identifier corresponding to the property name result = -1 for name in self.names.keys(): - if compressedName == name.replace(' ', ''): + if compressedName == name.replace(" ", ""): result = self.names[name] break return result def __setattr__(self, name, value): - name = name.replace(' ', '') + name = name.replace(" ", "") privateVars = ["packetType", "types", "names", "properties"] if name in privateVars: object.__setattr__(self, name, value) else: # the name could have spaces in, or not. Remove spaces before assignment - if name not in [aname.replace(' ', '') for aname in self.names.keys()]: - raise MQTTException( - f"Property name must be one of {self.names.keys()}") + if name not in [aname.replace(" ", "") for aname in self.names.keys()]: + raise MQTTException(f"Property name must be one of {self.names.keys()}") # check that this attribute applies to the packet type if self.packetType not in self.properties[self.getIdentFromName(name)][1]: - raise MQTTException(f"Property {name} does not apply to packet type {PacketTypes.Names[self.packetType]}") + raise MQTTException( + f"Property {name} does not apply to packet type {PacketTypes.Names[self.packetType]}" + ) # Check for forbidden values if not isinstance(value, list): - if name in ["ReceiveMaximum", "TopicAlias"] \ - and (value < 1 or value > 65535): - - raise MQTTException(f"{name} property value must be in the range 1-65535") - elif name in ["TopicAliasMaximum"] \ - and (value < 0 or value > 65535): - - raise MQTTException(f"{name} property value must be in the range 0-65535") - elif name in ["MaximumPacketSize", "SubscriptionIdentifier"] \ - and (value < 1 or value > 268435455): - - raise MQTTException(f"{name} property value must be in the range 1-268435455") - elif name in ["RequestResponseInformation", "RequestProblemInformation", "PayloadFormatIndicator"] \ - and (value != 0 and value != 1): - + if name in ["ReceiveMaximum", "TopicAlias"] and ( + value < 1 or value > 65535 + ): + raise MQTTException( + f"{name} property value must be in the range 1-65535" + ) + elif name in ["TopicAliasMaximum"] and (value < 0 or value > 65535): + raise MQTTException( + f"{name} property value must be in the range 0-65535" + ) + elif name in ["MaximumPacketSize", "SubscriptionIdentifier"] and ( + value < 1 or value > 268435455 + ): raise MQTTException( - f"{name} property value must be 0 or 1") + f"{name} property value must be in the range 1-268435455" + ) + elif name in [ + "RequestResponseInformation", + "RequestProblemInformation", + "PayloadFormatIndicator", + ] and (value != 0 and value != 1): + raise MQTTException(f"{name} property value must be 0 or 1") if self.allowsMultiple(name): if not isinstance(value, list): @@ -295,7 +360,7 @@ def __str__(self): buffer = "[" first = True for name in self.names.keys(): - compressedName = name.replace(' ', '') + compressedName = name.replace(" ", "") if hasattr(self, compressedName): if not first: buffer += ", " @@ -307,10 +372,10 @@ def __str__(self): def json(self): data = {} for name in self.names.keys(): - compressedName = name.replace(' ', '') + compressedName = name.replace(" ", "") if hasattr(self, compressedName): val = getattr(self, compressedName) - if compressedName == 'CorrelationData' and isinstance(val, bytes): + if compressedName == "CorrelationData" and isinstance(val, bytes): data[compressedName] = val.hex() else: data[compressedName] = val @@ -319,7 +384,7 @@ def json(self): def isEmpty(self): rc = True for name in self.names.keys(): - compressedName = name.replace(' ', '') + compressedName = name.replace(" ", "") if hasattr(self, compressedName): rc = False break @@ -327,7 +392,7 @@ def isEmpty(self): def clear(self): for name in self.names.keys(): - compressedName = name.replace(' ', '') + compressedName = name.replace(" ", "") if hasattr(self, compressedName): delattr(self, compressedName) @@ -354,17 +419,17 @@ def pack(self): # serialize properties into buffer for sending over network buffer = b"" for name in self.names.keys(): - compressedName = name.replace(' ', '') + compressedName = name.replace(" ", "") if hasattr(self, compressedName): identifier = self.getIdentFromName(compressedName) attr_type = self.properties[identifier][0] if self.allowsMultiple(compressedName): for prop in getattr(self, compressedName): - buffer += self.writeProperty(identifier, - attr_type, prop) + buffer += self.writeProperty(identifier, attr_type, prop) else: - buffer += self.writeProperty(identifier, attr_type, - getattr(self, compressedName)) + buffer += self.writeProperty( + identifier, attr_type, getattr(self, compressedName) + ) return VariableByteIntegers.encode(len(buffer)) + buffer def readProperty(self, buffer, type, propslen): @@ -406,18 +471,21 @@ def unpack(self, buffer): propslenleft = propslen while propslenleft > 0: # properties length is 0 if there are none identifier, VBIlen2 = VariableByteIntegers.decode( - buffer) # property identifier + buffer + ) # property identifier buffer = buffer[VBIlen2:] # strip the bytes used by the VBI propslenleft -= VBIlen2 attr_type = self.properties[identifier][0] - value, valuelen = self.readProperty( - buffer, attr_type, propslenleft) + value, valuelen = self.readProperty(buffer, attr_type, propslenleft) buffer = buffer[valuelen:] # strip the bytes used by the value propslenleft -= valuelen propname = self.getNameFromIdent(identifier) - compressedName = propname.replace(' ', '') - if not self.allowsMultiple(compressedName) and hasattr(self, compressedName): + compressedName = propname.replace(" ", "") + if not self.allowsMultiple(compressedName) and hasattr( + self, compressedName + ): raise MQTTException( - f"Property '{property}' must not exist more than once") + f"Property '{property}' must not exist more than once" + ) setattr(self, propname, value) return self, propslen + VBIlen diff --git a/src/paho/mqtt/publish.py b/src/paho/mqtt/publish.py index 38138585..989a27b8 100644 --- a/src/paho/mqtt/publish.py +++ b/src/paho/mqtt/publish.py @@ -20,13 +20,41 @@ """ import collections +import typing from collections.abc import Iterable from .. import mqtt from . import client as paho +if typing.TYPE_CHECKING: + try: + from typing import NotRequired, Required, TypedDict # type: ignore + except ImportError: + from typing_extensions import NotRequired, Required, TypedDict -def _do_publish(client): + class AuthParamater(TypedDict, total=False): + username: "Required[str]" + password: "NotRequired[str]" + + class TLSParamater(TypedDict, total=False): + ca_certs: "Required[str]" + certfile: "NotRequired[str]" + keyfile: "NotRequired[str]" + tls_version: "NotRequired[int]" + ciphers: "NotRequired[str]" + insecure: "NotRequired[bool]" + + class MessageDict(TypedDict, total=False): + topic: "Required[str]" + payload: "NotRequired[paho.PayloadType]" + qos: "NotRequired[int]" + retain: "NotRequired[bool]" + + MessageTuple = typing.Tuple[str, paho.PayloadType, int, bool] + MessagesList = typing.List[typing.Union[MessageDict, MessageTuple]] + + +def _do_publish(client: paho.Client): """Internal function""" message = client._userdata.popleft() @@ -36,12 +64,12 @@ def _do_publish(client): elif isinstance(message, (tuple, list)): client.publish(*message) else: - raise TypeError('message must be a dict, tuple, or list') + raise TypeError("message must be a dict, tuple, or list") def _on_connect(client, userdata, flags, rc): """Internal callback""" - #pylint: disable=invalid-name, unused-argument + # pylint: disable=invalid-name, unused-argument if rc == 0: if len(userdata) > 0: @@ -49,13 +77,19 @@ def _on_connect(client, userdata, flags, rc): else: raise mqtt.MQTTException(paho.connack_string(rc)) -def _on_connect_v5(client, userdata, flags, rc, properties): + +def _on_connect_v5( + client: paho.Client, userdata: "MessagesList", flags, rc, properties +): """Internal v5 callback""" _on_connect(client, userdata, flags, rc) -def _on_publish(client, userdata, mid): + +def _on_publish( + client: paho.Client, userdata: typing.Deque["MessagesList"], mid: int +) -> None: """Internal callback""" - #pylint: disable=unused-argument + # pylint: disable=unused-argument if len(userdata) == 0: client.disconnect() @@ -63,9 +97,19 @@ def _on_publish(client, userdata, mid): _do_publish(client) -def multiple(msgs, hostname="localhost", port=1883, client_id="", keepalive=60, - will=None, auth=None, tls=None, protocol=paho.MQTTv311, - transport="tcp", proxy_args=None): +def multiple( + msgs: "MessagesList", + hostname: str = "localhost", + port: int = 1883, + client_id: str = "", + keepalive: int = 60, + will: typing.Optional["MessageDict"] = None, + auth: typing.Optional["AuthParamater"] = None, + tls: typing.Optional["TLSParamater"] = None, + protocol: int = paho.MQTTv311, + transport: str = "tcp", + proxy_args: typing.Optional[typing.Any] = None, +) -> None: """Publish multiple messages to a broker, then disconnect cleanly. This function creates an MQTT client, connects to a broker and publishes a @@ -129,37 +173,42 @@ def multiple(msgs, hostname="localhost", port=1883, client_id="", keepalive=60, """ if not isinstance(msgs, Iterable): - raise TypeError('msgs must be an iterable') + raise TypeError("msgs must be an iterable") - - client = paho.Client(client_id=client_id, userdata=collections.deque(msgs), - protocol=protocol, transport=transport) + client = paho.Client( + client_id=client_id, + userdata=collections.deque(msgs), + protocol=protocol, + transport=transport, + ) client.on_publish = _on_publish if protocol == mqtt.client.MQTTv5: - client.on_connect = _on_connect_v5 + client.on_connect = _on_connect_v5 # type: ignore else: - client.on_connect = _on_connect + client.on_connect = _on_connect # type: ignore if proxy_args is not None: client.proxy_set(**proxy_args) if auth: - username = auth.get('username') + username = auth.get("username") if username: - password = auth.get('password') + password = auth.get("password") client.username_pw_set(username, password) else: - raise KeyError("The 'username' key was not found, this is " - "required for auth") + raise KeyError( + "The 'username' key was not found, this is " "required for auth" + ) if will is not None: client.will_set(**will) if tls is not None: if isinstance(tls, dict): - insecure = tls.pop('insecure', False) - client.tls_set(**tls) + insecure = tls.pop("insecure", False) + # mypy don't get the tls no longer contains the key insecure + client.tls_set(**tls) # type: ignore[misc] if insecure: # Must be set *after* the `client.tls_set()` call since it sets # up the SSL context that `client.tls_insecure_set` alters. @@ -172,9 +221,22 @@ def multiple(msgs, hostname="localhost", port=1883, client_id="", keepalive=60, client.loop_forever() -def single(topic, payload=None, qos=0, retain=False, hostname="localhost", - port=1883, client_id="", keepalive=60, will=None, auth=None, - tls=None, protocol=paho.MQTTv311, transport="tcp", proxy_args=None): +def single( + topic: str, + payload: paho.PayloadType = None, + qos: int = 0, + retain: bool = False, + hostname: str = "localhost", + port: int = 1883, + client_id: str = "", + keepalive: int = 60, + will: typing.Optional["MessageDict"] = None, + auth: typing.Optional["AuthParamater"] = None, + tls: typing.Optional["TLSParamater"] = None, + protocol: int = paho.MQTTv311, + transport: str = "tcp", + proxy_args: typing.Optional[typing.Any] = None, +) -> None: """Publish a single message to a broker, then disconnect cleanly. This function creates an MQTT client, connects to a broker and publishes a @@ -230,7 +292,23 @@ def single(topic, payload=None, qos=0, retain=False, hostname="localhost", proxy_args: a dictionary that will be given to the client. """ - msg = {'topic':topic, 'payload':payload, 'qos':qos, 'retain':retain} - - multiple([msg], hostname, port, client_id, keepalive, will, auth, tls, - protocol, transport, proxy_args) + msg: "MessageDict" = { + "topic": topic, + "payload": payload, + "qos": qos, + "retain": retain, + } + + multiple( + [msg], + hostname, + port, + client_id, + keepalive, + will, + auth, + tls, + protocol, + transport, + proxy_args, + ) diff --git a/src/paho/mqtt/reasoncodes.py b/src/paho/mqtt/reasoncodes.py index 69a313f7..9fb8ab58 100644 --- a/src/paho/mqtt/reasoncodes.py +++ b/src/paho/mqtt/reasoncodes.py @@ -43,80 +43,151 @@ def __init__(self, packetType, aName="Success", identifier=-1): self.packetType = packetType self.names = { - 0: {"Success": [PacketTypes.CONNACK, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.PUBREL, PacketTypes.PUBCOMP, - PacketTypes.UNSUBACK, PacketTypes.AUTH], + 0: { + "Success": [ + PacketTypes.CONNACK, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.PUBREL, + PacketTypes.PUBCOMP, + PacketTypes.UNSUBACK, + PacketTypes.AUTH, + ], "Normal disconnection": [PacketTypes.DISCONNECT], - "Granted QoS 0": [PacketTypes.SUBACK]}, + "Granted QoS 0": [PacketTypes.SUBACK], + }, 1: {"Granted QoS 1": [PacketTypes.SUBACK]}, 2: {"Granted QoS 2": [PacketTypes.SUBACK]}, 4: {"Disconnect with will message": [PacketTypes.DISCONNECT]}, - 16: {"No matching subscribers": - [PacketTypes.PUBACK, PacketTypes.PUBREC]}, + 16: {"No matching subscribers": [PacketTypes.PUBACK, PacketTypes.PUBREC]}, 17: {"No subscription found": [PacketTypes.UNSUBACK]}, 24: {"Continue authentication": [PacketTypes.AUTH]}, 25: {"Re-authenticate": [PacketTypes.AUTH]}, - 128: {"Unspecified error": [PacketTypes.CONNACK, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.SUBACK, PacketTypes.UNSUBACK, - PacketTypes.DISCONNECT], }, - 129: {"Malformed packet": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 130: {"Protocol error": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 131: {"Implementation specific error": [PacketTypes.CONNACK, - PacketTypes.PUBACK, PacketTypes.PUBREC, PacketTypes.SUBACK, - PacketTypes.UNSUBACK, PacketTypes.DISCONNECT], }, + 128: { + "Unspecified error": [ + PacketTypes.CONNACK, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.SUBACK, + PacketTypes.UNSUBACK, + PacketTypes.DISCONNECT, + ], + }, + 129: {"Malformed packet": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 130: {"Protocol error": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 131: { + "Implementation specific error": [ + PacketTypes.CONNACK, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.SUBACK, + PacketTypes.UNSUBACK, + PacketTypes.DISCONNECT, + ], + }, 132: {"Unsupported protocol version": [PacketTypes.CONNACK]}, 133: {"Client identifier not valid": [PacketTypes.CONNACK]}, 134: {"Bad user name or password": [PacketTypes.CONNACK]}, - 135: {"Not authorized": [PacketTypes.CONNACK, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.SUBACK, PacketTypes.UNSUBACK, - PacketTypes.DISCONNECT], }, + 135: { + "Not authorized": [ + PacketTypes.CONNACK, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.SUBACK, + PacketTypes.UNSUBACK, + PacketTypes.DISCONNECT, + ], + }, 136: {"Server unavailable": [PacketTypes.CONNACK]}, 137: {"Server busy": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, 138: {"Banned": [PacketTypes.CONNACK]}, 139: {"Server shutting down": [PacketTypes.DISCONNECT]}, - 140: {"Bad authentication method": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 140: { + "Bad authentication method": [ + PacketTypes.CONNACK, + PacketTypes.DISCONNECT, + ] + }, 141: {"Keep alive timeout": [PacketTypes.DISCONNECT]}, 142: {"Session taken over": [PacketTypes.DISCONNECT]}, - 143: {"Topic filter invalid": - [PacketTypes.SUBACK, PacketTypes.UNSUBACK, PacketTypes.DISCONNECT]}, - 144: {"Topic name invalid": - [PacketTypes.CONNACK, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.DISCONNECT]}, - 145: {"Packet identifier in use": - [PacketTypes.PUBACK, PacketTypes.PUBREC, - PacketTypes.SUBACK, PacketTypes.UNSUBACK]}, - 146: {"Packet identifier not found": - [PacketTypes.PUBREL, PacketTypes.PUBCOMP]}, + 143: { + "Topic filter invalid": [ + PacketTypes.SUBACK, + PacketTypes.UNSUBACK, + PacketTypes.DISCONNECT, + ] + }, + 144: { + "Topic name invalid": [ + PacketTypes.CONNACK, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.DISCONNECT, + ] + }, + 145: { + "Packet identifier in use": [ + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.SUBACK, + PacketTypes.UNSUBACK, + ] + }, + 146: { + "Packet identifier not found": [PacketTypes.PUBREL, PacketTypes.PUBCOMP] + }, 147: {"Receive maximum exceeded": [PacketTypes.DISCONNECT]}, 148: {"Topic alias invalid": [PacketTypes.DISCONNECT]}, 149: {"Packet too large": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, 150: {"Message rate too high": [PacketTypes.DISCONNECT]}, - 151: {"Quota exceeded": [PacketTypes.CONNACK, PacketTypes.PUBACK, - PacketTypes.PUBREC, PacketTypes.SUBACK, PacketTypes.DISCONNECT], }, + 151: { + "Quota exceeded": [ + PacketTypes.CONNACK, + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.SUBACK, + PacketTypes.DISCONNECT, + ], + }, 152: {"Administrative action": [PacketTypes.DISCONNECT]}, - 153: {"Payload format invalid": - [PacketTypes.PUBACK, PacketTypes.PUBREC, PacketTypes.DISCONNECT]}, - 154: {"Retain not supported": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 155: {"QoS not supported": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 156: {"Use another server": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 157: {"Server moved": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 158: {"Shared subscription not supported": - [PacketTypes.SUBACK, PacketTypes.DISCONNECT]}, - 159: {"Connection rate exceeded": - [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, - 160: {"Maximum connect time": - [PacketTypes.DISCONNECT]}, - 161: {"Subscription identifiers not supported": - [PacketTypes.SUBACK, PacketTypes.DISCONNECT]}, - 162: {"Wildcard subscription not supported": - [PacketTypes.SUBACK, PacketTypes.DISCONNECT]}, + 153: { + "Payload format invalid": [ + PacketTypes.PUBACK, + PacketTypes.PUBREC, + PacketTypes.DISCONNECT, + ] + }, + 154: { + "Retain not supported": [PacketTypes.CONNACK, PacketTypes.DISCONNECT] + }, + 155: {"QoS not supported": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 156: {"Use another server": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 157: {"Server moved": [PacketTypes.CONNACK, PacketTypes.DISCONNECT]}, + 158: { + "Shared subscription not supported": [ + PacketTypes.SUBACK, + PacketTypes.DISCONNECT, + ] + }, + 159: { + "Connection rate exceeded": [ + PacketTypes.CONNACK, + PacketTypes.DISCONNECT, + ] + }, + 160: {"Maximum connect time": [PacketTypes.DISCONNECT]}, + 161: { + "Subscription identifiers not supported": [ + PacketTypes.SUBACK, + PacketTypes.DISCONNECT, + ] + }, + 162: { + "Wildcard subscription not supported": [ + PacketTypes.SUBACK, + PacketTypes.DISCONNECT, + ] + }, } if identifier == -1: if packetType == PacketTypes.DISCONNECT and aName == "Success": @@ -165,8 +236,7 @@ def unpack(self, buffer): return 1 def getName(self): - """Returns the reason code name corresponding to the numeric value which is set. - """ + """Returns the reason code name corresponding to the numeric value which is set.""" return self.__getName__(self.packetType, self.value) def __eq__(self, other): diff --git a/src/paho/mqtt/subscribe.py b/src/paho/mqtt/subscribe.py index 955dfa13..e9370e17 100644 --- a/src/paho/mqtt/subscribe.py +++ b/src/paho/mqtt/subscribe.py @@ -28,11 +28,12 @@ def _on_connect_v5(client, userdata, flags, rc, properties): if rc != 0: raise mqtt.MQTTException(paho.connack_string(rc)) - if isinstance(userdata['topics'], list): - for topic in userdata['topics']: - client.subscribe(topic, userdata['qos']) + if isinstance(userdata["topics"], list): + for topic in userdata["topics"]: + client.subscribe(topic, userdata["qos"]) else: - client.subscribe(userdata['topics'], userdata['qos']) + client.subscribe(userdata["topics"], userdata["qos"]) + def _on_connect(client, userdata, flags, rc): """Internal v5 callback""" @@ -41,35 +42,48 @@ def _on_connect(client, userdata, flags, rc): def _on_message_callback(client, userdata, message): """Internal callback""" - userdata['callback'](client, userdata['userdata'], message) + userdata["callback"](client, userdata["userdata"], message) def _on_message_simple(client, userdata, message): """Internal callback""" - if userdata['msg_count'] == 0: + if userdata["msg_count"] == 0: return # Don't process stale retained messages if 'retained' was false - if message.retain and not userdata['retained']: + if message.retain and not userdata["retained"]: return - userdata['msg_count'] = userdata['msg_count'] - 1 + userdata["msg_count"] = userdata["msg_count"] - 1 - if userdata['messages'] is None and userdata['msg_count'] == 0: - userdata['messages'] = message + if userdata["messages"] is None and userdata["msg_count"] == 0: + userdata["messages"] = message client.disconnect() return - userdata['messages'].append(message) - if userdata['msg_count'] == 0: + userdata["messages"].append(message) + if userdata["msg_count"] == 0: client.disconnect() -def callback(callback, topics, qos=0, userdata=None, hostname="localhost", - port=1883, client_id="", keepalive=60, will=None, auth=None, - tls=None, protocol=paho.MQTTv311, transport="tcp", - clean_session=True, proxy_args=None): +def callback( + callback, + topics, + qos=0, + userdata=None, + hostname="localhost", + port=1883, + client_id="", + keepalive=60, + will=None, + auth=None, + tls=None, + protocol=paho.MQTTv311, + transport="tcp", + clean_session=True, + proxy_args=None, +): """Subscribe to a list of topics and process them in a callback function. This function creates an MQTT client, connects to a broker and subscribes @@ -134,17 +148,22 @@ def callback(callback, topics, qos=0, userdata=None, hostname="localhost", """ if qos < 0 or qos > 2: - raise ValueError('qos must be in the range 0-2') + raise ValueError("qos must be in the range 0-2") callback_userdata = { - 'callback':callback, - 'topics':topics, - 'qos':qos, - 'userdata':userdata} - - client = paho.Client(client_id=client_id, userdata=callback_userdata, - protocol=protocol, transport=transport, - clean_session=clean_session) + "callback": callback, + "topics": topics, + "qos": qos, + "userdata": userdata, + } + + client = paho.Client( + client_id=client_id, + userdata=callback_userdata, + protocol=protocol, + transport=transport, + clean_session=clean_session, + ) client.on_message = _on_message_callback if protocol == mqtt.client.MQTTv5: client.on_connect = _on_connect_v5 @@ -155,20 +174,21 @@ def callback(callback, topics, qos=0, userdata=None, hostname="localhost", client.proxy_set(**proxy_args) if auth: - username = auth.get('username') + username = auth.get("username") if username: - password = auth.get('password') + password = auth.get("password") client.username_pw_set(username, password) else: - raise KeyError("The 'username' key was not found, this is " - "required for auth") + raise KeyError( + "The 'username' key was not found, this is " "required for auth" + ) if will is not None: client.will_set(**will) if tls is not None: if isinstance(tls, dict): - insecure = tls.pop('insecure', False) + insecure = tls.pop("insecure", False) client.tls_set(**tls) if insecure: # Must be set *after* the `client.tls_set()` call since it sets @@ -182,10 +202,23 @@ def callback(callback, topics, qos=0, userdata=None, hostname="localhost", client.loop_forever() -def simple(topics, qos=0, msg_count=1, retained=True, hostname="localhost", - port=1883, client_id="", keepalive=60, will=None, auth=None, - tls=None, protocol=paho.MQTTv311, transport="tcp", - clean_session=True, proxy_args=None): +def simple( + topics, + qos=0, + msg_count=1, + retained=True, + hostname="localhost", + port=1883, + client_id="", + keepalive=60, + will=None, + auth=None, + tls=None, + protocol=paho.MQTTv311, + transport="tcp", + clean_session=True, + proxy_args=None, +): """Subscribe to a list of topics and return msg_count messages. This function creates an MQTT client, connects to a broker and subscribes @@ -251,14 +284,14 @@ def simple(topics, qos=0, msg_count=1, retained=True, hostname="localhost", when it disconnects. If False, the client is a persistent client and subscription information and queued messages will be retained when the client disconnects. - Defaults to True. If protocoll is MQTTv50, clean_session + Defaults to True. If protocol is MQTTv50, clean_session is ignored. proxy_args: a dictionary that will be given to the client. """ if msg_count < 1: - raise ValueError('msg_count must be > 0') + raise ValueError("msg_count must be > 0") # Set ourselves up to return a single message if msg_count == 1, or a list # if > 1. @@ -271,10 +304,24 @@ def simple(topics, qos=0, msg_count=1, retained=True, hostname="localhost", if protocol == paho.MQTTv5: clean_session = None - userdata = {'retained':retained, 'msg_count':msg_count, 'messages':messages} - - callback(_on_message_simple, topics, qos, userdata, hostname, port, - client_id, keepalive, will, auth, tls, protocol, transport, - clean_session, proxy_args) - - return userdata['messages'] + userdata = {"retained": retained, "msg_count": msg_count, "messages": messages} + + callback( + _on_message_simple, + topics, + qos, + userdata, + hostname, + port, + client_id, + keepalive, + will, + auth, + tls, + protocol, + transport, + clean_session, + proxy_args, + ) + + return userdata["messages"] diff --git a/src/paho/mqtt/subscribeoptions.py b/src/paho/mqtt/subscribeoptions.py index f56973ce..c4a76ca3 100644 --- a/src/paho/mqtt/subscribeoptions.py +++ b/src/paho/mqtt/subscribeoptions.py @@ -17,7 +17,6 @@ """ - class MQTTException(Exception): pass @@ -38,10 +37,15 @@ class SubscribeOptions: """ # retain handling options - RETAIN_SEND_ON_SUBSCRIBE, RETAIN_SEND_IF_NEW_SUB, RETAIN_DO_NOT_SEND = range( - 0, 3) - - def __init__(self, qos=0, noLocal=False, retainAsPublished=False, retainHandling=RETAIN_SEND_ON_SUBSCRIBE): + RETAIN_SEND_ON_SUBSCRIBE, RETAIN_SEND_IF_NEW_SUB, RETAIN_DO_NOT_SEND = range(0, 3) + + def __init__( + self, + qos: int = 0, + noLocal: bool = False, + retainAsPublished: bool = False, + retainHandling: int = RETAIN_SEND_ON_SUBSCRIBE, + ): """ qos: 0, 1 or 2. 0 is the default. noLocal: True or False. False is the default and corresponds to MQTT v3.1.1 behavior. @@ -49,42 +53,52 @@ def __init__(self, qos=0, noLocal=False, retainAsPublished=False, retainHandling retainHandling: RETAIN_SEND_ON_SUBSCRIBE, RETAIN_SEND_IF_NEW_SUB or RETAIN_DO_NOT_SEND RETAIN_SEND_ON_SUBSCRIBE is the default and corresponds to MQTT v3.1.1 behavior. """ - object.__setattr__(self, "names", - ["QoS", "noLocal", "retainAsPublished", "retainHandling"]) + object.__setattr__( + self, "names", ["QoS", "noLocal", "retainAsPublished", "retainHandling"] + ) self.QoS = qos # bits 0,1 self.noLocal = noLocal # bit 2 self.retainAsPublished = retainAsPublished # bit 3 self.retainHandling = retainHandling # bits 4 and 5: 0, 1 or 2 if self.retainHandling not in (0, 1, 2): - raise AssertionError(f"Retain handling should be 0, 1 or 2, not {self.retainHandling}") + raise AssertionError( + f"Retain handling should be 0, 1 or 2, not {self.retainHandling}" + ) if self.QoS not in (0, 1, 2): raise AssertionError(f"QoS should be 0, 1 or 2, not {self.QoS}") def __setattr__(self, name, value): if name not in self.names: - raise MQTTException( - f"{name} Attribute name must be one of {self.names}") + raise MQTTException(f"{name} Attribute name must be one of {self.names}") object.__setattr__(self, name, value) def pack(self): if self.retainHandling not in (0, 1, 2): - raise AssertionError(f"Retain handling should be 0, 1 or 2, not {self.retainHandling}") + raise AssertionError( + f"Retain handling should be 0, 1 or 2, not {self.retainHandling}" + ) if self.QoS not in (0, 1, 2): raise AssertionError(f"QoS should be 0, 1 or 2, not {self.QoS}") noLocal = 1 if self.noLocal else 0 retainAsPublished = 1 if self.retainAsPublished else 0 - data = [(self.retainHandling << 4) | (retainAsPublished << 3) | - (noLocal << 2) | self.QoS] + data = [ + (self.retainHandling << 4) + | (retainAsPublished << 3) + | (noLocal << 2) + | self.QoS + ] return bytes(data) def unpack(self, buffer): b0 = buffer[0] - self.retainHandling = ((b0 >> 4) & 0x03) + self.retainHandling = (b0 >> 4) & 0x03 self.retainAsPublished = True if ((b0 >> 3) & 0x01) == 1 else False self.noLocal = True if ((b0 >> 2) & 0x01) == 1 else False - self.QoS = (b0 & 0x03) + self.QoS = b0 & 0x03 if self.retainHandling not in (0, 1, 2): - raise AssertionError(f"Retain handling should be 0, 1 or 2, not {self.retainHandling}") + raise AssertionError( + f"Retain handling should be 0, 1 or 2, not {self.retainHandling}" + ) if self.QoS not in (0, 1, 2): raise AssertionError(f"QoS should be 0, 1 or 2, not {self.QoS}") return 1 @@ -93,9 +107,17 @@ def __repr__(self): return str(self) def __str__(self): - return "{QoS="+str(self.QoS)+", noLocal="+str(self.noLocal) +\ - ", retainAsPublished="+str(self.retainAsPublished) +\ - ", retainHandling="+str(self.retainHandling)+"}" + return ( + "{QoS=" + + str(self.QoS) + + ", noLocal=" + + str(self.noLocal) + + ", retainAsPublished=" + + str(self.retainAsPublished) + + ", retainHandling=" + + str(self.retainHandling) + + "}" + ) def json(self): data = { diff --git a/tests/debug_helpers.py b/tests/debug_helpers.py index 54b96368..38f21ac9 100644 --- a/tests/debug_helpers.py +++ b/tests/debug_helpers.py @@ -8,7 +8,7 @@ def dump_packet(prefix: str, data: bytes) -> None: data = to_string(data) print(prefix, ": ", data, sep="") except struct.error: - data = binascii.b2a_hex(data).decode('utf8') + data = binascii.b2a_hex(data).decode("utf8") print(prefix, " (not decoded): 0x", data, sep="") @@ -23,7 +23,7 @@ def remaining_length(packet: bytes) -> Tuple[bytes, int]: rl += (byte & 127) * mult mult *= 128 if byte & 128 == 0: - packet = packet[i + 1:] + packet = packet[i + 1 :] break return (packet, rl) @@ -36,7 +36,7 @@ def to_hex_string(packet: bytes) -> str: s = "" while len(packet) > 0: packet0 = struct.unpack("!B", packet[0]) - s = s+hex(packet0[0]) + " " + s = s + hex(packet0[0]) + " " packet = packet[1:] return s @@ -46,7 +46,7 @@ def to_string(packet: bytes) -> str: if not packet: return "" - packet0 = struct.unpack("!B%ds" % (len(packet)-1), bytes(packet)) + packet0 = struct.unpack("!B%ds" % (len(packet) - 1), bytes(packet)) packet0 = packet0[0] cmd = packet0 & 0xF0 if cmd == 0x00: @@ -55,29 +55,31 @@ def to_string(packet: bytes) -> str: elif cmd == 0x10: # CONNECT (packet, rl) = remaining_length(packet) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 'sBBH' + str(len(packet) - slen - 4) + 's' - (protocol, proto_ver, flags, keepalive, packet) = struct.unpack(pack_format, packet) - kind = ("clean-session" if flags & 2 else "durable") + pack_format = "!" + str(slen) + "sBBH" + str(len(packet) - slen - 4) + "s" + (protocol, proto_ver, flags, keepalive, packet) = struct.unpack( + pack_format, packet + ) + kind = "clean-session" if flags & 2 else "durable" s = f"CONNECT, proto={protocol}{proto_ver}, keepalive={keepalive}, {kind}" - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 's' + str(len(packet) - slen) + 's' + pack_format = "!" + str(slen) + "s" + str(len(packet) - slen) + "s" (client_id, packet) = struct.unpack(pack_format, packet) s = s + ", id=" + str(client_id) if flags & 4: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 's' + str(len(packet) - slen) + 's' + pack_format = "!" + str(slen) + "s" + str(len(packet) - slen) + "s" (will_topic, packet) = struct.unpack(pack_format, packet) s = s + ", will-topic=" + str(will_topic) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 's' + str(len(packet) - slen) + 's' + pack_format = "!" + str(slen) + "s" + str(len(packet) - slen) + "s" (will_message, packet) = struct.unpack(pack_format, packet) s = s + ", will-message=" + will_message @@ -85,16 +87,16 @@ def to_string(packet: bytes) -> str: s = s + ", will-retain=" + str((flags & 32) >> 5) if flags & 128: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 's' + str(len(packet) - slen) + 's' + pack_format = "!" + str(slen) + "s" + str(len(packet) - slen) + "s" (username, packet) = struct.unpack(pack_format, packet) s = s + ", username=" + str(username) if flags & 64: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (slen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(slen) + 's' + str(len(packet) - slen) + 's' + pack_format = "!" + str(slen) + "s" + str(len(packet) - slen) + "s" (password, packet) = struct.unpack(pack_format, packet) s = s + ", password=" + str(password) @@ -105,11 +107,20 @@ def to_string(packet: bytes) -> str: elif cmd == 0x20: # CONNACK if len(packet) == 4: - (cmd, rl, resv, rc) = struct.unpack('!BBBB', packet) - return "CONNACK, rl="+str(rl)+", res="+str(resv)+", rc="+str(rc) + (cmd, rl, resv, rc) = struct.unpack("!BBBB", packet) + return "CONNACK, rl=" + str(rl) + ", res=" + str(resv) + ", rc=" + str(rc) elif len(packet) == 5: - (cmd, rl, flags, reason_code, proplen) = struct.unpack('!BBBBB', packet) - return "CONNACK, rl="+str(rl)+", flags="+str(flags)+", rc="+str(reason_code)+", proplen="+str(proplen) + (cmd, rl, flags, reason_code, proplen) = struct.unpack("!BBBBB", packet) + return ( + "CONNACK, rl=" + + str(rl) + + ", flags=" + + str(flags) + + ", rc=" + + str(reason_code) + + ", proplen=" + + str(proplen) + ) else: return "CONNACK, (not decoded)" @@ -117,15 +128,26 @@ def to_string(packet: bytes) -> str: # PUBLISH dup = (packet0 & 0x08) >> 3 qos = (packet0 & 0x06) >> 1 - retain = (packet0 & 0x01) + retain = packet0 & 0x01 (packet, rl) = remaining_length(packet) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (tlen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(tlen) + 's' + str(len(packet) - tlen) + 's' + pack_format = "!" + str(tlen) + "s" + str(len(packet) - tlen) + "s" (topic, packet) = struct.unpack(pack_format, packet) - s = "PUBLISH, rl=" + str(rl) + ", topic=" + str(topic) + ", qos=" + str(qos) + ", retain=" + str(retain) + ", dup=" + str(dup) + s = ( + "PUBLISH, rl=" + + str(rl) + + ", topic=" + + str(topic) + + ", qos=" + + str(qos) + + ", retain=" + + str(retain) + + ", dup=" + + str(dup) + ) if qos > 0: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (mid, packet) = struct.unpack(pack_format, packet) s = s + ", mid=" + str(mid) @@ -134,90 +156,111 @@ def to_string(packet: bytes) -> str: elif cmd == 0x40: # PUBACK if len(packet) == 5: - (cmd, rl, mid, reason_code) = struct.unpack('!BBHB', packet) - return "PUBACK, rl="+str(rl)+", mid="+str(mid)+", reason_code="+str(reason_code) + (cmd, rl, mid, reason_code) = struct.unpack("!BBHB", packet) + return ( + "PUBACK, rl=" + + str(rl) + + ", mid=" + + str(mid) + + ", reason_code=" + + str(reason_code) + ) else: - (cmd, rl, mid) = struct.unpack('!BBH', packet) - return "PUBACK, rl="+str(rl)+", mid="+str(mid) + (cmd, rl, mid) = struct.unpack("!BBH", packet) + return "PUBACK, rl=" + str(rl) + ", mid=" + str(mid) elif cmd == 0x50: # PUBREC if len(packet) == 5: - (cmd, rl, mid, reason_code) = struct.unpack('!BBHB', packet) - return "PUBREC, rl="+str(rl)+", mid="+str(mid)+", reason_code="+str(reason_code) + (cmd, rl, mid, reason_code) = struct.unpack("!BBHB", packet) + return ( + "PUBREC, rl=" + + str(rl) + + ", mid=" + + str(mid) + + ", reason_code=" + + str(reason_code) + ) else: - (cmd, rl, mid) = struct.unpack('!BBH', packet) - return "PUBREC, rl="+str(rl)+", mid="+str(mid) + (cmd, rl, mid) = struct.unpack("!BBH", packet) + return "PUBREC, rl=" + str(rl) + ", mid=" + str(mid) elif cmd == 0x60: # PUBREL dup = (packet0 & 0x08) >> 3 - (cmd, rl, mid) = struct.unpack('!BBH', packet) + (cmd, rl, mid) = struct.unpack("!BBH", packet) return "PUBREL, rl=" + str(rl) + ", mid=" + str(mid) + ", dup=" + str(dup) elif cmd == 0x70: # PUBCOMP - (cmd, rl, mid) = struct.unpack('!BBH', packet) + (cmd, rl, mid) = struct.unpack("!BBH", packet) return "PUBCOMP, rl=" + str(rl) + ", mid=" + str(mid) elif cmd == 0x80: # SUBSCRIBE (packet, rl) = remaining_length(packet) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (mid, packet) = struct.unpack(pack_format, packet) s = "SUBSCRIBE, rl=" + str(rl) + ", mid=" + str(mid) topic_index = 0 while len(packet) > 0: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (tlen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(tlen) + 'sB' + str(len(packet) - tlen - 1) + 's' + pack_format = "!" + str(tlen) + "sB" + str(len(packet) - tlen - 1) + "s" (topic, qos, packet) = struct.unpack(pack_format, packet) s = s + ", topic" + str(topic_index) + "=" + str(topic) + "," + str(qos) return s elif cmd == 0x90: # SUBACK (packet, rl) = remaining_length(packet) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (mid, packet) = struct.unpack(pack_format, packet) pack_format = "!" + "B" * len(packet) granted_qos = struct.unpack(pack_format, packet) - s = "SUBACK, rl=" + str(rl) + ", mid=" + str(mid) + ", granted_qos=" + str(granted_qos[0]) + s = ( + "SUBACK, rl=" + + str(rl) + + ", mid=" + + str(mid) + + ", granted_qos=" + + str(granted_qos[0]) + ) for i in range(1, len(granted_qos) - 1): s = s + ", " + str(granted_qos[i]) return s elif cmd == 0xA0: # UNSUBSCRIBE (packet, rl) = remaining_length(packet) - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (mid, packet) = struct.unpack(pack_format, packet) s = "UNSUBSCRIBE, rl=" + str(rl) + ", mid=" + str(mid) topic_index = 0 while len(packet) > 0: - pack_format = "!H" + str(len(packet) - 2) + 's' + pack_format = "!H" + str(len(packet) - 2) + "s" (tlen, packet) = struct.unpack(pack_format, packet) - pack_format = "!" + str(tlen) + 's' + str(len(packet) - tlen) + 's' + pack_format = "!" + str(tlen) + "s" + str(len(packet) - tlen) + "s" (topic, packet) = struct.unpack(pack_format, packet) s = s + ", topic" + str(topic_index) + "=" + str(topic) return s elif cmd == 0xB0: # UNSUBACK - (cmd, rl, mid) = struct.unpack('!BBH', packet) + (cmd, rl, mid) = struct.unpack("!BBH", packet) return "UNSUBACK, rl=" + str(rl) + ", mid=" + str(mid) elif cmd == 0xC0: # PINGREQ - (cmd, rl) = struct.unpack('!BB', packet) + (cmd, rl) = struct.unpack("!BB", packet) return "PINGREQ, rl=" + str(rl) elif cmd == 0xD0: # PINGRESP - (cmd, rl) = struct.unpack('!BB', packet) + (cmd, rl) = struct.unpack("!BB", packet) return "PINGRESP, rl=" + str(rl) elif cmd == 0xE0: # DISCONNECT if len(packet) == 3: - (cmd, rl, reason_code) = struct.unpack('!BBB', packet) - return "DISCONNECT, rl="+str(rl)+", reason_code="+str(reason_code) + (cmd, rl, reason_code) = struct.unpack("!BBB", packet) + return "DISCONNECT, rl=" + str(rl) + ", reason_code=" + str(reason_code) else: - (cmd, rl) = struct.unpack('!BB', packet) - return "DISCONNECT, rl="+str(rl) + (cmd, rl) = struct.unpack("!BB", packet) + return "DISCONNECT, rl=" + str(rl) elif cmd == 0xF0: # AUTH - (cmd, rl) = struct.unpack('!BB', packet) - return "AUTH, rl="+str(rl) + (cmd, rl) = struct.unpack("!BB", packet) + return "AUTH, rl=" + str(rl) raise ValueError(f"Unknown packet type {cmd}") diff --git a/tests/lib/clients/01-asyncio.py b/tests/lib/clients/01-asyncio.py index eeab4433..cfd99a6b 100644 --- a/tests/lib/clients/01-asyncio.py +++ b/tests/lib/clients/01-asyncio.py @@ -5,7 +5,7 @@ from tests.paho_test import get_test_server_port -client_id = 'asyncio-test' +client_id = "asyncio-test" class AsyncioHelper: @@ -80,10 +80,11 @@ def on_disconnect(client, userdata, rc): _aioh = AsyncioHelper(loop, client) - client.connect('localhost', get_test_server_port(), 60) + client.connect("localhost", get_test_server_port(), 60) client.socket().setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 2048) await disconnected -if __name__ == '__main__': + +if __name__ == "__main__": asyncio.run(main()) diff --git a/tests/lib/clients/01-unpwd-unicode-set.py b/tests/lib/clients/01-unpwd-unicode-set.py index b0d12f9d..1f49e1fe 100644 --- a/tests/lib/clients/01-unpwd-unicode-set.py +++ b/tests/lib/clients/01-unpwd-unicode-set.py @@ -1,4 +1,3 @@ - import paho.mqtt.client as mqtt from tests.paho_test import get_test_server_port, loop_until_keyboard_interrupt diff --git a/tests/lib/clients/03-publish-b2c-qos1.py b/tests/lib/clients/03-publish-b2c-qos1.py index 9efc928d..880ad2ff 100644 --- a/tests/lib/clients/03-publish-b2c-qos1.py +++ b/tests/lib/clients/03-publish-b2c-qos1.py @@ -10,7 +10,7 @@ def on_message(mqttc, obj, msg): assert msg.topic == "pub/qos1/receive", f"Invalid topic: ({msg.topic})" assert msg.payload == expected_payload, f"Invalid payload: ({msg.payload})" assert msg.qos == 1, f"Invalid qos: ({msg.qos})" - assert msg.retain is not False, f"Invalid retain: ({msg.retain})" + assert not msg.retain, f"Invalid retain: ({msg.retain})" def on_connect(mqttc, obj, flags, rc): diff --git a/tests/lib/clients/03-publish-fill-inflight.py b/tests/lib/clients/03-publish-fill-inflight.py index f954dd13..ea427a9f 100644 --- a/tests/lib/clients/03-publish-fill-inflight.py +++ b/tests/lib/clients/03-publish-fill-inflight.py @@ -22,10 +22,12 @@ def on_connect(mqttc, obj, flags, rc): for i in range(12): mqttc.publish("topic", expected_payload(i), qos=1) + def on_disconnect(mqttc, rc, properties): logging.info("disconnected") mqttc.reconnect() + logging.basicConfig(level=logging.DEBUG) logging.info(str(mqtt)) mqttc = mqtt.Client("publish-qos1-test") diff --git a/tests/lib/conftest.py b/tests/lib/conftest.py index bb2f4278..f8681bf1 100644 --- a/tests/lib/conftest.py +++ b/tests/lib/conftest.py @@ -53,17 +53,24 @@ def starter(name: str, expected_returncode: int = 0) -> None: PAHO_SSL_PATH=str(ssl_path), PYTHONPATH=f"{tests_path}{os.pathsep}{os.environ.get('PYTHONPATH', '')}", ) - assert 'PAHO_SERVER_PORT' in env, "PAHO_SERVER_PORT must be set in the environment when starting a client" + assert ( + "PAHO_SERVER_PORT" in env + ), "PAHO_SERVER_PORT must be set in the environment when starting a client" # TODO: it would be nice to run this under `coverage` too! - proc = subprocess.Popen([ # noqa: S603 - sys.executable, - str(client_path), - ], env=env) + proc = subprocess.Popen( + [ # noqa: S603 + sys.executable, + str(client_path), + ], + env=env, + ) def fin(): stop_process(proc) if proc.returncode != expected_returncode: - raise RuntimeError(f"Client {name} exited with code {proc.returncode}, expected {expected_returncode}") + raise RuntimeError( + f"Client {name} exited with code {proc.returncode}, expected {expected_returncode}" + ) request.addfinalizer(fin) return proc diff --git a/tests/lib/test_01_no_clean_session.py b/tests/lib/test_01_no_clean_session.py index 7f00e544..a436feab 100644 --- a/tests/lib/test_01_no_clean_session.py +++ b/tests/lib/test_01_no_clean_session.py @@ -6,7 +6,9 @@ import tests.paho_test as paho_test -connect_packet = paho_test.gen_connect("01-no-clean-session", clean_session=False, keepalive=60) +connect_packet = paho_test.gen_connect( + "01-no-clean-session", clean_session=False, keepalive=60 +) def test_01_no_clean_session(server_socket, start_client): diff --git a/tests/lib/test_01_reconnect_on_failure.py b/tests/lib/test_01_reconnect_on_failure.py index 8deb6539..832f5ee2 100644 --- a/tests/lib/test_01_reconnect_on_failure.py +++ b/tests/lib/test_01_reconnect_on_failure.py @@ -7,8 +7,7 @@ connack_packet_ok = paho_test.gen_connack(rc=0) connack_packet_failure = paho_test.gen_connack(rc=1) # CONNACK_REFUSED_PROTOCOL_VERSION -publish_packet = paho_test.gen_publish( - "reconnect/test", qos=0, payload="message") +publish_packet = paho_test.gen_publish("reconnect/test", qos=0, payload="message") @pytest.mark.parametrize("ok_code", [False, True]) diff --git a/tests/lib/test_01_unpwd_empty_password_set.py b/tests/lib/test_01_unpwd_empty_password_set.py index 225d72bc..8fd4351a 100644 --- a/tests/lib/test_01_unpwd_empty_password_set.py +++ b/tests/lib/test_01_unpwd_empty_password_set.py @@ -7,7 +7,8 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "01-unpwd-set", keepalive=60, username="uname", password="") + "01-unpwd-set", keepalive=60, username="uname", password="" +) def test_01_unpwd_empty_password_set(server_socket, start_client): diff --git a/tests/lib/test_01_unpwd_empty_set.py b/tests/lib/test_01_unpwd_empty_set.py index 8c51c22a..7a40145f 100644 --- a/tests/lib/test_01_unpwd_empty_set.py +++ b/tests/lib/test_01_unpwd_empty_set.py @@ -7,7 +7,8 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "01-unpwd-set", keepalive=60, username="", password='') + "01-unpwd-set", keepalive=60, username="", password="" +) def test_01_unpwd_empty_set(server_socket, start_client): diff --git a/tests/lib/test_01_unpwd_set.py b/tests/lib/test_01_unpwd_set.py index 38834ab9..bf0b5a65 100644 --- a/tests/lib/test_01_unpwd_set.py +++ b/tests/lib/test_01_unpwd_set.py @@ -7,7 +7,8 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "01-unpwd-set", keepalive=60, username="uname", password=";'[08gn=#") + "01-unpwd-set", keepalive=60, username="uname", password=";'[08gn=#" +) def test_01_unpwd_set(server_socket, start_client): diff --git a/tests/lib/test_01_will_set.py b/tests/lib/test_01_will_set.py index 55b55a25..34467b99 100644 --- a/tests/lib/test_01_will_set.py +++ b/tests/lib/test_01_will_set.py @@ -9,8 +9,13 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "01-will-set", keepalive=60, will_topic="topic/on/unexpected/disconnect", - will_qos=1, will_retain=True, will_payload="will message") + "01-will-set", + keepalive=60, + will_topic="topic/on/unexpected/disconnect", + will_qos=1, + will_retain=True, + will_payload="will message", +) def test_01_will_set(server_socket, start_client): diff --git a/tests/lib/test_01_will_unpwd_set.py b/tests/lib/test_01_will_unpwd_set.py index 95c0517f..a38b1b67 100644 --- a/tests/lib/test_01_will_unpwd_set.py +++ b/tests/lib/test_01_will_unpwd_set.py @@ -10,8 +10,12 @@ connect_packet = paho_test.gen_connect( "01-will-unpwd-set", - keepalive=60, username="oibvvwqw", password="#'^2hg9a&nm38*us", - will_topic="will-topic", will_qos=2, will_payload="will message", + keepalive=60, + username="oibvvwqw", + password="#'^2hg9a&nm38*us", + will_topic="will-topic", + will_qos=2, + will_payload="will message", ) diff --git a/tests/lib/test_03_publish_b2c_qos1.py b/tests/lib/test_03_publish_b2c_qos1.py index 4666900d..d6f3efc7 100644 --- a/tests/lib/test_03_publish_b2c_qos1.py +++ b/tests/lib/test_03_publish_b2c_qos1.py @@ -18,7 +18,8 @@ mid = 123 publish_packet = paho_test.gen_publish( - "pub/qos1/receive", qos=1, mid=mid, payload="message") + "pub/qos1/receive", qos=1, mid=mid, payload="message" +) puback_packet = paho_test.gen_puback(mid) diff --git a/tests/lib/test_03_publish_c2b_qos1_disconnect.py b/tests/lib/test_03_publish_c2b_qos1_disconnect.py index 10daca6a..deb21b03 100644 --- a/tests/lib/test_03_publish_c2b_qos1_disconnect.py +++ b/tests/lib/test_03_publish_c2b_qos1_disconnect.py @@ -4,7 +4,9 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "publish-qos1-test", keepalive=60, clean_session=False, + "publish-qos1-test", + keepalive=60, + clean_session=False, ) connack_packet = paho_test.gen_connack(rc=0) @@ -12,9 +14,11 @@ mid = 1 publish_packet = paho_test.gen_publish( - "pub/qos1/test", qos=1, mid=mid, payload="message") + "pub/qos1/test", qos=1, mid=mid, payload="message" +) publish_packet_dup = paho_test.gen_publish( - "pub/qos1/test", qos=1, mid=mid, payload="message", dup=True) + "pub/qos1/test", qos=1, mid=mid, payload="message", dup=True +) puback_packet = paho_test.gen_puback(mid) diff --git a/tests/lib/test_03_publish_c2b_qos2_disconnect.py b/tests/lib/test_03_publish_c2b_qos2_disconnect.py index 15b1d496..2d416f11 100644 --- a/tests/lib/test_03_publish_c2b_qos2_disconnect.py +++ b/tests/lib/test_03_publish_c2b_qos2_disconnect.py @@ -4,7 +4,9 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "publish-qos2-test", keepalive=60, clean_session=False, + "publish-qos2-test", + keepalive=60, + clean_session=False, ) connack_packet = paho_test.gen_connack(rc=0) @@ -12,9 +14,11 @@ mid = 1 publish_packet = paho_test.gen_publish( - "pub/qos2/test", qos=2, mid=mid, payload="message") + "pub/qos2/test", qos=2, mid=mid, payload="message" +) publish_dup_packet = paho_test.gen_publish( - "pub/qos2/test", qos=2, mid=mid, payload="message", dup=True) + "pub/qos2/test", qos=2, mid=mid, payload="message", dup=True +) pubrec_packet = paho_test.gen_pubrec(mid) pubrel_packet = paho_test.gen_pubrel(mid) pubcomp_packet = paho_test.gen_pubcomp(mid) diff --git a/tests/lib/test_03_publish_fill_inflight.py b/tests/lib/test_03_publish_fill_inflight.py index 697c896a..1a65e856 100644 --- a/tests/lib/test_03_publish_fill_inflight.py +++ b/tests/lib/test_03_publish_fill_inflight.py @@ -22,6 +22,7 @@ def expected_payload(i: int) -> bytes: return f"message{i}" + connect_packet = paho_test.gen_connect("publish-qos1-test", keepalive=60) connack_packet = paho_test.gen_connack(rc=0) @@ -29,7 +30,10 @@ def expected_payload(i: int) -> bytes: first_connection_publishs = [ paho_test.gen_publish( - "topic", qos=1, mid=i+1, payload=expected_payload(i), + "topic", + qos=1, + mid=i + 1, + payload=expected_payload(i), ) for i in range(10) ] @@ -39,14 +43,15 @@ def expected_payload(i: int) -> bytes: # Currently on reconnection client will do two wrong thing: # * it sent more than max_inflight packet # * it re-send message both with mid = old_mid + 12 AND with mid = old_mid & dup=1 - "topic", qos=1, mid=i+13, payload=expected_payload(i), + "topic", + qos=1, + mid=i + 13, + payload=expected_payload(i), ) for i in range(12) ] -second_connection_pubacks = [ - paho_test.gen_puback(i+13) - for i in range(12) -] +second_connection_pubacks = [paho_test.gen_puback(i + 13) for i in range(12)] + @pytest.mark.xfail def test_03_publish_fill_inflight(server_socket, start_client): @@ -87,4 +92,3 @@ def test_03_publish_fill_inflight(server_socket, start_client): paho_test.expect_packet(conn, "publish", second_connection_publishs[11]) paho_test.expect_no_packet(conn, 0.5) - diff --git a/tests/lib/test_03_publish_helper_qos0.py b/tests/lib/test_03_publish_helper_qos0.py index b1c57d90..dd4ba400 100644 --- a/tests/lib/test_03_publish_helper_qos0.py +++ b/tests/lib/test_03_publish_helper_qos0.py @@ -14,13 +14,12 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "publish-helper-qos0-test", keepalive=60, + "publish-helper-qos0-test", + keepalive=60, ) connack_packet = paho_test.gen_connack(rc=0) -publish_packet = paho_test.gen_publish( - "pub/qos0/test", qos=0, payload="message" -) +publish_packet = paho_test.gen_publish("pub/qos0/test", qos=0, payload="message") disconnect_packet = paho_test.gen_disconnect() diff --git a/tests/lib/test_03_publish_helper_qos1_disconnect.py b/tests/lib/test_03_publish_helper_qos1_disconnect.py index f73462c3..ca203ef9 100644 --- a/tests/lib/test_03_publish_helper_qos1_disconnect.py +++ b/tests/lib/test_03_publish_helper_qos1_disconnect.py @@ -6,7 +6,8 @@ import tests.paho_test as paho_test connect_packet = paho_test.gen_connect( - "publish-helper-qos1-disconnect-test", keepalive=60, + "publish-helper-qos1-disconnect-test", + keepalive=60, ) connack_packet = paho_test.gen_connack(rc=0) @@ -15,7 +16,10 @@ "pub/qos1/test", qos=1, mid=mid, payload="message" ) publish_packet_dup = paho_test.gen_publish( - "pub/qos1/test", qos=1, mid=mid, payload="message", + "pub/qos1/test", + qos=1, + mid=mid, + payload="message", dup=True, ) puback_packet = paho_test.gen_puback(mid) diff --git a/tests/lib/test_04_retain_qos0.py b/tests/lib/test_04_retain_qos0.py index dee6b099..d0e43ac6 100644 --- a/tests/lib/test_04_retain_qos0.py +++ b/tests/lib/test_04_retain_qos0.py @@ -8,7 +8,8 @@ connack_packet = paho_test.gen_connack(rc=0) publish_packet = paho_test.gen_publish( - "retain/qos0/test", qos=0, payload="retained message", retain=True) + "retain/qos0/test", qos=0, payload="retained message", retain=True +) def test_04_retain_qos0(server_socket, start_client): diff --git a/tests/lib/test_08_ssl_bad_cacert.py b/tests/lib/test_08_ssl_bad_cacert.py index 14d48cb5..fc2918ce 100644 --- a/tests/lib/test_08_ssl_bad_cacert.py +++ b/tests/lib/test_08_ssl_bad_cacert.py @@ -8,4 +8,4 @@ def test_08_ssl_bad_cacert(): with pytest.raises(IOError): mqttc = mqtt.Client("08-ssl-bad-cacert") - mqttc.tls_set("this/file/doesnt/exist") + mqttc.tls_set("this/file/does not/exist") diff --git a/tests/mqtt5_props.py b/tests/mqtt5_props.py index f9be0d66..ed46bc4d 100644 --- a/tests/mqtt5_props.py +++ b/tests/mqtt5_props.py @@ -28,33 +28,42 @@ PROP_SUBSCRIPTION_ID_AVAILABLE = 41 PROP_SHARED_SUB_AVAILABLE = 42 + def gen_byte_prop(identifier, byte): - prop = struct.pack('BB', identifier, byte) + prop = struct.pack("BB", identifier, byte) return prop + def gen_uint16_prop(identifier, word): - prop = struct.pack('!BH', identifier, word) + prop = struct.pack("!BH", identifier, word) return prop + def gen_uint32_prop(identifier, word): - prop = struct.pack('!BI', identifier, word) + prop = struct.pack("!BI", identifier, word) return prop + def gen_string_prop(identifier, s): s = s.encode("utf-8") - prop = struct.pack(f'!BH{len(s)}s', identifier, len(s), s) + prop = struct.pack(f"!BH{len(s)}s", identifier, len(s), s) return prop + def gen_string_pair_prop(identifier, s1, s2): s1 = s1.encode("utf-8") s2 = s2.encode("utf-8") - prop = struct.pack(f'!BH{len(s1)}sH{len(s2)}s', identifier, len(s1), s1, len(s2), s2) + prop = struct.pack( + f"!BH{len(s1)}sH{len(s2)}s", identifier, len(s1), s1, len(s2), s2 + ) return prop + def gen_varint_prop(identifier, val): v = pack_varint(val) return struct.pack(f"!B{len(v)}s", identifier, v) + def pack_varint(varint): s = b"" while True: @@ -68,9 +77,9 @@ def pack_varint(varint): if varint == 0: return s + def prop_finalise(props): if props is None: return pack_varint(0) else: return pack_varint(len(props)) + props - diff --git a/tests/paho_test.py b/tests/paho_test.py index 4c77fb30..9cd4539b 100644 --- a/tests/paho_test.py +++ b/tests/paho_test.py @@ -20,7 +20,7 @@ def bind_to_any_free_port(sock) -> int: Bind a socket to an available port on localhost, and return the port number. """ - sock.bind(('localhost', 0)) + sock.bind(("localhost", 0)) return sock.getsockname()[1] @@ -58,7 +58,7 @@ def expect_packet(sock, name, expected): packet_recvd = b"" try: while len(packet_recvd) < rlen: - data = sock.recv(rlen-len(packet_recvd)) + data = sock.recv(rlen - len(packet_recvd)) if len(data) == 0: break packet_recvd += data @@ -70,8 +70,7 @@ def expect_packet(sock, name, expected): def expect_no_packet(sock, delay=1): - """ expect that nothing is received within given delay - """ + """expect that nothing is received within given delay""" sock.settimeout(delay) try: previous_timeout = sock.gettimeout() @@ -97,17 +96,32 @@ def packet_matches(name, recvd, expected): return True -def gen_connect(client_id, clean_session=True, keepalive=60, username=None, password=None, will_topic=None, will_qos=0, will_retain=False, will_payload=b"", proto_ver=4, connect_reserved=False, properties=b"", will_properties=b"", session_expiry=-1): - if (proto_ver&0x7F) == 3 or proto_ver == 0: +def gen_connect( + client_id, + clean_session=True, + keepalive=60, + username=None, + password=None, + will_topic=None, + will_qos=0, + will_retain=False, + will_payload=b"", + proto_ver=4, + connect_reserved=False, + properties=b"", + will_properties=b"", + session_expiry=-1, +): + if (proto_ver & 0x7F) == 3 or proto_ver == 0: remaining_length = 12 - elif (proto_ver&0x7F) == 4 or proto_ver == 5: + elif (proto_ver & 0x7F) == 4 or proto_ver == 5: remaining_length = 10 else: raise ValueError if client_id is not None: client_id = client_id.encode("utf-8") - remaining_length = remaining_length + 2+len(client_id) + remaining_length = remaining_length + 2 + len(client_id) else: remaining_length = remaining_length + 2 @@ -121,17 +135,23 @@ def gen_connect(client_id, clean_session=True, keepalive=60, username=None, pass if proto_ver == 5: if properties == b"": - properties += mqtt5_props.gen_uint16_prop(mqtt5_props.PROP_RECEIVE_MAXIMUM, 20) + properties += mqtt5_props.gen_uint16_prop( + mqtt5_props.PROP_RECEIVE_MAXIMUM, 20 + ) if session_expiry != -1: - properties += mqtt5_props.gen_uint32_prop(mqtt5_props.PROP_SESSION_EXPIRY_INTERVAL, session_expiry) + properties += mqtt5_props.gen_uint32_prop( + mqtt5_props.PROP_SESSION_EXPIRY_INTERVAL, session_expiry + ) properties = mqtt5_props.prop_finalise(properties) remaining_length += len(properties) if will_topic is not None: - will_topic = will_topic.encode('utf-8') - remaining_length = remaining_length + 2 + len(will_topic) + 2 + len(will_payload) + will_topic = will_topic.encode("utf-8") + remaining_length = ( + remaining_length + 2 + len(will_topic) + 2 + len(will_payload) + ) connect_flags = connect_flags | 0x04 | ((will_qos & 0x03) << 3) if will_retain: connect_flags = connect_flags | 32 @@ -140,64 +160,96 @@ def gen_connect(client_id, clean_session=True, keepalive=60, username=None, pass remaining_length += len(will_properties) if username is not None: - username = username.encode('utf-8') + username = username.encode("utf-8") remaining_length = remaining_length + 2 + len(username) connect_flags = connect_flags | 0x80 if password is not None: - password = password.encode('utf-8') + password = password.encode("utf-8") connect_flags = connect_flags | 0x40 remaining_length = remaining_length + 2 + len(password) rl = pack_remaining_length(remaining_length) packet = struct.pack("!B" + str(len(rl)) + "s", 0x10, rl) - if (proto_ver&0x7F) == 3 or proto_ver == 0: - packet = packet + struct.pack("!H6sBBH", len(b"MQIsdp"), b"MQIsdp", proto_ver, connect_flags, keepalive) - elif (proto_ver&0x7F) == 4 or proto_ver == 5: - packet = packet + struct.pack("!H4sBBH", len(b"MQTT"), b"MQTT", proto_ver, connect_flags, keepalive) + if (proto_ver & 0x7F) == 3 or proto_ver == 0: + packet = packet + struct.pack( + "!H6sBBH", len(b"MQIsdp"), b"MQIsdp", proto_ver, connect_flags, keepalive + ) + elif (proto_ver & 0x7F) == 4 or proto_ver == 5: + packet = packet + struct.pack( + "!H4sBBH", len(b"MQTT"), b"MQTT", proto_ver, connect_flags, keepalive + ) if proto_ver == 5: packet += properties if client_id is not None: - packet = packet + struct.pack("!H" + str(len(client_id)) + "s", len(client_id), bytes(client_id)) + packet = packet + struct.pack( + "!H" + str(len(client_id)) + "s", len(client_id), bytes(client_id) + ) else: packet = packet + struct.pack("!H", 0) if will_topic is not None: packet += will_properties - packet = packet + struct.pack("!H" + str(len(will_topic)) + "s", len(will_topic), will_topic) + packet = packet + struct.pack( + "!H" + str(len(will_topic)) + "s", len(will_topic), will_topic + ) if len(will_payload) > 0: - packet = packet + struct.pack("!H" + str(len(will_payload)) + "s", len(will_payload), will_payload.encode('utf8')) + packet = packet + struct.pack( + "!H" + str(len(will_payload)) + "s", + len(will_payload), + will_payload.encode("utf8"), + ) else: packet = packet + struct.pack("!H", 0) if username is not None: - packet = packet + struct.pack("!H" + str(len(username)) + "s", len(username), username) + packet = packet + struct.pack( + "!H" + str(len(username)) + "s", len(username), username + ) if password is not None: - packet = packet + struct.pack("!H" + str(len(password)) + "s", len(password), password) + packet = packet + struct.pack( + "!H" + str(len(password)) + "s", len(password), password + ) return packet + def gen_connack(flags=0, rc=0, proto_ver=4, properties=b"", property_helper=True): if proto_ver == 5: if property_helper: if properties is not None: - properties = mqtt5_props.gen_uint16_prop(mqtt5_props.PROP_TOPIC_ALIAS_MAXIMUM, 10) \ - + properties + mqtt5_props.gen_uint16_prop(mqtt5_props.PROP_RECEIVE_MAXIMUM, 20) + properties = ( + mqtt5_props.gen_uint16_prop( + mqtt5_props.PROP_TOPIC_ALIAS_MAXIMUM, 10 + ) + + properties + + mqtt5_props.gen_uint16_prop(mqtt5_props.PROP_RECEIVE_MAXIMUM, 20) + ) else: properties = b"" properties = mqtt5_props.prop_finalise(properties) - packet = struct.pack('!BBBB', 32, 2+len(properties), flags, rc) + properties + packet = struct.pack("!BBBB", 32, 2 + len(properties), flags, rc) + properties else: - packet = struct.pack('!BBBB', 32, 2, flags, rc) + packet = struct.pack("!BBBB", 32, 2, flags, rc) return packet -def gen_publish(topic, qos, payload=None, retain=False, dup=False, mid=0, proto_ver=4, properties=b""): + +def gen_publish( + topic, + qos, + payload=None, + retain=False, + dup=False, + mid=0, + proto_ver=4, + properties=b"", +): if isinstance(topic, str): topic = topic.encode("utf-8") - rl = 2+len(topic) - pack_format = "H"+str(len(topic))+"s" + rl = 2 + len(topic) + pack_format = "H" + str(len(topic)) + "s" if qos > 0: rl = rl + 2 pack_format = pack_format + "H" @@ -206,7 +258,7 @@ def gen_publish(topic, qos, payload=None, retain=False, dup=False, mid=0, proto_ properties = mqtt5_props.prop_finalise(properties) rl += len(properties) # This will break if len(properties) > 127 - pack_format = pack_format + "%ds"%(len(properties)) + pack_format = pack_format + "%ds" % (len(properties)) if payload is not None: payload = payload.encode("utf-8") @@ -225,14 +277,47 @@ def gen_publish(topic, qos, payload=None, retain=False, dup=False, mid=0, proto_ if proto_ver == 5: if qos > 0: - return struct.pack("!B" + str(len(rlpacked))+"s" + pack_format, cmd, rlpacked, len(topic), topic, mid, properties, payload) + return struct.pack( + "!B" + str(len(rlpacked)) + "s" + pack_format, + cmd, + rlpacked, + len(topic), + topic, + mid, + properties, + payload, + ) else: - return struct.pack("!B" + str(len(rlpacked))+"s" + pack_format, cmd, rlpacked, len(topic), topic, properties, payload) + return struct.pack( + "!B" + str(len(rlpacked)) + "s" + pack_format, + cmd, + rlpacked, + len(topic), + topic, + properties, + payload, + ) else: if qos > 0: - return struct.pack("!B" + str(len(rlpacked))+"s" + pack_format, cmd, rlpacked, len(topic), topic, mid, payload) + return struct.pack( + "!B" + str(len(rlpacked)) + "s" + pack_format, + cmd, + rlpacked, + len(topic), + topic, + mid, + payload, + ) else: - return struct.pack("!B" + str(len(rlpacked))+"s" + pack_format, cmd, rlpacked, len(topic), topic, payload) + return struct.pack( + "!B" + str(len(rlpacked)) + "s" + pack_format, + cmd, + rlpacked, + len(topic), + topic, + payload, + ) + def _gen_command_with_mid(cmd, mid, proto_ver=4, reason_code=-1, properties=None): if proto_ver == 5 and (reason_code != -1 or properties is not None): @@ -240,29 +325,35 @@ def _gen_command_with_mid(cmd, mid, proto_ver=4, reason_code=-1, properties=None reason_code = 0 if properties is None: - return struct.pack('!BBHB', cmd, 3, mid, reason_code) + return struct.pack("!BBHB", cmd, 3, mid, reason_code) elif properties == "": - return struct.pack('!BBHBB', cmd, 4, mid, reason_code, 0) + return struct.pack("!BBHBB", cmd, 4, mid, reason_code, 0) else: properties = mqtt5_props.prop_finalise(properties) - pack_format = "!BBHB"+str(len(properties))+"s" - return struct.pack(pack_format, cmd, 2+1+len(properties), mid, reason_code, properties) + pack_format = "!BBHB" + str(len(properties)) + "s" + return struct.pack( + pack_format, cmd, 2 + 1 + len(properties), mid, reason_code, properties + ) else: - return struct.pack('!BBH', cmd, 2, mid) + return struct.pack("!BBH", cmd, 2, mid) + def gen_puback(mid, proto_ver=4, reason_code=-1, properties=None): return _gen_command_with_mid(64, mid, proto_ver, reason_code, properties) + def gen_pubrec(mid, proto_ver=4, reason_code=-1, properties=None): return _gen_command_with_mid(80, mid, proto_ver, reason_code, properties) + def gen_pubrel(mid, dup=False, proto_ver=4, reason_code=-1, properties=None): if dup: - cmd = 96+8+2 + cmd = 96 + 8 + 2 else: - cmd = 96+2 + cmd = 96 + 2 return _gen_command_with_mid(cmd, mid, proto_ver, reason_code, properties) + def gen_pubcomp(mid, proto_ver=4, reason_code=-1, properties=None): return _gen_command_with_mid(112, mid, proto_ver, reason_code, properties) @@ -272,54 +363,64 @@ def gen_subscribe(mid, topic, qos, cmd=130, proto_ver=4, properties=b""): packet = struct.pack("!B", cmd) if proto_ver == 5: if properties == b"": - packet += pack_remaining_length(2+1+2+len(topic)+1) - pack_format = "!HBH"+str(len(topic))+"sB" + packet += pack_remaining_length(2 + 1 + 2 + len(topic) + 1) + pack_format = "!HBH" + str(len(topic)) + "sB" return packet + struct.pack(pack_format, mid, 0, len(topic), topic, qos) else: properties = mqtt5_props.prop_finalise(properties) - packet += pack_remaining_length(2+1+2+len(topic)+len(properties)) - pack_format = "!H"+str(len(properties))+"s"+"H"+str(len(topic))+"sB" - return packet + struct.pack(pack_format, mid, properties, len(topic), topic, qos) + packet += pack_remaining_length(2 + 1 + 2 + len(topic) + len(properties)) + pack_format = ( + "!H" + str(len(properties)) + "s" + "H" + str(len(topic)) + "sB" + ) + return packet + struct.pack( + pack_format, mid, properties, len(topic), topic, qos + ) else: - packet += pack_remaining_length(2+2+len(topic)+1) - pack_format = "!HH"+str(len(topic))+"sB" + packet += pack_remaining_length(2 + 2 + len(topic) + 1) + pack_format = "!HH" + str(len(topic)) + "sB" return packet + struct.pack(pack_format, mid, len(topic), topic, qos) def gen_suback(mid, qos, proto_ver=4): if proto_ver == 5: - return struct.pack('!BBHBB', 144, 2+1+1, mid, 0, qos) + return struct.pack("!BBHBB", 144, 2 + 1 + 1, mid, 0, qos) else: - return struct.pack('!BBHB', 144, 2+1, mid, qos) + return struct.pack("!BBHB", 144, 2 + 1, mid, qos) + def gen_unsubscribe(mid, topic, cmd=162, proto_ver=4, properties=b""): topic = topic.encode("utf-8") if proto_ver == 5: if properties == b"": - pack_format = "!BBHBH"+str(len(topic))+"s" - return struct.pack(pack_format, cmd, 2+2+len(topic)+1, mid, 0, len(topic), topic) + pack_format = "!BBHBH" + str(len(topic)) + "s" + return struct.pack( + pack_format, cmd, 2 + 2 + len(topic) + 1, mid, 0, len(topic), topic + ) else: properties = mqtt5_props.prop_finalise(properties) packet = struct.pack("!B", cmd) - l = 2+2+len(topic)+1+len(properties) # noqa: E741 + l = 2 + 2 + len(topic) + 1 + len(properties) # noqa: E741 packet += pack_remaining_length(l) - pack_format = "!HB"+str(len(properties))+"sH"+str(len(topic))+"s" - packet += struct.pack(pack_format, mid, len(properties), properties, len(topic), topic) + pack_format = "!HB" + str(len(properties)) + "sH" + str(len(topic)) + "s" + packet += struct.pack( + pack_format, mid, len(properties), properties, len(topic), topic + ) return packet else: - pack_format = "!BBHH"+str(len(topic))+"s" - return struct.pack(pack_format, cmd, 2+2+len(topic), mid, len(topic), topic) + pack_format = "!BBHH" + str(len(topic)) + "s" + return struct.pack(pack_format, cmd, 2 + 2 + len(topic), mid, len(topic), topic) + def gen_unsubscribe_multiple(mid, topics, proto_ver=4): packet = b"" remaining_length = 0 for t in topics: t = t.encode("utf-8") - remaining_length += 2+len(t) - packet += struct.pack("!H"+str(len(t))+"s", len(t), t) + remaining_length += 2 + len(t) + packet += struct.pack("!H" + str(len(t)) + "s", len(t), t) if proto_ver == 5: - remaining_length += 2+1 + remaining_length += 2 + 1 return struct.pack("!BBHB", 162, remaining_length, mid, 0) + packet else: @@ -327,44 +428,51 @@ def gen_unsubscribe_multiple(mid, topics, proto_ver=4): return struct.pack("!BBH", 162, remaining_length, mid) + packet + def gen_unsuback(mid, reason_code=0, proto_ver=4): if proto_ver == 5: if isinstance(reason_code, list): reason_code_count = len(reason_code) - p = struct.pack('!BBHB', 176, 3+reason_code_count, mid, 0) + p = struct.pack("!BBHB", 176, 3 + reason_code_count, mid, 0) for r in reason_code: - p += struct.pack('B', r) + p += struct.pack("B", r) return p else: - return struct.pack('!BBHBB', 176, 4, mid, 0, reason_code) + return struct.pack("!BBHBB", 176, 4, mid, 0, reason_code) else: - return struct.pack('!BBH', 176, 2, mid) + return struct.pack("!BBH", 176, 2, mid) + def gen_pingreq(): - return struct.pack('!BB', 192, 0) + return struct.pack("!BB", 192, 0) + def gen_pingresp(): - return struct.pack('!BB', 208, 0) + return struct.pack("!BB", 208, 0) def _gen_short(cmd, reason_code=-1, proto_ver=5, properties=None): if proto_ver == 5 and (reason_code != -1 or properties is not None): if reason_code == -1: - reason_code = 0 + reason_code = 0 if properties is None: - return struct.pack('!BBB', cmd, 1, reason_code) + return struct.pack("!BBB", cmd, 1, reason_code) elif properties == "": - return struct.pack('!BBBB', cmd, 2, reason_code, 0) + return struct.pack("!BBBB", cmd, 2, reason_code, 0) else: properties = mqtt5_props.prop_finalise(properties) - return struct.pack("!BBB", cmd, 1+len(properties), reason_code) + properties + return ( + struct.pack("!BBB", cmd, 1 + len(properties), reason_code) + properties + ) else: - return struct.pack('!BB', cmd, 0) + return struct.pack("!BB", cmd, 0) + def gen_disconnect(reason_code=-1, proto_ver=4, properties=None): return _gen_short(0xE0, reason_code, proto_ver, properties) + def gen_auth(reason_code=-1, properties=None): return _gen_short(0xF0, reason_code, 5, properties) @@ -421,4 +529,4 @@ def get_test_server_port() -> int: """ Get the port number for the test server. """ - return int(os.environ['PAHO_SERVER_PORT']) + return int(os.environ["PAHO_SERVER_PORT"]) diff --git a/tests/test_client.py b/tests/test_client.py index adfe6014..dd07b181 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -10,18 +10,20 @@ from tests.testsupport.broker import fake_broker # noqa: F401 -@pytest.mark.parametrize("proto_ver", [ - (client.MQTTv31), - (client.MQTTv311), -]) +@pytest.mark.parametrize( + "proto_ver", + [ + (client.MQTTv31), + (client.MQTTv311), + ], +) class Test_connect: """ Tests on connect/disconnect behaviour of the client """ def test_01_con_discon_success(self, proto_ver, fake_broker): - mqttc = client.Client( - "01-con-discon-success", protocol=proto_ver) + mqttc = client.Client("01-con-discon-success", protocol=proto_ver) def on_connect(mqttc, obj, flags, rc): assert rc == 0 @@ -36,8 +38,8 @@ def on_connect(mqttc, obj, flags, rc): fake_broker.start() connect_packet = paho_test.gen_connect( - "01-con-discon-success", keepalive=60, - proto_ver=proto_ver) + "01-con-discon-success", keepalive=60, proto_ver=proto_ver + ) packet_in = fake_broker.receive_packet(1000) assert packet_in # Check connection was not closed assert packet_in == connect_packet @@ -59,8 +61,7 @@ def on_connect(mqttc, obj, flags, rc): assert not packet_in # Check connection is closed def test_01_con_failure_rc(self, proto_ver, fake_broker): - mqttc = client.Client( - "01-con-failure-rc", protocol=proto_ver) + mqttc = client.Client("01-con-failure-rc", protocol=proto_ver) def on_connect(mqttc, obj, flags, rc): assert rc == 1 @@ -74,8 +75,8 @@ def on_connect(mqttc, obj, flags, rc): fake_broker.start() connect_packet = paho_test.gen_connect( - "01-con-failure-rc", keepalive=60, - proto_ver=proto_ver) + "01-con-failure-rc", keepalive=60, proto_ver=proto_ver + ) packet_in = fake_broker.receive_packet(1000) assert packet_in # Check connection was not closed assert packet_in == connect_packet @@ -93,7 +94,6 @@ def on_connect(mqttc, obj, flags, rc): class TestPublishBroker2Client: - def test_invalid_utf8_topic(self, fake_broker): mqttc = client.Client("client-id") @@ -140,7 +140,7 @@ def test_valid_utf8_topic_recv(self, fake_broker): mqttc = client.Client("client-id") # It should be non-ascii multi-bytes character - topic = unicodedata.lookup('SNOWMAN') + topic = unicodedata.lookup("SNOWMAN") def on_message(client, userdata, msg): assert msg.topic == topic @@ -164,9 +164,7 @@ def on_message(client, userdata, msg): assert count # Check connection was not closed assert count == len(connack_packet) - publish_packet = paho_test.gen_publish( - topic.encode('utf-8'), qos=0 - ) + publish_packet = paho_test.gen_publish(topic.encode("utf-8"), qos=0) count = fake_broker.send_packet(publish_packet) assert count # Check connection was not closed assert count == len(publish_packet) @@ -186,7 +184,7 @@ def test_valid_utf8_topic_publish(self, fake_broker): mqttc = client.Client("client-id") # It should be non-ascii multi-bytes character - topic = unicodedata.lookup('SNOWMAN') + topic = unicodedata.lookup("SNOWMAN") mqttc.connect_async("localhost", fake_broker.port) mqttc.loop_start() @@ -208,9 +206,7 @@ def test_valid_utf8_topic_publish(self, fake_broker): # Small sleep needed to avoid connection reset. time.sleep(0.3) - publish_packet = paho_test.gen_publish( - topic.encode('utf-8'), qos=0 - ) + publish_packet = paho_test.gen_publish(topic.encode("utf-8"), qos=0) packet_in = fake_broker.receive_packet(len(publish_packet)) assert packet_in # Check connection was not closed assert packet_in == publish_packet @@ -231,27 +227,27 @@ def test_valid_utf8_topic_publish(self, fake_broker): def test_message_callback(self, fake_broker): mqttc = client.Client("client-id") userdata = { - 'on_message': 0, - 'callback1': 0, - 'callback2': 0, + "on_message": 0, + "callback1": 0, + "callback2": 0, } mqttc.user_data_set(userdata) def on_message(client, userdata, msg): - assert msg.topic == 'topic/value' - userdata['on_message'] += 1 + assert msg.topic == "topic/value" + userdata["on_message"] += 1 def callback1(client, userdata, msg): - assert msg.topic == 'topic/callback/1' - userdata['callback1'] += 1 + assert msg.topic == "topic/callback/1" + userdata["callback1"] += 1 def callback2(client, userdata, msg): - assert msg.topic in ('topic/callback/3', 'topic/callback/1') - userdata['callback2'] += 1 + assert msg.topic in ("topic/callback/3", "topic/callback/1") + userdata["callback2"] += 1 mqttc.on_message = on_message - mqttc.message_callback_add('topic/callback/1', callback1) - mqttc.message_callback_add('topic/callback/+', callback2) + mqttc.message_callback_add("topic/callback/1", callback1) + mqttc.message_callback_add("topic/callback/+", callback2) mqttc.connect_async("localhost", fake_broker.port) mqttc.loop_start() @@ -284,7 +280,6 @@ def callback2(client, userdata, msg): assert count # Check connection was not closed assert count == len(publish_packet) - puback_packet = paho_test.gen_puback(mid=1) packet_in = fake_broker.receive_packet(len(puback_packet)) assert packet_in # Check connection was not closed @@ -313,6 +308,47 @@ def callback2(client, userdata, msg): packet_in = fake_broker.receive_packet(1) assert not packet_in # Check connection is closed - assert userdata['on_message'] == 1 - assert userdata['callback1'] == 1 - assert userdata['callback2'] == 2 + assert userdata["on_message"] == 1 + assert userdata["callback1"] == 1 + assert userdata["callback2"] == 2 + + +class Test_compatibility: + """ + Few test for backward compatibility + """ + + def test_change_error_code_to_enum(self): + """Make sure code don't break after MQTTErrorCode enum introduction""" + rc_ok = client.MQTTErrorCode.MQTT_ERR_SUCCESS + rc_again = client.MQTTErrorCode.MQTT_ERR_AGAIN + rc_err = client.MQTTErrorCode.MQTT_ERR_NOMEM + + # Access using old name still works + assert rc_ok == client.MQTT_ERR_SUCCESS + + # User might compare to 0 to check for success + assert rc_ok == 0 + assert not rc_err == 0 + assert not rc_again == 0 + assert not rc_ok != 0 + assert rc_err != 0 + assert rc_again != 0 + + # User might compare to specific code + assert rc_again == -1 + assert rc_err == 1 + + # User might just use "if rc:" + assert not rc_ok + assert rc_err + assert rc_again + + # User might do inequality with 0 (like "if rc > 0") + assert not (rc_ok > 0) + assert rc_err > 0 + assert rc_again < 0 + + # This might probably not be done: User might use rc as number in + # operation + assert rc_ok + 1 == 1 diff --git a/tests/test_matcher.py b/tests/test_matcher.py index e2dc02a4..d8145229 100644 --- a/tests/test_matcher.py +++ b/tests/test_matcher.py @@ -7,30 +7,35 @@ class Test_client_function: Tests on topic_matches_sub function in the client module """ - @pytest.mark.parametrize("sub,topic", [ - ("foo/bar", "foo/bar"), - ("foo/+", "foo/bar"), - ("foo/+/baz", "foo/bar/baz"), - ("foo/+/#", "foo/bar/baz"), - ("A/B/+/#", "A/B/B/C"), - ("#", "foo/bar/baz"), - ("#", "/foo/bar"), - ("/#", "/foo/bar"), - ("$SYS/bar", "$SYS/bar"), - ]) + @pytest.mark.parametrize( + "sub,topic", + [ + ("foo/bar", "foo/bar"), + ("foo/+", "foo/bar"), + ("foo/+/baz", "foo/bar/baz"), + ("foo/+/#", "foo/bar/baz"), + ("A/B/+/#", "A/B/B/C"), + ("#", "foo/bar/baz"), + ("#", "/foo/bar"), + ("/#", "/foo/bar"), + ("$SYS/bar", "$SYS/bar"), + ], + ) def test_matching(self, sub, topic): assert client.topic_matches_sub(sub, topic) - - @pytest.mark.parametrize("sub,topic", [ - ("test/6/#", "test/3"), - ("foo/bar", "foo"), - ("foo/+", "foo/bar/baz"), - ("foo/+/baz", "foo/bar/bar"), - ("foo/+/#", "fo2/bar/baz"), - ("/#", "foo/bar"), - ("#", "$SYS/bar"), - ("$BOB/bar", "$SYS/bar"), - ]) + @pytest.mark.parametrize( + "sub,topic", + [ + ("test/6/#", "test/3"), + ("foo/bar", "foo"), + ("foo/+", "foo/bar/baz"), + ("foo/+/baz", "foo/bar/bar"), + ("foo/+/#", "fo2/bar/baz"), + ("/#", "foo/bar"), + ("#", "$SYS/bar"), + ("$BOB/bar", "$SYS/bar"), + ], + ) def test_not_matching(self, sub, topic): assert not client.topic_matches_sub(sub, topic) diff --git a/tests/test_mqttv5.py b/tests/test_mqttv5.py index 9fd61f2d..e510046d 100644 --- a/tests/test_mqttv5.py +++ b/tests/test_mqttv5.py @@ -31,7 +31,6 @@ class Callbacks: - def __init__(self): self.messages = [] self.publisheds = [] @@ -42,16 +41,27 @@ def __init__(self): self.conn_failures = [] def __str__(self): - return str(self.messages) + str(self.messagedicts) + str(self.publisheds) + \ - str(self.subscribeds) + \ - str(self.unsubscribeds) + str(self.disconnects) + return ( + str(self.messages) + + str(self.messagedicts) + + str(self.publisheds) + + str(self.subscribeds) + + str(self.unsubscribeds) + + str(self.disconnects) + ) def clear(self): self.__init__() def on_connect(self, client, userdata, flags, reasonCode, properties): - self.connecteds.append({"userdata": userdata, "flags": flags, - "reasonCode": reasonCode, "properties": properties}) + self.connecteds.append( + { + "userdata": userdata, + "flags": flags, + "reasonCode": reasonCode, + "properties": properties, + } + ) def on_connect_fail(self, client, userdata): self.conn_failures.append({"userdata": userdata}) @@ -71,8 +81,7 @@ def wait_connected(self): return self.wait(self.connecteds) def on_disconnect(self, client, userdata, reasonCode, properties=None): - self.disconnecteds.append( - {"reasonCode": reasonCode, "properties": properties}) + self.disconnecteds.append({"reasonCode": reasonCode, "properties": properties}) def wait_disconnected(self): return self.wait(self.disconnecteds) @@ -87,15 +96,27 @@ def wait_published(self): return self.wait(self.publisheds) def on_subscribe(self, client, userdata, mid, reasonCodes, properties): - self.subscribeds.append({"mid": mid, "userdata": userdata, - "properties": properties, "reasonCodes": reasonCodes}) + self.subscribeds.append( + { + "mid": mid, + "userdata": userdata, + "properties": properties, + "reasonCodes": reasonCodes, + } + ) def wait_subscribed(self): return self.wait(self.subscribeds) def unsubscribed(self, client, userdata, mid, properties, reasonCodes): - self.unsubscribeds.append({"mid": mid, "userdata": userdata, - "properties": properties, "reasonCodes": reasonCodes}) + self.unsubscribeds.append( + { + "mid": mid, + "userdata": userdata, + "properties": properties, + "reasonCodes": reasonCodes, + } + ) def wait_unsubscribed(self): return self.wait(self.unsubscribeds) @@ -116,8 +137,9 @@ def register(self, client): def cleanRetained(port): callback = Callbacks() - curclient = paho.mqtt.client.Client(b"clean retained", - protocol=paho.mqtt.client.MQTTv5) + curclient = paho.mqtt.client.Client( + b"clean retained", protocol=paho.mqtt.client.MQTTv5 + ) curclient.loop_start() callback.register(curclient) curclient.connect(host="localhost", port=port) @@ -130,7 +152,7 @@ def cleanRetained(port): curclient.publish(message["message"].topic, b"", 0, retain=True) curclient.disconnect() curclient.loop_stop() - time.sleep(.1) + time.sleep(0.1) def cleanup(port): @@ -139,13 +161,14 @@ def cleanup(port): clientids = ("aclient", "bclient") for clientid in clientids: - curclient = paho.mqtt.client.Client(clientid.encode( - "utf-8"), protocol=paho.mqtt.client.MQTTv5) + curclient = paho.mqtt.client.Client( + clientid.encode("utf-8"), protocol=paho.mqtt.client.MQTTv5 + ) curclient.loop_start() curclient.connect(host="localhost", port=port, clean_start=True) - time.sleep(.1) + time.sleep(0.1) curclient.disconnect() - time.sleep(.1) + time.sleep(0.1) curclient.loop_stop() # clean retained messages @@ -154,7 +177,6 @@ def cleanup(port): class Test(unittest.TestCase): - @classmethod def setUpClass(cls): global callback, callback2, aclient, bclient @@ -180,15 +202,17 @@ def setUpClass(cls): # Wait a bit for TCP server to bind to an address time.sleep(0.5) # Hack to find the port used by the test broker... - cls._test_broker_port = mqtt.brokers.listeners.TCPListeners.server.socket.getsockname()[1] + cls._test_broker_port = ( + mqtt.brokers.listeners.TCPListeners.server.socket.getsockname()[1] + ) setData() cleanup(cls._test_broker_port) callback = Callbacks() callback2 = Callbacks() - #aclient = mqtt_client.Client(b"\xEF\xBB\xBF" + "myclientid".encode("utf-8")) - #aclient = mqtt_client.Client("myclientid".encode("utf-8")) + # aclient = mqtt_client.Client(b"\xEF\xBB\xBF" + "myclientid".encode("utf-8")) + # aclient = mqtt_client.Client("myclientid".encode("utf-8")) aclient = paho.mqtt.client.Client(b"aclient", protocol=paho.mqtt.client.MQTTv5) callback.register(aclient) @@ -199,13 +223,14 @@ def setUpClass(cls): def tearDownClass(cls): # Another hack to stop the test broker... we rely on fact that it use a sockserver.TCPServer import mqtt.brokers + mqtt.brokers.listeners.TCPListeners.server.shutdown() cls._test_broker.join(5) def waitfor(self, queue, depth, limit): total = 0 while len(queue) < depth and total < limit: - interval = .5 + interval = 0.5 total += interval time.sleep(interval) @@ -224,7 +249,7 @@ def test_basic(self): aclient.publish(topics[0], b"qos 2", 2) i = 0 while len(callback.messages) < 3 and i < 10: - time.sleep(.2) + time.sleep(0.2) i += 1 self.assertEqual(len(callback.messages), 3) aclient.disconnect() @@ -244,7 +269,6 @@ def test_connect_fail(self): fclient.loop_stop() def test_retained_message(self): - publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.UserProperty = ("a", "2") publish_properties.UserProperty = ("c", "3") @@ -254,12 +278,15 @@ def test_retained_message(self): aclient.connect(host="localhost", port=self._test_broker_port) aclient.loop_start() response = callback.wait_connected() - aclient.publish(topics[1], b"qos 0", 0, - retain=True, properties=publish_properties) - aclient.publish(topics[2], b"qos 1", 1, - retain=True, properties=publish_properties) - aclient.publish(topics[3], b"qos 2", 2, - retain=True, properties=publish_properties) + aclient.publish( + topics[1], b"qos 0", 0, retain=True, properties=publish_properties + ) + aclient.publish( + topics[2], b"qos 1", 1, retain=True, properties=publish_properties + ) + aclient.publish( + topics[3], b"qos 2", 2, retain=True, properties=publish_properties + ) # wait until those messages are published time.sleep(1) aclient.subscribe(wildtopics[5], options=SubscribeOptions(qos=2)) @@ -272,14 +299,17 @@ def test_retained_message(self): self.assertEqual(len(callback.messages), 3) userprops = callback.messages[0]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue( + userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops + ) userprops = callback.messages[1]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue( + userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops + ) userprops = callback.messages[2]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue( + userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops + ) qoss = [callback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) @@ -296,8 +326,7 @@ def test_will_message(self): will_properties.UserProperty = ("a", "2") will_properties.UserProperty = ("c", "3") - aclient.will_set(topics[2], payload=b"will message", - properties=will_properties) + aclient.will_set(topics[2], payload=b"will message", properties=will_properties) aclient.connect(host="localhost", port=self._test_broker_port, keepalive=2) aclient.loop_start() @@ -328,11 +357,12 @@ def test_zero_length_clientid(self): callback0.register(client0) client0.loop_start() # should not be rejected - client0.connect(host="localhost", port=self._test_broker_port, clean_start=False) + client0.connect( + host="localhost", port=self._test_broker_port, clean_start=False + ) response = callback0.wait_connected() self.assertEqual(response["reasonCode"].getName(), "Success") - self.assertTrue( - len(response["properties"].AssignedClientIdentifier) > 0) + self.assertTrue(len(response["properties"].AssignedClientIdentifier) > 0) client0.disconnect() client0.loop_stop() @@ -342,21 +372,18 @@ def test_zero_length_clientid(self): client0.connect(host="localhost", port=self._test_broker_port) # should work response = callback0.wait_connected() self.assertEqual(response["reasonCode"].getName(), "Success") - self.assertTrue( - len(response["properties"].AssignedClientIdentifier) > 0) + self.assertTrue(len(response["properties"].AssignedClientIdentifier) > 0) client0.disconnect() client0.loop_stop() # when we supply a client id, we should not get one assigned - client0 = paho.mqtt.client.Client( - "client0", protocol=paho.mqtt.client.MQTTv5) + client0 = paho.mqtt.client.Client("client0", protocol=paho.mqtt.client.MQTTv5) callback0.register(client0) client0.loop_start() client0.connect(host="localhost", port=self._test_broker_port) # should work response = callback0.wait_connected() self.assertEqual(response["reasonCode"].getName(), "Success") - self.assertFalse( - hasattr(response["properties"], "AssignedClientIdentifier")) + self.assertFalse(hasattr(response["properties"], "AssignedClientIdentifier")) client0.disconnect() client0.loop_stop() @@ -366,13 +393,14 @@ def test_offline_message_queueing(self): ocallback = Callbacks() clientid = b"offline message queueing" - oclient = paho.mqtt.client.Client( - clientid, protocol=paho.mqtt.client.MQTTv5) + oclient = paho.mqtt.client.Client(clientid, protocol=paho.mqtt.client.MQTTv5) ocallback.register(oclient) connect_properties = Properties(PacketTypes.CONNECT) connect_properties.SessionExpiryInterval = 99999 oclient.loop_start() - oclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) + oclient.connect( + host="localhost", port=self._test_broker_port, properties=connect_properties + ) ocallback.wait_connected() oclient.subscribe(wildtopics[5], qos=2) ocallback.wait_subscribed() @@ -389,20 +417,22 @@ def test_offline_message_queueing(self): bclient.disconnect() bclient.loop_stop() - oclient = paho.mqtt.client.Client( - clientid, protocol=paho.mqtt.client.MQTTv5) + oclient = paho.mqtt.client.Client(clientid, protocol=paho.mqtt.client.MQTTv5) ocallback.register(oclient) oclient.loop_start() - oclient.connect(host="localhost", port=self._test_broker_port, clean_start=False) + oclient.connect( + host="localhost", port=self._test_broker_port, clean_start=False + ) ocallback.wait_connected() time.sleep(2) oclient.disconnect() oclient.loop_stop() - self.assertTrue(len(ocallback.messages) in [ - 2, 3], len(ocallback.messages)) - logging.info("This server %s queueing QoS 0 messages for offline clients" % - ("is" if len(ocallback.messages) == 3 else "is not")) + self.assertTrue(len(ocallback.messages) in [2, 3], len(ocallback.messages)) + logging.info( + "This server %s queueing QoS 0 messages for offline clients" + % ("is" if len(ocallback.messages) == 3 else "is not") + ) def test_overlapping_subscriptions(self): # overlapping subscriptions. When there is more than one matching subscription for the same client for a topic, @@ -411,15 +441,18 @@ def test_overlapping_subscriptions(self): ocallback = Callbacks() clientid = b"overlapping subscriptions" - oclient = paho.mqtt.client.Client( - clientid, protocol=paho.mqtt.client.MQTTv5) + oclient = paho.mqtt.client.Client(clientid, protocol=paho.mqtt.client.MQTTv5) ocallback.register(oclient) oclient.loop_start() oclient.connect(host="localhost", port=self._test_broker_port) ocallback.wait_connected() - oclient.subscribe([(wildtopics[6], SubscribeOptions(qos=2)), - (wildtopics[0], SubscribeOptions(qos=1))]) + oclient.subscribe( + [ + (wildtopics[6], SubscribeOptions(qos=2)), + (wildtopics[0], SubscribeOptions(qos=1)), + ] + ) ocallback.wait_subscribed() oclient.publish(topics[3], b"overlapping topic filters", 2) ocallback.wait_published() @@ -427,14 +460,28 @@ def test_overlapping_subscriptions(self): self.assertTrue(len(ocallback.messages) in [1, 2], ocallback.messages) if len(ocallback.messages) == 1: logging.info( - "This server is publishing one message for all matching overlapping subscriptions, not one for each.") + "This server is publishing one message for all matching overlapping subscriptions, not one for each." + ) self.assertEqual( - ocallback.messages[0]["message"].qos, 2, ocallback.messages[0]["message"].qos) + ocallback.messages[0]["message"].qos, + 2, + ocallback.messages[0]["message"].qos, + ) else: logging.info( - "This server is publishing one message per each matching overlapping subscription.") - self.assertTrue((ocallback.messages[0]["message"].qos == 2 and ocallback.messages[1]["message"].qos == 1) or - (ocallback.messages[0]["message"].qos == 1 and ocallback.messages[1]["message"].qos == 2), callback.messages) + "This server is publishing one message per each matching overlapping subscription." + ) + self.assertTrue( + ( + ocallback.messages[0]["message"].qos == 2 + and ocallback.messages[1]["message"].qos == 1 + ) + or ( + ocallback.messages[0]["message"].qos == 1 + and ocallback.messages[1]["message"].qos == 2 + ), + callback.messages, + ) oclient.disconnect() oclient.loop_stop() ocallback.clear() @@ -446,8 +493,7 @@ def test_subscribe_failure(self): ocallback = Callbacks() clientid = b"subscribe failure" - oclient = paho.mqtt.client.Client( - clientid, protocol=paho.mqtt.client.MQTTv5) + oclient = paho.mqtt.client.Client(clientid, protocol=paho.mqtt.client.MQTTv5) ocallback.register(oclient) oclient.loop_start() oclient.connect(host="localhost", port=self._test_broker_port) @@ -455,8 +501,11 @@ def test_subscribe_failure(self): oclient.subscribe(nosubscribe_topics[0], qos=2) response = ocallback.wait_subscribed() - self.assertEqual(response["reasonCodes"][0].getName(), "Unspecified error", - f"return code should be 0x80 {response['reasonCodes'][0].getName()}") + self.assertEqual( + response["reasonCodes"][0].getName(), + "Unspecified error", + f"return code should be 0x80 {response['reasonCodes'][0].getName()}", + ) oclient.disconnect() oclient.loop_stop() @@ -493,8 +542,9 @@ def test_unsubscribe(self): def new_client(self, clientid): callback = Callbacks() - client = paho.mqtt.client.Client(clientid.encode( - "utf-8"), protocol=paho.mqtt.client.MQTTv5) + client = paho.mqtt.client.Client( + clientid.encode("utf-8"), protocol=paho.mqtt.client.MQTTv5 + ) callback.register(client) client.loop_start() return client, callback @@ -509,7 +559,9 @@ def test_session_expiry(self): eclient, ecallback = self.new_client(clientid) - eclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) + eclient.connect( + host="localhost", port=self._test_broker_port, properties=connect_properties + ) connack = ecallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) @@ -522,8 +574,12 @@ def test_session_expiry(self): fclient, fcallback = self.new_client(clientid) # session should immediately expire - fclient.connect_async(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + fclient.connect_async( + host="localhost", + port=self._test_broker_port, + clean_start=False, + properties=connect_properties, + ) connack = fcallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) @@ -534,7 +590,9 @@ def test_session_expiry(self): eclient, ecallback = self.new_client(clientid) - eclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) + eclient.connect( + host="localhost", port=self._test_broker_port, properties=connect_properties + ) connack = ecallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) @@ -547,8 +605,12 @@ def test_session_expiry(self): time.sleep(2) # session should still exist fclient, fcallback = self.new_client(clientid) - fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + fclient.connect( + host="localhost", + port=self._test_broker_port, + clean_start=False, + properties=connect_properties, + ) connack = fcallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], True) @@ -559,8 +621,12 @@ def test_session_expiry(self): time.sleep(6) # session should not exist fclient, fcallback = self.new_client(clientid) - fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + fclient.connect( + host="localhost", + port=self._test_broker_port, + clean_start=False, + properties=connect_properties, + ) connack = fcallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) @@ -571,7 +637,8 @@ def test_session_expiry(self): eclient, ecallback = self.new_client(clientid) connect_properties.SessionExpiryInterval = 1 connack = eclient.connect( - host="localhost", port=self._test_broker_port, properties=connect_properties) + host="localhost", port=self._test_broker_port, properties=connect_properties + ) connack = ecallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) @@ -586,8 +653,12 @@ def test_session_expiry(self): time.sleep(3) # session should still exist as we changed the expiry interval on disconnect fclient, fcallback = self.new_client(clientid) - fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + fclient.connect( + host="localhost", + port=self._test_broker_port, + clean_start=False, + properties=connect_properties, + ) connack = fcallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], True) @@ -598,8 +669,12 @@ def test_session_expiry(self): # session should immediately expire fclient, fcallback = self.new_client(clientid) - fclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + fclient.connect( + host="localhost", + port=self._test_broker_port, + clean_start=False, + properties=connect_properties, + ) connack = fcallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) @@ -623,29 +698,29 @@ def test_user_properties(self): publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.UserProperty = ("a", "2") publish_properties.UserProperty = ("c", "3") - uclient.publish(topics[0], b"", 0, retain=False, - properties=publish_properties) - uclient.publish(topics[0], b"", 1, retain=False, - properties=publish_properties) - uclient.publish(topics[0], b"", 2, retain=False, - properties=publish_properties) + uclient.publish(topics[0], b"", 0, retain=False, properties=publish_properties) + uclient.publish(topics[0], b"", 1, retain=False, properties=publish_properties) + uclient.publish(topics[0], b"", 2, retain=False, properties=publish_properties) count = 0 while len(ucallback.messages) < 3 and count < 50: - time.sleep(.1) + time.sleep(0.1) count += 1 uclient.disconnect() ucallback.wait_disconnected() uclient.loop_stop() self.assertEqual(len(ucallback.messages), 3, ucallback.messages) userprops = ucallback.messages[0]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue( + userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops + ) userprops = ucallback.messages[1]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue( + userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops + ) userprops = ucallback.messages[2]["message"].properties.UserProperty - self.assertTrue(userprops in [[("a", "2"), ("c", "3")], [ - ("c", "3"), ("a", "2")]], userprops) + self.assertTrue( + userprops in [[("a", "2"), ("c", "3")], [("c", "3"), ("a", "2")]], userprops + ) qoss = [ucallback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) @@ -662,18 +737,21 @@ def test_payload_format(self): publish_properties.PayloadFormatIndicator = 1 publish_properties.ContentType = "My name" info = pclient.publish( - topics[0], b"qos 0", 0, retain=False, properties=publish_properties) + topics[0], b"qos 0", 0, retain=False, properties=publish_properties + ) info.wait_for_publish() info = pclient.publish( - topics[0], b"qos 1", 1, retain=False, properties=publish_properties) + topics[0], b"qos 1", 1, retain=False, properties=publish_properties + ) info.wait_for_publish() info = pclient.publish( - topics[0], b"qos 2", 2, retain=False, properties=publish_properties) + topics[0], b"qos 2", 2, retain=False, properties=publish_properties + ) info.wait_for_publish() count = 0 while len(pcallback.messages) < 3 and count < 50: - time.sleep(.1) + time.sleep(0.1) count += 1 pclient.disconnect() pcallback.wait_disconnected() @@ -682,16 +760,13 @@ def test_payload_format(self): self.assertEqual(len(pcallback.messages), 3, pcallback.messages) props = pcallback.messages[0]["message"].properties self.assertEqual(props.ContentType, "My name", props.ContentType) - self.assertEqual(props.PayloadFormatIndicator, - 1, props.PayloadFormatIndicator) + self.assertEqual(props.PayloadFormatIndicator, 1, props.PayloadFormatIndicator) props = pcallback.messages[1]["message"].properties self.assertEqual(props.ContentType, "My name", props.ContentType) - self.assertEqual(props.PayloadFormatIndicator, - 1, props.PayloadFormatIndicator) + self.assertEqual(props.PayloadFormatIndicator, 1, props.PayloadFormatIndicator) props = pcallback.messages[2]["message"].properties self.assertEqual(props.ContentType, "My name", props.ContentType) - self.assertEqual(props.PayloadFormatIndicator, - 1, props.PayloadFormatIndicator) + self.assertEqual(props.PayloadFormatIndicator, 1, props.PayloadFormatIndicator) qoss = [pcallback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) @@ -703,7 +778,9 @@ def test_message_expiry(self): lbclient, lbcallback = self.new_client(f"{clientid} b") lbclient.loop_start() - lbclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) + lbclient.connect( + host="localhost", port=self._test_broker_port, properties=connect_properties + ) lbcallback.wait_connected() lbclient.subscribe(topics[0], qos=2) lbcallback.wait_subscribed() @@ -718,30 +795,48 @@ def test_message_expiry(self): laclient.connect(host="localhost", port=self._test_broker_port) publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.MessageExpiryInterval = 1 - laclient.publish(topics[0], b"qos 1 - expire", 1, - retain=False, properties=publish_properties) - laclient.publish(topics[0], b"qos 2 - expire", 2, - retain=False, properties=publish_properties) + laclient.publish( + topics[0], b"qos 1 - expire", 1, retain=False, properties=publish_properties + ) + laclient.publish( + topics[0], b"qos 2 - expire", 2, retain=False, properties=publish_properties + ) publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.MessageExpiryInterval = 6 - laclient.publish(topics[0], b"qos 1 - don't expire", - 1, retain=False, properties=publish_properties) - laclient.publish(topics[0], b"qos 2 - don't expire", - 2, retain=False, properties=publish_properties) + laclient.publish( + topics[0], + b"qos 1 - don't expire", + 1, + retain=False, + properties=publish_properties, + ) + laclient.publish( + topics[0], + b"qos 2 - don't expire", + 2, + retain=False, + properties=publish_properties, + ) time.sleep(3) lbclient, lbcallback = self.new_client(f"{clientid} b") lbclient.loop_start() - lbclient.connect(host="localhost", port=self._test_broker_port, clean_start=False) + lbclient.connect( + host="localhost", port=self._test_broker_port, clean_start=False + ) lbcallback.wait_connected() self.waitfor(lbcallback.messages, 1, 3) time.sleep(1) self.assertEqual(len(lbcallback.messages), 2, lbcallback.messages) - self.assertTrue(lbcallback.messages[0]["message"].properties.MessageExpiryInterval < 6, - lbcallback.messages[0]["message"].properties.MessageExpiryInterval) - self.assertTrue(lbcallback.messages[1]["message"].properties.MessageExpiryInterval < 6, - lbcallback.messages[1]["message"].properties.MessageExpiryInterval) + self.assertTrue( + lbcallback.messages[0]["message"].properties.MessageExpiryInterval < 6, + lbcallback.messages[0]["message"].properties.MessageExpiryInterval, + ) + self.assertTrue( + lbcallback.messages[1]["message"].properties.MessageExpiryInterval < 6, + lbcallback.messages[1]["message"].properties.MessageExpiryInterval, + ) laclient.disconnect() lacallback.wait_disconnected() laclient.loop_stop() @@ -752,22 +847,20 @@ def test_message_expiry(self): def test_subscribe_options(self): # noLocal - clientid = 'subscribe options - noLocal' + clientid = "subscribe options - noLocal" laclient, lacallback = self.new_client(f"{clientid} a") laclient.connect(host="localhost", port=self._test_broker_port) lacallback.wait_connected() laclient.loop_start() - laclient.subscribe( - topics[0], options=SubscribeOptions(qos=2, noLocal=True)) + laclient.subscribe(topics[0], options=SubscribeOptions(qos=2, noLocal=True)) lacallback.wait_subscribed() lbclient, lbcallback = self.new_client(f"{clientid} b") lbclient.connect(host="localhost", port=self._test_broker_port) lbcallback.wait_connected() lbclient.loop_start() - lbclient.subscribe( - topics[0], options=SubscribeOptions(qos=2, noLocal=True)) + lbclient.subscribe(topics[0], options=SubscribeOptions(qos=2, noLocal=True)) lbcallback.wait_subscribed() laclient.publish(topics[0], b"noLocal test", 1, retain=False) @@ -784,18 +877,17 @@ def test_subscribe_options(self): lbclient.loop_stop() # retainAsPublished - clientid = 'subscribe options - retain as published' + clientid = "subscribe options - retain as published" laclient, lacallback = self.new_client(f"{clientid} a") laclient.connect(host="localhost", port=self._test_broker_port) lacallback.wait_connected() - laclient.subscribe(topics[0], options=SubscribeOptions( - qos=2, retainAsPublished=True)) + laclient.subscribe( + topics[0], options=SubscribeOptions(qos=2, retainAsPublished=True) + ) lacallback.wait_subscribed() self.waitfor(lacallback.subscribeds, 1, 3) - laclient.publish( - topics[0], b"retain as published false", 1, retain=False) - laclient.publish( - topics[0], b"retain as published true", 1, retain=True) + laclient.publish(topics[0], b"retain as published false", 1, retain=False) + laclient.publish(topics[0], b"retain as published true", 1, retain=True) self.waitfor(lacallback.messages, 2, 3) time.sleep(1) @@ -808,7 +900,7 @@ def test_subscribe_options(self): self.assertEqual(lacallback.messages[1]["message"].retain, True) # retainHandling - clientid = 'subscribe options - retain handling' + clientid = "subscribe options - retain handling" laclient, lacallback = self.new_client(f"{clientid} a") laclient.connect(host="localhost", port=self._test_broker_port) lacallback.wait_connected() @@ -818,15 +910,13 @@ def test_subscribe_options(self): time.sleep(1) # retain handling 1 only gives us retained messages on a new subscription - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) lacallback.wait_subscribed() self.assertEqual(len(lacallback.messages), 3) qoss = [lacallback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) lacallback.clear() - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) lacallback.wait_subscribed() time.sleep(1) self.assertEqual(len(lacallback.messages), 0) @@ -839,15 +929,13 @@ def test_subscribe_options(self): lacallback.wait_unsubscribed() # check that we really did remove that subscription - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) lacallback.wait_subscribed() self.assertEqual(len(lacallback.messages), 3) qoss = [lacallback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) lacallback.clear() - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=1)) lacallback.wait_subscribed() time.sleep(1) self.assertEqual(len(lacallback.messages), 0) @@ -860,12 +948,10 @@ def test_subscribe_options(self): lacallback.wait_unsubscribed() lacallback.clear() - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=2)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=2)) lacallback.wait_subscribed() self.assertEqual(len(lacallback.messages), 0) - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=2)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=2)) lacallback.wait_subscribed() self.assertEqual(len(lacallback.messages), 0) @@ -873,15 +959,13 @@ def test_subscribe_options(self): laclient.unsubscribe(wildtopics[5]) lacallback.wait_unsubscribed() - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=0)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=0)) lacallback.wait_subscribed() self.assertEqual(len(lacallback.messages), 3) qoss = [lacallback.messages[i]["message"].qos for i in range(3)] self.assertTrue(1 in qoss and 2 in qoss and 0 in qoss, qoss) lacallback.clear() - laclient.subscribe( - wildtopics[5], options=SubscribeOptions(2, retainHandling=0)) + laclient.subscribe(wildtopics[5], options=SubscribeOptions(2, retainHandling=0)) time.sleep(1) self.assertEqual(len(lacallback.messages), 3) qoss = [lacallback.messages[i]["message"].qos for i in range(3)] @@ -893,7 +977,7 @@ def test_subscribe_options(self): cleanRetained(self._test_broker_port) def test_subscription_identifiers(self): - clientid = 'subscription identifiers' + clientid = "subscription identifiers" laclient, lacallback = self.new_client(f"{clientid} a") laclient.connect(host="localhost", port=self._test_broker_port) @@ -922,8 +1006,11 @@ def test_subscription_identifiers(self): self.waitfor(lacallback.messages, 1, 3) self.assertEqual(len(lacallback.messages), 1, lacallback.messages) - self.assertEqual(lacallback.messages[0]["message"].properties.SubscriptionIdentifier[0], - 456789, lacallback.messages[0]["message"].properties.SubscriptionIdentifier) + self.assertEqual( + lacallback.messages[0]["message"].properties.SubscriptionIdentifier[0], + 456789, + lacallback.messages[0]["message"].properties.SubscriptionIdentifier, + ) laclient.disconnect() lacallback.wait_disconnected() laclient.loop_stop() @@ -932,14 +1019,15 @@ def test_subscription_identifiers(self): self.assertEqual(len(lbcallback.messages), 1, lbcallback.messages) expected_subsids = set([2, 3]) received_subsids = set( - lbcallback.messages[0]["message"].properties.SubscriptionIdentifier) + lbcallback.messages[0]["message"].properties.SubscriptionIdentifier + ) self.assertEqual(received_subsids, expected_subsids, received_subsids) lbclient.disconnect() lbcallback.wait_disconnected() lbclient.loop_stop() def test_request_response(self): - clientid = 'request response' + clientid = "request response" laclient, lacallback = self.new_client(f"{clientid} a") laclient.connect(host="localhost", port=self._test_broker_port) @@ -951,31 +1039,38 @@ def test_request_response(self): lbcallback.wait_connected() lbclient.loop_start() - laclient.subscribe( - topics[0], options=SubscribeOptions(2, noLocal=True)) + laclient.subscribe(topics[0], options=SubscribeOptions(2, noLocal=True)) lacallback.wait_subscribed() - lbclient.subscribe( - topics[0], options=SubscribeOptions(2, noLocal=True)) + lbclient.subscribe(topics[0], options=SubscribeOptions(2, noLocal=True)) lbcallback.wait_subscribed() publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.ResponseTopic = topics[0] publish_properties.CorrelationData = b"334" # client a is the requester - laclient.publish(topics[0], b"request", 1, - properties=publish_properties) + laclient.publish(topics[0], b"request", 1, properties=publish_properties) # client b is the responder self.waitfor(lbcallback.messages, 1, 3) self.assertEqual(len(lbcallback.messages), 1, lbcallback.messages) - self.assertEqual(lbcallback.messages[0]["message"].properties.ResponseTopic, topics[0], - lbcallback.messages[0]["message"].properties) - self.assertEqual(lbcallback.messages[0]["message"].properties.CorrelationData, b"334", - lbcallback.messages[0]["message"].properties) - - lbclient.publish(lbcallback.messages[0]["message"].properties.ResponseTopic, b"response", 1, - properties=lbcallback.messages[0]["message"].properties) + self.assertEqual( + lbcallback.messages[0]["message"].properties.ResponseTopic, + topics[0], + lbcallback.messages[0]["message"].properties, + ) + self.assertEqual( + lbcallback.messages[0]["message"].properties.CorrelationData, + b"334", + lbcallback.messages[0]["message"].properties, + ) + + lbclient.publish( + lbcallback.messages[0]["message"].properties.ResponseTopic, + b"response", + 1, + properties=lbcallback.messages[0]["message"].properties, + ) # client a gets the response self.waitfor(lacallback.messages, 1, 3) @@ -989,13 +1084,15 @@ def test_request_response(self): lbclient.loop_stop() def test_client_topic_alias(self): - clientid = 'client topic alias' + clientid = "client topic alias" connect_properties = Properties(PacketTypes.CONNECT) connect_properties.TopicAliasMaximum = 0 # server topic aliases not allowed connect_properties.SessionExpiryInterval = 99999 laclient, lacallback = self.new_client(f"{clientid} a") - laclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) + laclient.connect( + host="localhost", port=self._test_broker_port, properties=connect_properties + ) connack = lacallback.wait_connected() clientTopicAliasMaximum = 0 if hasattr(connack["properties"], "TopicAliasMaximum"): @@ -1012,13 +1109,11 @@ def test_client_topic_alias(self): publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.TopicAlias = 1 - laclient.publish(topics[0], b"topic alias 1", - 1, properties=publish_properties) + laclient.publish(topics[0], b"topic alias 1", 1, properties=publish_properties) self.waitfor(lacallback.messages, 1, 3) self.assertEqual(len(lacallback.messages), 1, lacallback.messages) - laclient.publish("", b"topic alias 2", 1, - properties=publish_properties) + laclient.publish("", b"topic alias 2", 1, properties=publish_properties) self.waitfor(lacallback.messages, 2, 3) self.assertEqual(len(lacallback.messages), 2, lacallback.messages) @@ -1028,8 +1123,12 @@ def test_client_topic_alias(self): # check aliases have been deleted laclient, lacallback = self.new_client(f"{clientid} a") - laclient.connect(host="localhost", port=self._test_broker_port, clean_start=False, - properties=connect_properties) + laclient.connect( + host="localhost", + port=self._test_broker_port, + clean_start=False, + properties=connect_properties, + ) laclient.publish(topics[0], b"topic alias 3", 1) self.waitfor(lacallback.messages, 1, 3) @@ -1037,22 +1136,23 @@ def test_client_topic_alias(self): publish_properties = Properties(PacketTypes.PUBLISH) publish_properties.TopicAlias = 1 - laclient.publish("", b"topic alias 4", 1, - properties=publish_properties) + laclient.publish("", b"topic alias 4", 1, properties=publish_properties) # should get back a disconnect with Topic alias invalid lacallback.wait_disconnected() laclient.loop_stop() def test_server_topic_alias(self): - clientid = 'server topic alias' + clientid = "server topic alias" serverTopicAliasMaximum = 1 # server topic alias allowed connect_properties = Properties(PacketTypes.CONNECT) connect_properties.TopicAliasMaximum = serverTopicAliasMaximum laclient, lacallback = self.new_client(f"{clientid} a") - laclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) + laclient.connect( + host="localhost", port=self._test_broker_port, properties=connect_properties + ) lacallback.wait_connected() laclient.loop_start() @@ -1068,19 +1168,23 @@ def test_server_topic_alias(self): laclient.loop_stop() # first message should set the topic alias - self.assertTrue(hasattr( - lacallback.messages[0]["message"].properties, "TopicAlias"), lacallback.messages[0]["message"].properties) + self.assertTrue( + hasattr(lacallback.messages[0]["message"].properties, "TopicAlias"), + lacallback.messages[0]["message"].properties, + ) topicalias = lacallback.messages[0]["message"].properties.TopicAlias self.assertTrue(topicalias > 0) self.assertEqual(lacallback.messages[0]["message"].topic, topics[0]) self.assertEqual( - lacallback.messages[1]["message"].properties.TopicAlias, topicalias) + lacallback.messages[1]["message"].properties.TopicAlias, topicalias + ) self.assertEqual(lacallback.messages[1]["message"].topic, "") self.assertEqual( - lacallback.messages[2]["message"].properties.TopicAlias, topicalias) + lacallback.messages[2]["message"].properties.TopicAlias, topicalias + ) self.assertEqual(lacallback.messages[2]["message"].topic, "") serverTopicAliasMaximum = 0 # no server topic alias allowed @@ -1088,7 +1192,9 @@ def test_server_topic_alias(self): # connect_properties.TopicAliasMaximum = serverTopicAliasMaximum # default is 0 laclient, lacallback = self.new_client(f"{clientid} a") - laclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) + laclient.connect( + host="localhost", port=self._test_broker_port, properties=connect_properties + ) lacallback.wait_connected() laclient.loop_start() @@ -1104,19 +1210,27 @@ def test_server_topic_alias(self): laclient.loop_stop() # No topic aliases - self.assertFalse(hasattr( - lacallback.messages[0]["message"].properties, "TopicAlias"), lacallback.messages[0]["message"].properties) - self.assertFalse(hasattr( - lacallback.messages[1]["message"].properties, "TopicAlias"), lacallback.messages[1]["message"].properties) - self.assertFalse(hasattr( - lacallback.messages[2]["message"].properties, "TopicAlias"), lacallback.messages[2]["message"].properties) + self.assertFalse( + hasattr(lacallback.messages[0]["message"].properties, "TopicAlias"), + lacallback.messages[0]["message"].properties, + ) + self.assertFalse( + hasattr(lacallback.messages[1]["message"].properties, "TopicAlias"), + lacallback.messages[1]["message"].properties, + ) + self.assertFalse( + hasattr(lacallback.messages[2]["message"].properties, "TopicAlias"), + lacallback.messages[2]["message"].properties, + ) serverTopicAliasMaximum = 0 # no server topic alias allowed connect_properties = Properties(PacketTypes.CONNECT) connect_properties.TopicAliasMaximum = serverTopicAliasMaximum # default is 0 laclient, lacallback = self.new_client(f"{clientid} a") - laclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) + laclient.connect( + host="localhost", port=self._test_broker_port, properties=connect_properties + ) lacallback.wait_connected() laclient.loop_start() @@ -1132,15 +1246,21 @@ def test_server_topic_alias(self): laclient.loop_stop() # No topic aliases - self.assertFalse(hasattr( - lacallback.messages[0]["message"].properties, "TopicAlias"), lacallback.messages[0]["message"].properties) - self.assertFalse(hasattr( - lacallback.messages[1]["message"].properties, "TopicAlias"), lacallback.messages[1]["message"].properties) - self.assertFalse(hasattr( - lacallback.messages[2]["message"].properties, "TopicAlias"), lacallback.messages[2]["message"].properties) + self.assertFalse( + hasattr(lacallback.messages[0]["message"].properties, "TopicAlias"), + lacallback.messages[0]["message"].properties, + ) + self.assertFalse( + hasattr(lacallback.messages[1]["message"].properties, "TopicAlias"), + lacallback.messages[1]["message"].properties, + ) + self.assertFalse( + hasattr(lacallback.messages[2]["message"].properties, "TopicAlias"), + lacallback.messages[2]["message"].properties, + ) def test_maximum_packet_size(self): - clientid = 'maximum packet size' + clientid = "maximum packet size" # 1. server max packet size laclient, lacallback = self.new_client(f"{clientid} a") @@ -1148,20 +1268,22 @@ def test_maximum_packet_size(self): connack = lacallback.wait_connected() laclient.loop_start() - serverMaximumPacketSize = 2**28-1 + serverMaximumPacketSize = 2**28 - 1 if hasattr(connack["properties"], "MaximumPacketSize"): serverMaximumPacketSize = connack["properties"].MaximumPacketSize if serverMaximumPacketSize < 65535: # publish bigger packet than server can accept - payload = b"."*serverMaximumPacketSize + payload = b"." * serverMaximumPacketSize laclient.publish(topics[0], payload, 0) # should get back a disconnect with packet size too big response = lacallback.wait_disconnected() - self.assertEqual(len(lacallback.disconnecteds), - 0, lacallback.disconnecteds) - self.assertEqual(response["reasonCode"].getName(), - "Packet too large", response["reasonCode"].getName()) + self.assertEqual(len(lacallback.disconnecteds), 0, lacallback.disconnecteds) + self.assertEqual( + response["reasonCode"].getName(), + "Packet too large", + response["reasonCode"].getName(), + ) else: laclient.disconnect() lacallback.wait_disconnected() @@ -1173,11 +1295,13 @@ def test_maximum_packet_size(self): connect_properties.MaximumPacketSize = maximumPacketSize laclient, lacallback = self.new_client(f"{clientid} a") - laclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) + laclient.connect( + host="localhost", port=self._test_broker_port, properties=connect_properties + ) connack = lacallback.wait_connected() laclient.loop_start() - serverMaximumPacketSize = 2**28-1 + serverMaximumPacketSize = 2**28 - 1 if hasattr(connack["properties"], "MaximumPacketSize"): serverMaximumPacketSize = connack["properties"].MaximumPacketSize @@ -1185,13 +1309,13 @@ def test_maximum_packet_size(self): response = lacallback.wait_subscribed() # send a small enough packet, should get this one back - payload = b"."*(int(maximumPacketSize/2)) + payload = b"." * (int(maximumPacketSize / 2)) laclient.publish(topics[0], payload, 0) self.waitfor(lacallback.messages, 1, 3) self.assertEqual(len(lacallback.messages), 1, lacallback.messages) # send a packet too big to receive - payload = b"."*maximumPacketSize + payload = b"." * maximumPacketSize laclient.publish(topics[0], payload, 1) self.waitfor(lacallback.messages, 2, 3) self.assertEqual(len(lacallback.messages), 1, lacallback.messages) @@ -1220,7 +1344,7 @@ def test_server_keep_alive(self): def test_will_delay(self): # the will message should be received earlier than the session expiry - clientid = 'will delay' + clientid = "will delay" will_properties = Properties(PacketTypes.WILLMESSAGE) connect_properties = Properties(PacketTypes.CONNECT) @@ -1232,15 +1356,22 @@ def test_will_delay(self): laclient, lacallback = self.new_client(f"{clientid} a") laclient.will_set( - topics[0], payload=b"test_will_delay will message", properties=will_properties) - laclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) + topics[0], + payload=b"test_will_delay will message", + properties=will_properties, + ) + laclient.connect( + host="localhost", port=self._test_broker_port, properties=connect_properties + ) connack = lacallback.wait_connected() self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) laclient.loop_start() lbclient, lbcallback = self.new_client(f"{clientid} b") - lbclient.connect(host="localhost", port=self._test_broker_port, properties=connect_properties) + lbclient.connect( + host="localhost", port=self._test_broker_port, properties=connect_properties + ) connack = lbcallback.wait_connected() lbclient.loop_start() # subscribe to will message topic @@ -1252,19 +1383,20 @@ def test_will_delay(self): laclient.socket().close() start = time.time() while lbcallback.messages == []: - time.sleep(.1) + time.sleep(0.1) duration = time.time() - start self.assertAlmostEqual(duration, 4, delta=1) self.assertEqual(lbcallback.messages[0]["message"].topic, topics[0]) self.assertEqual( - lbcallback.messages[0]["message"].payload, b"test_will_delay will message") + lbcallback.messages[0]["message"].payload, b"test_will_delay will message" + ) lbclient.disconnect() lbcallback.wait_disconnected() lbclient.loop_stop() def test_shared_subscriptions(self): - clientid = 'shared subscriptions' + clientid = "shared subscriptions" shared_sub_topic = f"$share/sharename/{topic_prefix}x" shared_pub_topic = f"{topic_prefix}x" @@ -1278,7 +1410,8 @@ def test_shared_subscriptions(self): self.assertEqual(connack["flags"]["session present"], False) laclient.subscribe( - [(shared_sub_topic, SubscribeOptions(2)), (topics[0], SubscribeOptions(2))]) + [(shared_sub_topic, SubscribeOptions(2)), (topics[0], SubscribeOptions(2))] + ) lacallback.wait_subscribed() lbclient, lbcallback = self.new_client(f"{clientid} b") @@ -1289,8 +1422,7 @@ def test_shared_subscriptions(self): self.assertEqual(connack["reasonCode"].getName(), "Success") self.assertEqual(connack["flags"]["session present"], False) - lbclient.subscribe( - [(shared_sub_topic, SubscribeOptions(2)), (topics[0], 2)]) + lbclient.subscribe([(shared_sub_topic, SubscribeOptions(2)), (topics[0], 2)]) lbcallback.wait_subscribed() lacallback.clear() @@ -1300,8 +1432,10 @@ def test_shared_subscriptions(self): for i in range(count): lbclient.publish(topics[0], f"message {i}", 0) j = 0 - while len(lacallback.messages) + len(lbcallback.messages) < 2*count and j < 20: - time.sleep(.1) + while ( + len(lacallback.messages) + len(lbcallback.messages) < 2 * count and j < 20 + ): + time.sleep(0.1) j += 1 time.sleep(1) self.assertEqual(len(lacallback.messages), count) @@ -1314,12 +1448,11 @@ def test_shared_subscriptions(self): lbclient.publish(shared_pub_topic, f"message {i}", 0) j = 0 while len(lacallback.messages) + len(lbcallback.messages) < count and j < 20: - time.sleep(.1) + time.sleep(0.1) j += 1 time.sleep(1) # Each message should only be received once - self.assertEqual(len(lacallback.messages) + - len(lbcallback.messages), count) + self.assertEqual(len(lacallback.messages) + len(lbcallback.messages), count) laclient.disconnect() lacallback.wait_disconnected() diff --git a/tests/test_websocket_integration.py b/tests/test_websocket_integration.py index 80872f9e..2d923e73 100644 --- a/tests/test_websocket_integration.py +++ b/tests/test_websocket_integration.py @@ -14,65 +14,75 @@ @pytest.fixture def init_response_headers(): # "Normal" websocket response from server - response_headers = OrderedDict([ - ("Upgrade", "websocket"), - ("Connection", "Upgrade"), - ("Sec-WebSocket-Accept", "testwebsocketkey"), - ("Sec-WebSocket-Protocol", "chat"), - ]) + response_headers = OrderedDict( + [ + ("Upgrade", "websocket"), + ("Connection", "Upgrade"), + ("Sec-WebSocket-Accept", "testwebsocketkey"), + ("Sec-WebSocket-Protocol", "chat"), + ] + ) return response_headers def get_websocket_response(response_headers): - """ Takes headers and constructs HTTP response + """Takes headers and constructs HTTP response 'HTTP/1.1 101 Switching Protocols' is the headers for the response, as expected in client.py """ - response = "\r\n".join([ - "HTTP/1.1 101 Switching Protocols", - "\r\n".join(f"{i}: {j}" for i, j in response_headers.items()), - "\r\n", - ]).encode("utf8") + response = "\r\n".join( + [ + "HTTP/1.1 101 Switching Protocols", + "\r\n".join(f"{i}: {j}" for i, j in response_headers.items()), + "\r\n", + ] + ).encode("utf8") return response -@pytest.mark.parametrize("proto_ver,proto_name", [ - (client.MQTTv31, "MQIsdp"), - (client.MQTTv311, "MQTT"), -]) +@pytest.mark.parametrize( + "proto_ver,proto_name", + [ + (client.MQTTv31, "MQIsdp"), + (client.MQTTv311, "MQTT"), + ], +) class TestInvalidWebsocketResponse: def test_unexpected_response(self, proto_ver, proto_name, fake_websocket_broker): - """ Server responds with a valid code, but it's not what the client expected """ + """Server responds with a valid code, but it's not what the client expected""" mqttc = client.Client( - "test_unexpected_response", - protocol=proto_ver, - transport="websockets" - ) + "test_unexpected_response", protocol=proto_ver, transport="websockets" + ) class WebsocketHandler(socketserver.BaseRequestHandler): def handle(_self): # Respond with data passed in to serve() _self.request.sendall(b"200 OK") - with fake_websocket_broker.serve(WebsocketHandler), pytest.raises(WebsocketConnectionError) as exc: + with fake_websocket_broker.serve(WebsocketHandler), pytest.raises( + WebsocketConnectionError + ) as exc: mqttc.connect("localhost", fake_websocket_broker.port, keepalive=10) assert str(exc.value) == "WebSocket handshake error" -@pytest.mark.parametrize("proto_ver,proto_name", [ - (client.MQTTv31, "MQIsdp"), - (client.MQTTv311, "MQTT"), -]) +@pytest.mark.parametrize( + "proto_ver,proto_name", + [ + (client.MQTTv31, "MQIsdp"), + (client.MQTTv311, "MQTT"), + ], +) class TestBadWebsocketHeaders: - """ Testing for basic functionality in checking for headers """ + """Testing for basic functionality in checking for headers""" def _get_basic_handler(self, response_headers): - """ Get a basic BaseRequestHandler which returns the information in + """Get a basic BaseRequestHandler which returns the information in self._response_headers """ @@ -81,64 +91,69 @@ def _get_basic_handler(self, response_headers): class WebsocketHandler(socketserver.BaseRequestHandler): def handle(_self): self.data = _self.request.recv(1024).strip() - print('Received', self.data.decode('utf8')) + print("Received", self.data.decode("utf8")) # Respond with data passed in to serve() _self.request.sendall(response) return WebsocketHandler - def test_no_upgrade(self, proto_ver, proto_name, fake_websocket_broker, - init_response_headers): - """ Server doesn't respond with 'connection: upgrade' """ + def test_no_upgrade( + self, proto_ver, proto_name, fake_websocket_broker, init_response_headers + ): + """Server doesn't respond with 'connection: upgrade'""" mqttc = client.Client( - "test_no_upgrade", - protocol=proto_ver, - transport="websockets" - ) + "test_no_upgrade", protocol=proto_ver, transport="websockets" + ) init_response_headers["Connection"] = "bad" response = self._get_basic_handler(init_response_headers) - with fake_websocket_broker.serve(response), pytest.raises(WebsocketConnectionError) as exc: + with fake_websocket_broker.serve(response), pytest.raises( + WebsocketConnectionError + ) as exc: mqttc.connect("localhost", fake_websocket_broker.port, keepalive=10) assert str(exc.value) == "WebSocket handshake error, connection not upgraded" - def test_bad_secret_key(self, proto_ver, proto_name, fake_websocket_broker, - init_response_headers): - """ Server doesn't give anything after connection: upgrade """ + def test_bad_secret_key( + self, proto_ver, proto_name, fake_websocket_broker, init_response_headers + ): + """Server doesn't give anything after connection: upgrade""" mqttc = client.Client( - "test_bad_secret_key", - protocol=proto_ver, - transport="websockets" - ) + "test_bad_secret_key", protocol=proto_ver, transport="websockets" + ) response = self._get_basic_handler(init_response_headers) - with fake_websocket_broker.serve(response), pytest.raises(WebsocketConnectionError) as exc: + with fake_websocket_broker.serve(response), pytest.raises( + WebsocketConnectionError + ) as exc: mqttc.connect("localhost", fake_websocket_broker.port, keepalive=10) assert str(exc.value) == "WebSocket handshake error, invalid secret key" -@pytest.mark.parametrize("proto_ver,proto_name", [ - (client.MQTTv31, "MQIsdp"), - (client.MQTTv311, "MQTT"), -]) +@pytest.mark.parametrize( + "proto_ver,proto_name", + [ + (client.MQTTv31, "MQIsdp"), + (client.MQTTv311, "MQTT"), + ], +) class TestValidHeaders: - """ Testing for functionality in request/response headers """ + """Testing for functionality in request/response headers""" def _get_callback_handler(self, response_headers, check_request=None): - """ Get a basic BaseRequestHandler which returns the information in + """Get a basic BaseRequestHandler which returns the information in self._response_headers """ class WebsocketHandler(socketserver.BaseRequestHandler): def handle(_self): self.data = _self.request.recv(1024).strip() - print('Received', self.data.decode('utf8')) + print("Received", self.data.decode("utf8")) decoded = self.data.decode("utf8") @@ -147,7 +162,9 @@ def handle(_self): # Create server hash GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" - key = re.search("sec-websocket-key: ([A-Za-z0-9+/=]*)", decoded, re.IGNORECASE).group(1) + key = re.search( + "sec-websocket-key: ([A-Za-z0-9+/=]*)", decoded, re.IGNORECASE + ).group(1) to_hash = f"{key:s}{GUID:s}" hashed = hashlib.sha1(to_hash.encode("utf8")) # noqa: S324 @@ -162,16 +179,14 @@ def handle(_self): return WebsocketHandler - def test_successful_connection(self, proto_ver, proto_name, - fake_websocket_broker, - init_response_headers): - """ Connect successfully, on correct path """ + def test_successful_connection( + self, proto_ver, proto_name, fake_websocket_broker, init_response_headers + ): + """Connect successfully, on correct path""" mqttc = client.Client( - "test_successful_connection", - protocol=proto_ver, - transport="websockets" - ) + "test_successful_connection", protocol=proto_ver, transport="websockets" + ) response = self._get_callback_handler(init_response_headers) @@ -180,20 +195,26 @@ def test_successful_connection(self, proto_ver, proto_name, mqttc.disconnect() - @pytest.mark.parametrize("mqtt_path", [ - "/mqtt" - "/special", - None, - ]) - def test_correct_path(self, proto_ver, proto_name, fake_websocket_broker, - mqtt_path, init_response_headers): - """ Make sure it can connect on user specified paths """ + @pytest.mark.parametrize( + "mqtt_path", + [ + "/mqtt" "/special", + None, + ], + ) + def test_correct_path( + self, + proto_ver, + proto_name, + fake_websocket_broker, + mqtt_path, + init_response_headers, + ): + """Make sure it can connect on user specified paths""" mqttc = client.Client( - "test_correct_path", - protocol=proto_ver, - transport="websockets" - ) + "test_correct_path", protocol=proto_ver, transport="websockets" + ) mqttc.ws_set_options( path=mqtt_path, @@ -202,7 +223,10 @@ def test_correct_path(self, proto_ver, proto_name, fake_websocket_broker, def check_path_correct(decoded): # Make sure it connects to the right path if mqtt_path: - assert re.search(f"GET {mqtt_path} HTTP/1.1", decoded, re.IGNORECASE) is not None + assert ( + re.search(f"GET {mqtt_path} HTTP/1.1", decoded, re.IGNORECASE) + is not None + ) response = self._get_callback_handler( init_response_headers, @@ -214,21 +238,28 @@ def check_path_correct(decoded): mqttc.disconnect() - @pytest.mark.parametrize("auth_headers", [ - {"Authorization": "test123"}, - {"Authorization": "test123", "auth2": "abcdef"}, - # Won't be checked, but make sure it still works even if the user passes it - None, - ]) - def test_correct_auth(self, proto_ver, proto_name, fake_websocket_broker, - auth_headers, init_response_headers): - """ Make sure it sends the right auth headers """ + @pytest.mark.parametrize( + "auth_headers", + [ + {"Authorization": "test123"}, + {"Authorization": "test123", "auth2": "abcdef"}, + # Won't be checked, but make sure it still works even if the user passes it + None, + ], + ) + def test_correct_auth( + self, + proto_ver, + proto_name, + fake_websocket_broker, + auth_headers, + init_response_headers, + ): + """Make sure it sends the right auth headers""" mqttc = client.Client( - "test_correct_path", - protocol=proto_ver, - transport="websockets" - ) + "test_correct_path", protocol=proto_ver, transport="websockets" + ) mqttc.ws_set_options( headers=auth_headers, diff --git a/tests/test_websockets.py b/tests/test_websockets.py index f2605a3a..5fc267a9 100644 --- a/tests/test_websockets.py +++ b/tests/test_websockets.py @@ -6,10 +6,10 @@ class TestHeaders: - """ Make sure headers are used correctly """ + """Make sure headers are used correctly""" def test_normal_headers(self): - """ Normal headers as specified in RFC 6455 """ + """Normal headers as specified in RFC 6455""" response = [ "HTTP/1.1 101 Switching Protocols", @@ -54,15 +54,18 @@ def fakerecv(*args): # error assert str(exc.value) == "WebSocket handshake error, invalid secret key" - expected_sent = [i.format(**wargs) for i in [ - "GET {path:s} HTTP/1.1", - "Host: {host:s}", - "Upgrade: websocket", - "Connection: Upgrade", - "Sec-Websocket-Protocol: mqtt", - "Sec-Websocket-Version: 13", - "Origin: https://{host:s}:{port:d}", - ]] + expected_sent = [ + i.format(**wargs) + for i in [ + "GET {path:s} HTTP/1.1", + "Host: {host:s}", + "Upgrade: websocket", + "Connection: Upgrade", + "Sec-Websocket-Protocol: mqtt", + "Sec-Websocket-Version: 13", + "Origin: https://{host:s}:{port:d}", + ] + ] # Only sends the header once assert mocksock.send.call_count == 1 diff --git a/tests/testsupport/broker.py b/tests/testsupport/broker.py index 6cf72c7f..4f4cf7a4 100644 --- a/tests/testsupport/broker.py +++ b/tests/testsupport/broker.py @@ -22,7 +22,7 @@ def __init__(self): def start(self): if self._sock is None: - raise ValueError('Socket is not open') + raise ValueError("Socket is not open") (conn, address) = self._sock.accept() conn.settimeout(10) @@ -39,14 +39,14 @@ def finish(self): def receive_packet(self, num_bytes): if self._conn is None: - raise ValueError('Connection is not open') + raise ValueError("Connection is not open") packet_in = self._conn.recv(num_bytes) return packet_in def send_packet(self, packet_out): if self._conn is None: - raise ValueError('Connection is not open') + raise ValueError("Connection is not open") count = self._conn.send(packet_out) return count diff --git a/tox.ini b/tox.ini index 48dea4c9..e5fbf849 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,6 @@ envlist = py{37,38,39,310,311,312} whitelist_externals = echo make deps = -rrequirements.txt - ruff==0.1.8 allowlist_externals = echo make @@ -16,16 +15,13 @@ env = [testenv:lint] deps = - -e . - black - codespell + -e .[proxy] + dnspython mypy pre-commit safety commands = # The "-" in front of command tells tox to ignore errors pre-commit run --all-files - - black --check src - - codespell - - mypy --ignore-missing-imports src + mypy src safety check