Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions osprey_worker/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"confluent-kafka>=2.14.0",
"flask-cors>=6.0.1",
"osprey_rpc",
]
Expand Down
12 changes: 9 additions & 3 deletions osprey_worker/src/osprey/worker/_stdlibplugin/sink_register.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List, Sequence

from kafka import KafkaProducer
from osprey.worker._stdlibplugin.execution_result_store_chooser import get_rules_execution_result_storage_backend
from osprey.worker.adaptor.plugin_manager import hookimpl_osprey
from osprey.worker.lib.config import Config
Expand All @@ -19,10 +18,17 @@ def register_output_sinks(config: Config) -> Sequence[BaseOutputSink]:
output_topic = config.expect_str('OSPREY_KAFKA_OUTPUT_TOPIC')
bootstrap_servers = config.expect_str_list('OSPREY_KAFKA_BOOTSTRAP_SERVERS')
client_id = config.expect_str('OSPREY_KAFKA_OUTPUT_CLIENT_ID')
auto_create_topic = config.get_bool('OSPREY_KAFKA_AUTO_CREATE_TOPIC', True)
num_partitions = config.get_int('OSPREY_KAFKA_NUM_PARTITIONS', 1)
replication_factor = config.get_int('OSPREY_KAFKA_REPLICATION_FACTOR', 1)
sinks.append(
KafkaOutputSink(
kafka_topic=output_topic,
kafka_producer=KafkaProducer(bootstrap_servers=bootstrap_servers, client_id=client_id),
bootstrap_servers=bootstrap_servers,
output_topic=output_topic,
client_id=client_id,
auto_create_topic=auto_create_topic,
num_partitions=num_partitions,
replication_factor=replication_factor,
)
)

Expand Down
49 changes: 41 additions & 8 deletions osprey_worker/src/osprey/worker/cli/sinks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

import click
import gevent
import kafka
import sentry_sdk
from confluent_kafka import Consumer
from google.api_core.exceptions import AlreadyExists
from google.cloud import pubsub_v1
from osprey.worker.lib import instruments
Expand All @@ -51,7 +51,6 @@
from osprey.worker.sinks.sink.input_stream import PostgresInputStream
from osprey.worker.sinks.sink.osprey_coordinator_input_stream import OspreyCoordinatorInputStream
from osprey.worker.sinks.sink.rules_sink import RulesSink
from osprey.worker.sinks.utils.kafka import PatchedKafkaConsumer

LOGGER = get_logger()

Expand Down Expand Up @@ -81,9 +80,25 @@ def tail_kafka_output_sink() -> None:
output_topic = config.get_str('OSPREY_KAFKA_OUTPUT_SINK_TOPIC', 'osprey.execution_results')
bootstrap_servers = config.get_str_list('OSPREY_KAFKA_BOOTSTRAP_SERVERS', ['localhost'])

kakfa_consumer = kafka.KafkaConsumer(output_topic, bootstrap_servers=bootstrap_servers)
for event in kakfa_consumer:
print(event)
consumer = Consumer(
{
'bootstrap.servers': ','.join(bootstrap_servers),
'group.id': 'osprey-tail-output',
'auto.offset.reset': 'latest',
}
)
consumer.subscribe([output_topic])
try:
while True:
msg = consumer.poll(timeout=1.0)
if msg is None:
continue
if msg.error():
print(f'Error: {msg.error()}')
continue
print(msg.value())
finally:
consumer.close()


@cli.command()
Expand All @@ -101,9 +116,27 @@ def tail_kafka_input_sink() -> None:
client_id = config.get_str('OSPREY_KAFKA_INPUT_STREAM_CLIENT_ID', 'localhost')
bootstrap_servers = config.get_str_list('OSPREY_KAFKA_BOOTSTRAP_SERVERS', ['localhost'])
input_topic = config.get_str('OSPREY_KAFKA_INPUT_STREAM_TOPIC', 'osprey.actions_input')
kafka_consumer = PatchedKafkaConsumer(input_topic, bootstrap_servers=bootstrap_servers, client_id=client_id)
for event in kafka_consumer:
print(event)

consumer = Consumer(
{
'bootstrap.servers': ','.join(bootstrap_servers),
'client.id': client_id,
'group.id': 'osprey-tail-input',
'auto.offset.reset': 'latest',
}
)
consumer.subscribe([input_topic])
try:
while True:
msg = consumer.poll(timeout=1.0)
if msg is None:
continue
if msg.error():
print(f'Error: {msg.error()}')
continue
print(msg.value())
finally:
consumer.close()


@cli.command()
Expand Down
25 changes: 15 additions & 10 deletions osprey_worker/src/osprey/worker/sinks/input_stream_chooser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import random
from datetime import datetime, timedelta

from confluent_kafka import Consumer
from google.cloud import pubsub_v1
from kafka.coordinator.assignors.roundrobin import RoundRobinPartitionAssignor
from osprey.engine.executor.execution_context import Action
from osprey.worker.adaptor.plugin_manager import bootstrap_input_stream
from osprey.worker.lib.singletons import CONFIG
Expand All @@ -15,7 +15,7 @@
)
from osprey.worker.sinks.sink.osprey_coordinator_input_stream import OspreyCoordinatorInputStream
from osprey.worker.sinks.utils.acking_contexts import BaseAckingContext, NoopAckingContext
from osprey.worker.sinks.utils.kafka import PatchedKafkaConsumer
from osprey.worker.sinks.utils.kafka import ThreadedKafkaConsumer


def get_rules_sink_input_stream(
Expand Down Expand Up @@ -113,17 +113,22 @@ def get_rules_sink_input_stream(
if client_id_suffix:
client_id = f'{client_id}-{client_id_suffix}'

consumer: PatchedKafkaConsumer = PatchedKafkaConsumer(
input_topic,
bootstrap_servers=input_bootstrap_servers,
client_id=client_id,
group_id=group_id,
partition_assignment_strategy=(RoundRobinPartitionAssignor,),
)
consumer_config = {
'bootstrap.servers': ','.join(input_bootstrap_servers),
'client.id': client_id,
'group.id': group_id or 'osprey-consumer',
'partition.assignment.strategy': 'roundrobin',
'auto.offset.reset': 'latest',
}

consumer = Consumer(consumer_config)
consumer.subscribe([input_topic])
threaded_consumer = ThreadedKafkaConsumer(consumer)

from osprey.worker.sinks.sink.input_stream import KafkaInputStream

return KafkaInputStream(
kafka_consumer=consumer,
kafka_consumer=threaded_consumer,
)
elif input_stream_source == InputStreamSource.PLUGIN:
stream = bootstrap_input_stream(config=config)
Expand Down
32 changes: 20 additions & 12 deletions osprey_worker/src/osprey/worker/sinks/sink/input_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import gevent
import msgpack
import sentry_sdk
from confluent_kafka import Message as KafkaMessage
from gevent.lock import RLock
from gevent.queue import Queue as GeventQueue
from google.api_core import retry
Expand All @@ -16,7 +17,6 @@
from google.protobuf.message import DecodeError
from google.protobuf.message import Message as ProtoMessage
from google.pubsub_v1 import PubsubMessage
from kafka.consumer.fetcher import ConsumerRecord
from osprey.engine.executor.execution_context import Action
from osprey.worker.lib.encryption.envelope import Envelope
from osprey.worker.lib.instruments import metrics
Expand All @@ -29,7 +29,6 @@
PubSubMessageAckingContext,
PullPubSubMessageContext,
)
from osprey.worker.sinks.utils.kafka import PatchedKafkaConsumer
from pydantic import BaseModel
from tenacity import RetryCallState, retry_if_exception_type, stop_never, wait_exponential
from tenacity import retry as tenacity_retry
Expand All @@ -43,6 +42,7 @@

if TYPE_CHECKING:
from google.cloud.pubsub_v1.types import PullResponse
from osprey.worker.sinks.utils.kafka import ThreadedKafkaConsumer


class BaseInputStream(abc.ABC, Generic[_T]):
Expand Down Expand Up @@ -413,30 +413,38 @@ def claim_with_retry() -> Optional[_ModelT]:
class KafkaInputStream(BaseInputStream[BaseAckingContext[Action]]):
"""An input stream that consumes messages from a Kafka topic and yields Action objects wrapped in an AckingContext."""

def __init__(self, kafka_consumer: PatchedKafkaConsumer):
def __init__(self, kafka_consumer: 'ThreadedKafkaConsumer'):
super().__init__()
self._consumer: PatchedKafkaConsumer = kafka_consumer
self._consumer = kafka_consumer

def _gen(self) -> Iterator[BaseAckingContext[Action]]:
while True:
try:
with metrics.timed('kafka_consumer.lock_time'):
with metrics.timed('kafka_consumer.poll_time'):
record: ConsumerRecord = next(self._consumer)
data = json.loads(record.value)
with metrics.timed('kafka_consumer.poll_time'):
msg: Optional[KafkaMessage] = self._consumer.poll(timeout=1.0)

if msg is None:
continue

if msg.error():
logger.error(f'Kafka consumer error: {msg.error()}')
sentry_sdk.capture_exception(Exception(str(msg.error())))
continue

value = msg.value()
if value is None:
continue

data = json.loads(value)
timestamp = parse_go_timestamp(data['send_time'])
action_data = data['data']
# this was here for when this was protobuf. If its json by default, we should just assume its all in one
# json blob.
# action_data = json.loads(action_data_json)

action = Action(
action_id=int(action_data['action_id']),
action_name=action_data['action_name'],
data=action_data['data'],
timestamp=timestamp,
)
# Wrap in NoopAckingContext for now, or implement a KafkaAckingContext if needed
yield NoopAckingContext(action)
except Exception as e:
logger.exception('Error while consuming from Kafka')
Expand Down
117 changes: 109 additions & 8 deletions osprey_worker/src/osprey/worker/sinks/sink/kafka_output_sink.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,135 @@
import platform
from typing import Any

import sentry_sdk
from kafka import KafkaProducer
from confluent_kafka import Producer
from confluent_kafka.admin import AdminClient, NewTopic
from osprey.engine.executor.execution_context import ExecutionResult
from osprey.worker.lib.osprey_shared.logging import get_logger
from osprey.worker.sinks.sink.output_sink import BaseOutputSink
from osprey.worker.sinks.utils.kafka import ThreadedKafkaProducer

logger = get_logger()


class EmptyBootstrapServersException(Exception):
"""Exception that is raised whenever the server list provided to KafkaOutputSink is empty."""


class InvalidOutputTopicException(Exception):
"""Exception that is raised whenever the output topic passed to KafkaOutputSink is empty."""


class KafkaOutputSink(BaseOutputSink):
"""An output sink that sends the extracted features to a given kafka topic."""

def __init__(self, kafka_topic: str, kafka_producer: KafkaProducer):
self._kafka_topic = kafka_topic
self._kafka_producer = kafka_producer
def __init__(
self,
bootstrap_servers: list[str],
output_topic: str,
client_id: str | None,
auto_create_topic: bool = True,
num_partitions: int = 1,
replication_factor: int = 1,
) -> None:
if len(bootstrap_servers) == 0:
raise EmptyBootstrapServersException()

if output_topic == '':
raise InvalidOutputTopicException()

self.logger = get_logger('KafkaOutputSink')

self._bootstrap_servers = bootstrap_servers
self._output_topic = output_topic
self._num_partitions = num_partitions
self._replication_factor = replication_factor

# NOTE(haileyok): this is...not necessary probably
self.topic_ensured = False

if client_id is None:
client_hostname = platform.node()
if client_hostname != '':
client_id = f'{client_hostname};host_override={bootstrap_servers[0]}'
else:
client_id = f'osprey-output-sink;host_override={bootstrap_servers[0]}'

self.logger.info(f'Creating Kafka producer with client id {client_id}')

config = {
'bootstrap.servers': ','.join(bootstrap_servers),
'client.id': client_id,
'queue.buffering.max.messages': 1_000_000,
'linger.ms': 10,
'retries': 10,
'request.timeout.ms': 30_000,
'socket.timeout.ms': 30_000,
'delivery.timeout.ms': 120_000,
'statistics.interval.ms': 10_000,
'log.connection.close': False,
'enable.idempotence': True,
'acks': 'all',
'max.in.flight.requests.per.connection': 5,
'message.max.bytes': 20_000_000,
}

self._producer = Producer(config)
self._threaded_producer = ThreadedKafkaProducer(self._producer)

if auto_create_topic:
self.ensure_topic()

super().__init__()

def ensure_topic(self) -> None:
"""Create the Kafka topic if it does not yet exist."""
admin_client = AdminClient({'bootstrap.servers': ','.join(self._bootstrap_servers)})

try:
metadata = admin_client.list_topics(timeout=10)
except Exception as e:
self.logger.error(f'Error listing topics, unable to ensure topic: {e}')
return

if self._output_topic in metadata.topics:
self.topic_ensured = True
return

self.logger.info(f'Creating topic {self._output_topic}')
try:
topic = NewTopic(
self._output_topic, num_partitions=self._num_partitions, replication_factor=self._replication_factor
)
fs = admin_client.create_topics([topic])
fs[self._output_topic].result()
self.topic_ensured = True
except Exception as e:
self.logger.error(f'Error creating topic, unable to ensure topic: {e}')

def will_do_work(self, result: ExecutionResult) -> bool:
return True

def push(self, result: ExecutionResult) -> None:
kafka_future: Any = self._kafka_producer.send(
topic=self._kafka_topic, value=result.extracted_features_json.encode('utf-8')
self._threaded_producer.produce(
self._output_topic,
value=result.extracted_features_json.encode('utf-8'),
on_delivery=self._on_delivery,
)
kafka_future.add_errback(self.push_err_to_sentry)

def _on_delivery(self, err: Any, msg: Any) -> None:
if err is not None:
self.push_err_to_sentry(err)

def flush(self, timeout: float = 30) -> int:
return self._threaded_producer.flush(timeout)

@classmethod
def push_err_to_sentry(cls, e: Exception) -> None:
logger.error(f'exception raised when pushing event to kafka: {str(e)}')
sentry_sdk.capture_exception(error=e)

def stop(self) -> None:
pass
remaining = self._threaded_producer.close(timeout=10)
if remaining > 0:
logger.warning(f'{remaining} messages were not delivered')
Loading
Loading