Skip to content

Commit

Permalink
Feat: Faster per-record processing (#301)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers authored Jul 18, 2024
1 parent e784ed9 commit ce20b56
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 136 deletions.
17 changes: 11 additions & 6 deletions airbyte/_future_cdk/record_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from airbyte import exceptions as exc
from airbyte._future_cdk.state_writers import StdOutStateWriter
from airbyte.records import StreamRecordHandler
from airbyte.strategies import WriteStrategy


Expand Down Expand Up @@ -156,7 +157,7 @@ def process_input_stream(
def process_record_message(
self,
record_msg: AirbyteRecordMessage,
stream_schema: dict,
stream_record_handler: StreamRecordHandler,
) -> None:
"""Write a record.
Expand All @@ -180,22 +181,26 @@ def process_airbyte_messages(
context={"write_strategy": write_strategy},
)

stream_schemas: dict[str, dict] = {}
stream_record_handlers: dict[str, StreamRecordHandler] = {}

# Process messages, writing to batches as we go
for message in messages:
if message.type is Type.RECORD:
record_msg = cast(AirbyteRecordMessage, message.record)
stream_name = record_msg.stream

if stream_name not in stream_schemas:
stream_schemas[stream_name] = self.catalog_provider.get_stream_json_schema(
stream_name=stream_name
if stream_name not in stream_record_handlers:
stream_record_handlers[stream_name] = StreamRecordHandler(
json_schema=self.catalog_provider.get_stream_json_schema(
stream_name=stream_name,
),
normalize_keys=True,
prune_extra_fields=False,
)

self.process_record_message(
record_msg,
stream_schema=stream_schemas[stream_name],
stream_record_handler=stream_record_handlers[stream_name],
)

elif message.type is Type.STATE:
Expand Down
5 changes: 3 additions & 2 deletions airbyte/_future_cdk/sql_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from airbyte._future_cdk.catalog_providers import CatalogProvider
from airbyte._future_cdk.state_writers import StateWriterBase
from airbyte._processors.file.base import FileWriterBase
from airbyte.records import StreamRecordHandler
from airbyte.secrets.base import SecretString


Expand Down Expand Up @@ -227,7 +228,7 @@ def get_sql_table(
def process_record_message(
self,
record_msg: AirbyteRecordMessage,
stream_schema: dict,
stream_record_handler: StreamRecordHandler,
) -> None:
"""Write a record to the cache.
Expand All @@ -238,7 +239,7 @@ def process_record_message(
"""
self.file_writer.process_record_message(
record_msg,
stream_schema=stream_schema,
stream_record_handler=stream_record_handler,
)

# Protected members (non-public interface):
Expand Down
9 changes: 3 additions & 6 deletions airbyte/_processors/file/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@

from airbyte import exceptions as exc
from airbyte._batch_handles import BatchHandle
from airbyte._util.name_normalizers import LowerCaseNormalizer
from airbyte.progress import progress
from airbyte.records import StreamRecord
from airbyte.records import StreamRecord, StreamRecordHandler


if TYPE_CHECKING:
Expand Down Expand Up @@ -142,7 +141,7 @@ def cleanup_all(self) -> None:
def process_record_message(
self,
record_msg: AirbyteRecordMessage,
stream_schema: dict,
stream_record_handler: StreamRecordHandler,
) -> None:
"""Write a record to the cache.
Expand All @@ -167,9 +166,7 @@ def process_record_message(
self._write_record_dict(
record_dict=StreamRecord.from_record_message(
record_message=record_msg,
expected_keys=stream_schema["properties"].keys(),
normalizer=LowerCaseNormalizer,
prune_extra_fields=self.prune_extra_fields,
stream_record_handler=stream_record_handler,
),
open_file_writer=batch_handle.open_file_writer,
)
Expand Down
Loading

0 comments on commit ce20b56

Please sign in to comment.