-
Notifications
You must be signed in to change notification settings - Fork 76
Add MQTT Sink #659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add MQTT Sink #659
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,169 @@ | ||||||||||||||
import json | ||||||||||||||
from datetime import datetime | ||||||||||||||
from typing import Any, List, Tuple | ||||||||||||||
|
||||||||||||||
from quixstreams.models.types import HeaderValue | ||||||||||||||
from quixstreams.sinks.base.sink import BaseSink | ||||||||||||||
|
||||||||||||||
try: | ||||||||||||||
import paho.mqtt.client as paho | ||||||||||||||
from paho import mqtt | ||||||||||||||
except ImportError as exc: | ||||||||||||||
raise ImportError( | ||||||||||||||
'Package "paho-mqtt" is missing: ' "run pip install quixstreams[mqtt] to fix it" | ||||||||||||||
) from exc | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class MQTTSink(BaseSink): | ||||||||||||||
""" | ||||||||||||||
A sink that publishes messages to an MQTT broker. | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
def __init__( | ||||||||||||||
self, | ||||||||||||||
mqtt_client_id: str, | ||||||||||||||
mqtt_server: str, | ||||||||||||||
mqtt_port: int, | ||||||||||||||
mqtt_topic_root: str, | ||||||||||||||
mqtt_username: str = None, | ||||||||||||||
mqtt_password: str = None, | ||||||||||||||
mqtt_version: str = "3.1.1", | ||||||||||||||
tls_enabled: bool = True, | ||||||||||||||
qos: int = 1, | ||||||||||||||
): | ||||||||||||||
""" | ||||||||||||||
Initialize the MQTTSink. | ||||||||||||||
|
||||||||||||||
:param mqtt_client_id: MQTT client identifier. | ||||||||||||||
:param mqtt_server: MQTT broker server address. | ||||||||||||||
:param mqtt_port: MQTT broker server port. | ||||||||||||||
:param mqtt_topic_root: Root topic to publish messages to. | ||||||||||||||
:param mqtt_username: Username for MQTT broker authentication. Defaults to None | ||||||||||||||
:param mqtt_password: Password for MQTT broker authentication. Defaults to None | ||||||||||||||
:param mqtt_version: MQTT protocol version ("3.1", "3.1.1", or "5"). Defaults to 3.1.1 | ||||||||||||||
:param tls_enabled: Whether to use TLS encryption. Defaults to True | ||||||||||||||
:param qos: Quality of Service level (0, 1, or 2). Defaults to 1 | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
super().__init__() | ||||||||||||||
|
||||||||||||||
self.mqtt_version = mqtt_version | ||||||||||||||
self.mqtt_username = mqtt_username | ||||||||||||||
self.mqtt_password = mqtt_password | ||||||||||||||
self.mqtt_topic_root = mqtt_topic_root | ||||||||||||||
self.tls_enabled = tls_enabled | ||||||||||||||
self.qos = qos | ||||||||||||||
|
||||||||||||||
self.mqtt_client = paho.Client( | ||||||||||||||
callback_api_version=paho.CallbackAPIVersion.VERSION2, | ||||||||||||||
client_id=mqtt_client_id, | ||||||||||||||
userdata=None, | ||||||||||||||
protocol=self._mqtt_protocol_version(), | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
if self.tls_enabled: | ||||||||||||||
self.mqtt_client.tls_set( | ||||||||||||||
tls_version=mqtt.client.ssl.PROTOCOL_TLS | ||||||||||||||
) # we'll be using tls now | ||||||||||||||
|
||||||||||||||
self.mqtt_client.reconnect_delay_set(5, 60) | ||||||||||||||
self._configure_authentication() | ||||||||||||||
self.mqtt_client.on_connect = self._mqtt_on_connect_cb | ||||||||||||||
self.mqtt_client.on_disconnect = self._mqtt_on_disconnect_cb | ||||||||||||||
self.mqtt_client.connect(mqtt_server, int(mqtt_port)) | ||||||||||||||
|
||||||||||||||
# setting callbacks for different events to see if it works, print the message etc. | ||||||||||||||
def _mqtt_on_connect_cb( | ||||||||||||||
self, | ||||||||||||||
client: paho.Client, | ||||||||||||||
userdata: any, | ||||||||||||||
connect_flags: paho.ConnectFlags, | ||||||||||||||
reason_code: paho.ReasonCode, | ||||||||||||||
properties: paho.Properties, | ||||||||||||||
): | ||||||||||||||
if reason_code == 0: | ||||||||||||||
print("CONNECTED!") # required for Quix to know this has connected | ||||||||||||||
else: | ||||||||||||||
print(f"ERROR ({reason_code.value}). {reason_code.getName()}") | ||||||||||||||
|
||||||||||||||
def _mqtt_on_disconnect_cb( | ||||||||||||||
self, | ||||||||||||||
client: paho.Client, | ||||||||||||||
userdata: any, | ||||||||||||||
disconnect_flags: paho.DisconnectFlags, | ||||||||||||||
reason_code: paho.ReasonCode, | ||||||||||||||
properties: paho.Properties, | ||||||||||||||
): | ||||||||||||||
print( | ||||||||||||||
f"DISCONNECTED! Reason code ({reason_code.value}) {reason_code.getName()}!" | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def _mqtt_protocol_version(self): | ||||||||||||||
if self.mqtt_version == "3.1": | ||||||||||||||
return paho.MQTTv31 | ||||||||||||||
elif self.mqtt_version == "3.1.1": | ||||||||||||||
return paho.MQTTv311 | ||||||||||||||
elif self.mqtt_version == "5": | ||||||||||||||
return paho.MQTTv5 | ||||||||||||||
else: | ||||||||||||||
raise ValueError(f"Unsupported MQTT version: {self.mqtt_version}") | ||||||||||||||
|
||||||||||||||
def _configure_authentication(self): | ||||||||||||||
if self.mqtt_username: | ||||||||||||||
self.mqtt_client.username_pw_set(self.mqtt_username, self.mqtt_password) | ||||||||||||||
|
||||||||||||||
def _publish_to_mqtt( | ||||||||||||||
self, | ||||||||||||||
data: str, | ||||||||||||||
key: bytes, | ||||||||||||||
timestamp: datetime, | ||||||||||||||
headers: List[Tuple[str, HeaderValue]], | ||||||||||||||
): | ||||||||||||||
if isinstance(data, bytes): | ||||||||||||||
data = data.decode("utf-8") # Decode bytes to string using utf-8 | ||||||||||||||
|
||||||||||||||
json_data = json.dumps(data) | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some sinks has concept of value_serializer: Callable[[Any], str] = json.dumps,
key_serializer: Callable[[Any], str] = bytes.decode, in its constructors. That is very handy, specially when you are dealing with non json data. I don't see a use case for |
||||||||||||||
message_key_string = key.decode( | ||||||||||||||
"utf-8" | ||||||||||||||
) # Convert to string using utf-8 encoding | ||||||||||||||
# publish to MQTT | ||||||||||||||
self.mqtt_client.publish( | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be good to add sync primitive here. Client.publish returns MQTTMessageInfo which has
Suggested change
see quix-streams sink flush docs |
||||||||||||||
self.mqtt_topic_root + "/" + message_key_string, | ||||||||||||||
payload=json_data, | ||||||||||||||
qos=self.qos, | ||||||||||||||
Comment on lines
+131
to
+133
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to docs, it would be great to have a control over retain: bool | Callable[[Any], bool] = False,
properties: paho.mqtt.properties.Properties | Callable[[Any], paho.mqtt.properties.Properties] | None = None, |
||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def add( | ||||||||||||||
self, | ||||||||||||||
topic: str, | ||||||||||||||
partition: int, | ||||||||||||||
offset: int, | ||||||||||||||
key: bytes, | ||||||||||||||
value: bytes, | ||||||||||||||
timestamp: datetime, | ||||||||||||||
headers: List[Tuple[str, HeaderValue]], | ||||||||||||||
**kwargs: Any, | ||||||||||||||
): | ||||||||||||||
self._publish_to_mqtt(value, key, timestamp, headers) | ||||||||||||||
|
||||||||||||||
def _construct_topic(self, key): | ||||||||||||||
if key: | ||||||||||||||
key_str = key.decode("utf-8") if isinstance(key, bytes) else str(key) | ||||||||||||||
return f"{self.mqtt_topic_root}/{key_str}" | ||||||||||||||
else: | ||||||||||||||
return self.mqtt_topic_root | ||||||||||||||
|
||||||||||||||
def on_paused(self, topic: str, partition: int): | ||||||||||||||
# not used | ||||||||||||||
pass | ||||||||||||||
|
||||||||||||||
def flush(self, topic: str, partition: str): | ||||||||||||||
# not used | ||||||||||||||
pass | ||||||||||||||
|
||||||||||||||
def cleanup(self): | ||||||||||||||
self.mqtt_client.loop_stop() | ||||||||||||||
self.mqtt_client.disconnect() | ||||||||||||||
|
||||||||||||||
def __del__(self): | ||||||||||||||
self.cleanup() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from datetime import datetime | ||
from unittest.mock import patch | ||
|
||
import pytest | ||
|
||
from quixstreams.sinks.community.mqtt import MQTTSink | ||
|
||
|
||
@pytest.fixture() | ||
def mqtt_sink_factory(): | ||
def factory( | ||
mqtt_client_id: str = "test_client", | ||
mqtt_server: str = "localhost", | ||
mqtt_port: int = 1883, | ||
mqtt_topic_root: str = "test/topic", | ||
mqtt_username: str = None, | ||
mqtt_password: str = None, | ||
mqtt_version: str = "3.1.1", | ||
tls_enabled: bool = True, | ||
qos: int = 1, | ||
) -> MQTTSink: | ||
with patch("paho.mqtt.client.Client") as MockClient: | ||
mock_mqtt_client = MockClient.return_value | ||
sink = MQTTSink( | ||
mqtt_client_id=mqtt_client_id, | ||
mqtt_server=mqtt_server, | ||
mqtt_port=mqtt_port, | ||
mqtt_topic_root=mqtt_topic_root, | ||
mqtt_username=mqtt_username, | ||
mqtt_password=mqtt_password, | ||
mqtt_version=mqtt_version, | ||
tls_enabled=tls_enabled, | ||
qos=qos, | ||
) | ||
sink.mqtt_client = mock_mqtt_client | ||
return sink, mock_mqtt_client | ||
|
||
return factory | ||
|
||
|
||
class TestMQTTSink: | ||
def test_mqtt_connect(self, mqtt_sink_factory): | ||
sink, mock_mqtt_client = mqtt_sink_factory() | ||
mock_mqtt_client.connect.assert_called_once_with("localhost", 1883) | ||
|
||
def test_mqtt_tls_enabled(self, mqtt_sink_factory): | ||
sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=True) | ||
mock_mqtt_client.tls_set.assert_called_once() | ||
|
||
def test_mqtt_tls_disabled(self, mqtt_sink_factory): | ||
sink, mock_mqtt_client = mqtt_sink_factory(tls_enabled=False) | ||
mock_mqtt_client.tls_set.assert_not_called() | ||
|
||
def test_mqtt_publish(self, mqtt_sink_factory): | ||
sink, mock_mqtt_client = mqtt_sink_factory() | ||
data = "test_data" | ||
key = b"test_key" | ||
timestamp = datetime.now() | ||
headers = [] | ||
|
||
sink.add( | ||
topic="test-topic", | ||
partition=0, | ||
offset=1, | ||
key=key, | ||
value=data.encode("utf-8"), | ||
timestamp=timestamp, | ||
headers=headers, | ||
) | ||
|
||
mock_mqtt_client.publish.assert_called_once_with( | ||
"test/topic/test_key", payload='"test_data"', qos=1 | ||
) | ||
|
||
def test_mqtt_authentication(self, mqtt_sink_factory): | ||
sink, mock_mqtt_client = mqtt_sink_factory( | ||
mqtt_username="user", mqtt_password="pass" | ||
) | ||
mock_mqtt_client.username_pw_set.assert_called_once_with("user", "pass") | ||
|
||
def test_mqtt_disconnect_on_delete(self, mqtt_sink_factory): | ||
sink, mock_mqtt_client = mqtt_sink_factory() | ||
sink.cleanup() # Explicitly call cleanup | ||
mock_mqtt_client.loop_stop.assert_called_once() | ||
mock_mqtt_client.disconnect.assert_called_once() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.