From 0281f3fa6bb25f210244e6ce96dea7bd3a0b5480 Mon Sep 17 00:00:00 2001 From: SteveRosam Date: Wed, 27 Nov 2024 16:23:58 +0000 Subject: [PATCH 1/5] Add MQTT Sink --- pyproject.toml | 3 +- quixstreams/sinks/community/mqtt.py | 134 ++++++++++++++++++ tests/requirements.txt | 1 + .../test_community/test_mqtt_sink.py | 79 +++++++++++ 4 files changed, 216 insertions(+), 1 deletion(-) create mode 100644 quixstreams/sinks/community/mqtt.py create mode 100644 tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py diff --git a/pyproject.toml b/pyproject.toml index 1f64ba568..1729723f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,8 @@ all = [ "psycopg2-binary>=2.9.9,<3", "boto3>=1.35.65,<2.0", "boto3-stubs>=1.35.65,<2.0", - "redis[hiredis]>=5.2.0,<6" + "redis[hiredis]>=5.2.0,<6", + "paho-mqtt==2.1.0" ] avro = ["fastavro>=1.8,<2.0"] diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py new file mode 100644 index 000000000..ff624ac2a --- /dev/null +++ b/quixstreams/sinks/community/mqtt.py @@ -0,0 +1,134 @@ +from quixstreams.sinks.base.sink import BaseSink +from quixstreams.sinks.base.exceptions import SinkBackpressureError +from typing import List, Tuple, Any +from quixstreams.models.types import HeaderValue +from datetime import datetime +import json + +try: + import paho.mqtt.client as paho + from paho import mqtt +except ImportError as exc: + raise ImportError( + 'Package "paho-mqtt" is missing: ' + "run pip install quixstreams[paho-mqtt] to fix it" + ) from exc + +class MQTTSink(BaseSink): + """ + A sink that publishes messages to an MQTT broker. + """ + + def __init__(self, + mqtt_client_id: str, + mqtt_server: str, + mqtt_port: int, + mqtt_topic_root: str, + mqtt_username: str = None, + mqtt_password: str = None, + mqtt_version: str = "3.1.1", + tls_enabled: bool = True, + qos: int = 1): + """ + Initialize the MQTTSink. + + :param mqtt_client_id: MQTT client identifier. + :param mqtt_server: MQTT broker server address. + :param mqtt_port: MQTT broker server port. + :param mqtt_topic_root: Root topic to publish messages to. + :param mqtt_username: Username for MQTT broker authentication. Defaults to None + :param mqtt_password: Password for MQTT broker authentication. Defaults to None + :param mqtt_version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1 + :param tls_enabled: Whether to use TLS encryption. Defaults to True + :param qos: Quality of Service level (0, 1, or 2). Defaults to 1 + """ + + super().__init__() + + self.mqtt_version = mqtt_version + self.mqtt_username = mqtt_username + self.mqtt_password = mqtt_password + self.mqtt_topic_root = mqtt_topic_root + self.tls_enabled = tls_enabled + self.qos = qos + + self.mqtt_client = paho.Client(callback_api_version=paho.CallbackAPIVersion.VERSION2, + client_id = mqtt_client_id, userdata = None, protocol = self._mqtt_protocol_version()) + + if self.tls_enabled: + self.mqtt_client.tls_set(tls_version = mqtt.client.ssl.PROTOCOL_TLS) # we'll be using tls now + + self.mqtt_client.reconnect_delay_set(5, 60) + self._configure_authentication() + self.mqtt_client.on_connect = self._mqtt_on_connect_cb + self.mqtt_client.on_disconnect = self._mqtt_on_disconnect_cb + self.mqtt_client.connect(mqtt_server, int(mqtt_port)) + + # setting callbacks for different events to see if it works, print the message etc. + def _mqtt_on_connect_cb(self, client: paho.Client, userdata: any, connect_flags: paho.ConnectFlags, + reason_code: paho.ReasonCode, properties: paho.Properties): + if reason_code == 0: + print("CONNECTED!") # required for Quix to know this has connected + else: + print(f"ERROR ({reason_code.value}). {reason_code.getName()}") + + def _mqtt_on_disconnect_cb(self, client: paho.Client, userdata: any, disconnect_flags: paho.DisconnectFlags, + reason_code: paho.ReasonCode, properties: paho.Properties): + print(f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!") + + def _mqtt_protocol_version(self): + if self.mqtt_version == "3.1": + return paho.MQTTv31 + elif self.mqtt_version == "3.1.1": + return paho.MQTTv311 + elif self.mqtt_version == "5": + return paho.MQTTv5 + else: + raise ValueError(f"Unsupported MQTT version: {self.mqtt_version}") + + def _configure_authentication(self): + if self.mqtt_username: + self.mqtt_client.username_pw_set(self.mqtt_username, self.mqtt_password) + + def _publish_to_mqtt(self, data: str, key: bytes, timestamp: datetime, headers: List[Tuple[str, HeaderValue]]): + if isinstance(data, bytes): + data = data.decode('utf-8') # Decode bytes to string using utf-8 + + json_data = json.dumps(data) + message_key_string = key.decode('utf-8') # Convert to string using utf-8 encoding + # publish to MQTT + self.mqtt_client.publish(self.mqtt_topic_root + "/" + message_key_string, payload = json_data, qos = self.qos) + + + def add(self, + topic: str, + partition: int, + offset: int, + key: bytes, + value: bytes, + timestamp: datetime, + headers: List[Tuple[str, HeaderValue]], + **kwargs: Any): + self._publish_to_mqtt(value, key, timestamp, headers) + + def _construct_topic(self, key): + if key: + key_str = key.decode('utf-8') if isinstance(key, bytes) else str(key) + return f"{self.mqtt_topic_root}/{key_str}" + else: + return self.mqtt_topic_root + + def on_paused(self, topic: str, partition: int): + # not used + pass + + def flush(self, topic: str, partition: str): + # not used + pass + + def cleanup(self): + self.mqtt_client.loop_stop() + self.mqtt_client.disconnect() + + def __del__(self): + self.cleanup() \ No newline at end of file diff --git a/tests/requirements.txt b/tests/requirements.txt index 8032d747b..ba046b08d 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -7,3 +7,4 @@ protobuf>=5.27.2 influxdb3-python>=0.7.0,<1.0 pyiceberg[pyarrow,glue]>=0.7,<0.8 redis[hiredis]>=5.2.0,<6 +paho-mqtt==2.1.0 diff --git a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py new file mode 100644 index 000000000..6f9e80d25 --- /dev/null +++ b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py @@ -0,0 +1,79 @@ +from unittest.mock import MagicMock, patch +import pytest +from datetime import datetime +from quixstreams.sinks.community.mqtt import MQTTSink + +@pytest.fixture() +def mqtt_sink_factory(): + def factory( + mqtt_client_id: str = "test_client", + mqtt_server: str = "localhost", + mqtt_port: int = 1883, + mqtt_topic_root: str = "test/topic", + mqtt_username: str = None, + mqtt_password: str = None, + mqtt_version: str = "3.1.1", + tls_enabled: bool = True, + qos: int = 1, + ) -> MQTTSink: + with patch('paho.mqtt.client.Client') as MockClient: + mock_mqtt_client = MockClient.return_value + sink = MQTTSink( + mqtt_client_id=mqtt_client_id, + mqtt_server=mqtt_server, + mqtt_port=mqtt_port, + mqtt_topic_root=mqtt_topic_root, + mqtt_username=mqtt_username, + mqtt_password=mqtt_password, + mqtt_version=mqtt_version, + tls_enabled=tls_enabled, + qos=qos + ) + sink.mqtt_client = mock_mqtt_client + return sink, mock_mqtt_client + + return factory + +class TestMQTTSink: + def test_mqtt_connect(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory() + mock_mqtt_client.connect.assert_called_once_with("localhost", 1883) + + def test_mqtt_tls_enabled(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=True) + mock_mqtt_client.tls_set.assert_called_once() + + def test_mqtt_tls_disabled(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=False) + mock_mqtt_client.tls_set.assert_not_called() + + def test_mqtt_publish(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory() + data = "test_data" + key = b"test_key" + timestamp = datetime.now() + headers = [] + + sink.add( + topic="test-topic", + partition=0, + offset=1, + key=key, + value=data.encode('utf-8'), + timestamp=timestamp, + headers=headers + ) + + mock_mqtt_client.publish.assert_called_once_with( + "test/topic/test_key", payload='"test_data"', qos=1 + ) + + def test_mqtt_authentication(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory(mqtt_username="user", mqtt_password="pass") + mock_mqtt_client.username_pw_set.assert_called_once_with("user", "pass") + + def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory): + sink, mock_mqtt_client = mqtt_sink_factory() + sink.cleanup() # Explicitly call cleanup + mock_mqtt_client.loop_stop.assert_called_once() + mock_mqtt_client.disconnect.assert_called_once() \ No newline at end of file From 5c7bb61417cdb8c419c9d14638c2a59d32493f40 Mon Sep 17 00:00:00 2001 From: SteveRosam Date: Wed, 27 Nov 2024 16:31:34 +0000 Subject: [PATCH 2/5] Add new line --- .../test_sinks/test_community/test_mqtt_sink.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py index 6f9e80d25..4efe1d694 100644 --- a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py +++ b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py @@ -76,4 +76,4 @@ def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory): sink, mock_mqtt_client = mqtt_sink_factory() sink.cleanup() # Explicitly call cleanup mock_mqtt_client.loop_stop.assert_called_once() - mock_mqtt_client.disconnect.assert_called_once() \ No newline at end of file + mock_mqtt_client.disconnect.assert_called_once() From 52e71f26751f412c64dab4c9acf9d32607ea37de Mon Sep 17 00:00:00 2001 From: SteveRosam Date: Wed, 27 Nov 2024 16:33:36 +0000 Subject: [PATCH 3/5] EoF New Line --- quixstreams/sinks/community/mqtt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py index ff624ac2a..b4d305651 100644 --- a/quixstreams/sinks/community/mqtt.py +++ b/quixstreams/sinks/community/mqtt.py @@ -131,4 +131,4 @@ def cleanup(self): self.mqtt_client.disconnect() def __del__(self): - self.cleanup() \ No newline at end of file + self.cleanup() From 669d38a8eb133ac8a096b73f145a83d4c1c5e65d Mon Sep 17 00:00:00 2001 From: Quentin Dawans Date: Thu, 28 Nov 2024 10:08:32 +0100 Subject: [PATCH 4/5] run linters --- quixstreams/sinks/community/mqtt.py | 124 +++++++++++------- .../test_community/test_mqtt_sink.py | 20 ++- 2 files changed, 93 insertions(+), 51 deletions(-) diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py index b4d305651..51ab840a7 100644 --- a/quixstreams/sinks/community/mqtt.py +++ b/quixstreams/sinks/community/mqtt.py @@ -1,9 +1,9 @@ -from quixstreams.sinks.base.sink import BaseSink -from quixstreams.sinks.base.exceptions import SinkBackpressureError -from typing import List, Tuple, Any -from quixstreams.models.types import HeaderValue -from datetime import datetime import json +from datetime import datetime +from typing import Any, List, Tuple + +from quixstreams.models.types import HeaderValue +from quixstreams.sinks.base.sink import BaseSink try: import paho.mqtt.client as paho @@ -14,21 +14,24 @@ "run pip install quixstreams[paho-mqtt] to fix it" ) from exc + class MQTTSink(BaseSink): """ A sink that publishes messages to an MQTT broker. """ - def __init__(self, - mqtt_client_id: str, - mqtt_server: str, - mqtt_port: int, - mqtt_topic_root: str, - mqtt_username: str = None, - mqtt_password: str = None, - mqtt_version: str = "3.1.1", - tls_enabled: bool = True, - qos: int = 1): + def __init__( + self, + mqtt_client_id: str, + mqtt_server: str, + mqtt_port: int, + mqtt_topic_root: str, + mqtt_username: str = None, + mqtt_password: str = None, + mqtt_version: str = "3.1.1", + tls_enabled: bool = True, + qos: int = 1, + ): """ Initialize the MQTTSink. @@ -42,9 +45,9 @@ def __init__(self, :param tls_enabled: Whether to use TLS encryption. Defaults to True :param qos: Quality of Service level (0, 1, or 2). Defaults to 1 """ - + super().__init__() - + self.mqtt_version = mqtt_version self.mqtt_username = mqtt_username self.mqtt_password = mqtt_password @@ -52,11 +55,17 @@ def __init__(self, self.tls_enabled = tls_enabled self.qos = qos - self.mqtt_client = paho.Client(callback_api_version=paho.CallbackAPIVersion.VERSION2, - client_id = mqtt_client_id, userdata = None, protocol = self._mqtt_protocol_version()) + self.mqtt_client = paho.Client( + callback_api_version=paho.CallbackAPIVersion.VERSION2, + client_id=mqtt_client_id, + userdata=None, + protocol=self._mqtt_protocol_version(), + ) if self.tls_enabled: - self.mqtt_client.tls_set(tls_version = mqtt.client.ssl.PROTOCOL_TLS) # we'll be using tls now + self.mqtt_client.tls_set( + tls_version=mqtt.client.ssl.PROTOCOL_TLS + ) # we'll be using tls now self.mqtt_client.reconnect_delay_set(5, 60) self._configure_authentication() @@ -65,17 +74,31 @@ def __init__(self, self.mqtt_client.connect(mqtt_server, int(mqtt_port)) # setting callbacks for different events to see if it works, print the message etc. - def _mqtt_on_connect_cb(self, client: paho.Client, userdata: any, connect_flags: paho.ConnectFlags, - reason_code: paho.ReasonCode, properties: paho.Properties): + def _mqtt_on_connect_cb( + self, + client: paho.Client, + userdata: any, + connect_flags: paho.ConnectFlags, + reason_code: paho.ReasonCode, + properties: paho.Properties, + ): if reason_code == 0: - print("CONNECTED!") # required for Quix to know this has connected + print("CONNECTED!") # required for Quix to know this has connected else: print(f"ERROR ({reason_code.value}). {reason_code.getName()}") - def _mqtt_on_disconnect_cb(self, client: paho.Client, userdata: any, disconnect_flags: paho.DisconnectFlags, - reason_code: paho.ReasonCode, properties: paho.Properties): - print(f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!") - + def _mqtt_on_disconnect_cb( + self, + client: paho.Client, + userdata: any, + disconnect_flags: paho.DisconnectFlags, + reason_code: paho.ReasonCode, + properties: paho.Properties, + ): + print( + f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!" + ) + def _mqtt_protocol_version(self): if self.mqtt_version == "3.1": return paho.MQTTv31 @@ -90,30 +113,43 @@ def _configure_authentication(self): if self.mqtt_username: self.mqtt_client.username_pw_set(self.mqtt_username, self.mqtt_password) - def _publish_to_mqtt(self, data: str, key: bytes, timestamp: datetime, headers: List[Tuple[str, HeaderValue]]): + def _publish_to_mqtt( + self, + data: str, + key: bytes, + timestamp: datetime, + headers: List[Tuple[str, HeaderValue]], + ): if isinstance(data, bytes): - data = data.decode('utf-8') # Decode bytes to string using utf-8 + data = data.decode("utf-8") # Decode bytes to string using utf-8 json_data = json.dumps(data) - message_key_string = key.decode('utf-8') # Convert to string using utf-8 encoding + message_key_string = key.decode( + "utf-8" + ) # Convert to string using utf-8 encoding # publish to MQTT - self.mqtt_client.publish(self.mqtt_topic_root + "/" + message_key_string, payload = json_data, qos = self.qos) - - - def add(self, - topic: str, - partition: int, - offset: int, - key: bytes, - value: bytes, - timestamp: datetime, - headers: List[Tuple[str, HeaderValue]], - **kwargs: Any): + self.mqtt_client.publish( + self.mqtt_topic_root + "/" + message_key_string, + payload=json_data, + qos=self.qos, + ) + + def add( + self, + topic: str, + partition: int, + offset: int, + key: bytes, + value: bytes, + timestamp: datetime, + headers: List[Tuple[str, HeaderValue]], + **kwargs: Any, + ): self._publish_to_mqtt(value, key, timestamp, headers) def _construct_topic(self, key): if key: - key_str = key.decode('utf-8') if isinstance(key, bytes) else str(key) + key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key) return f"{self.mqtt_topic_root}/{key_str}" else: return self.mqtt_topic_root @@ -121,7 +157,7 @@ def _construct_topic(self, key): def on_paused(self, topic: str, partition: int): # not used pass - + def flush(self, topic: str, partition: str): # not used pass diff --git a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py index 4efe1d694..05b6b332b 100644 --- a/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py +++ b/tests/test_quixstreams/test_sinks/test_community/test_mqtt_sink.py @@ -1,8 +1,11 @@ -from unittest.mock import MagicMock, patch -import pytest from datetime import datetime +from unittest.mock import patch + +import pytest + from quixstreams.sinks.community.mqtt import MQTTSink + @pytest.fixture() def mqtt_sink_factory(): def factory( @@ -16,7 +19,7 @@ def factory( tls_enabled: bool = True, qos: int = 1, ) -> MQTTSink: - with patch('paho.mqtt.client.Client') as MockClient: + with patch("paho.mqtt.client.Client") as MockClient: mock_mqtt_client = MockClient.return_value sink = MQTTSink( mqtt_client_id=mqtt_client_id, @@ -27,13 +30,14 @@ def factory( mqtt_password=mqtt_password, mqtt_version=mqtt_version, tls_enabled=tls_enabled, - qos=qos + qos=qos, ) sink.mqtt_client = mock_mqtt_client return sink, mock_mqtt_client return factory + class TestMQTTSink: def test_mqtt_connect(self, mqtt_sink_factory): sink, mock_mqtt_client = mqtt_sink_factory() @@ -59,9 +63,9 @@ def test_mqtt_publish(self, mqtt_sink_factory): partition=0, offset=1, key=key, - value=data.encode('utf-8'), + value=data.encode("utf-8"), timestamp=timestamp, - headers=headers + headers=headers, ) mock_mqtt_client.publish.assert_called_once_with( @@ -69,7 +73,9 @@ def test_mqtt_publish(self, mqtt_sink_factory): ) def test_mqtt_authentication(self, mqtt_sink_factory): - sink, mock_mqtt_client = mqtt_sink_factory(mqtt_username="user", mqtt_password="pass") + sink, mock_mqtt_client = mqtt_sink_factory( + mqtt_username="user", mqtt_password="pass" + ) mock_mqtt_client.username_pw_set.assert_called_once_with("user", "pass") def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory): From 02ead3d44140701d14cde3be026bf78c959bd3e1 Mon Sep 17 00:00:00 2001 From: Quentin Dawans Date: Thu, 28 Nov 2024 10:27:01 +0100 Subject: [PATCH 5/5] requirements --- conda/post-link.sh | 3 ++- pyproject.toml | 3 ++- quixstreams/sinks/community/mqtt.py | 3 +-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/conda/post-link.sh b/conda/post-link.sh index aaf0608f9..24a9a62bb 100644 --- a/conda/post-link.sh +++ b/conda/post-link.sh @@ -5,4 +5,5 @@ $PREFIX/bin/pip install \ 'protobuf>=5.27.2,<6.0' \ 'influxdb3-python>=0.7,<1.0' \ 'pyiceberg[pyarrow,glue]>=0.7,<0.8' \ -'redis[hiredis]>=5.2.0,<6' +'redis[hiredis]>=5.2.0,<6' \ +'paho-mqtt>=2.1.0,<3' diff --git a/pyproject.toml b/pyproject.toml index 1729723f9..cbcdcb782 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ all = [ "boto3>=1.35.65,<2.0", "boto3-stubs>=1.35.65,<2.0", "redis[hiredis]>=5.2.0,<6", - "paho-mqtt==2.1.0" + "paho-mqtt>=2.1.0,<3" ] avro = ["fastavro>=1.8,<2.0"] @@ -51,6 +51,7 @@ pubsub = ["google-cloud-pubsub>=2.23.1,<3"] postgresql = ["psycopg2-binary>=2.9.9,<3"] kinesis = ["boto3>=1.35.65,<2.0", "boto3-stubs[kinesis]>=1.35.65,<2.0"] redis=["redis[hiredis]>=5.2.0,<6"] +mqtt=["paho-mqtt>=2.1.0,<3"] [tool.setuptools.packages.find] include = ["quixstreams*"] diff --git a/quixstreams/sinks/community/mqtt.py b/quixstreams/sinks/community/mqtt.py index 51ab840a7..51a2284a0 100644 --- a/quixstreams/sinks/community/mqtt.py +++ b/quixstreams/sinks/community/mqtt.py @@ -10,8 +10,7 @@ from paho import mqtt except ImportError as exc: raise ImportError( - 'Package "paho-mqtt" is missing: ' - "run pip install quixstreams[paho-mqtt] to fix it" + 'Package "paho-mqtt" is missing: ' "run pip install quixstreams[mqtt] to fix it" ) from exc