Skip to content
Draft
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions airbyte/_executors/declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from __future__ import annotations

import contextlib
import hashlib
import warnings
from pathlib import Path
Expand All @@ -15,7 +16,10 @@
from airbyte_cdk.sources.declarative.concurrent_declarative_source import (
ConcurrentDeclarativeSource,
)
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever
from airbyte_cdk.sources.types import StreamSlice

from airbyte import exceptions as exc
from airbyte._executors.base import Executor


Expand All @@ -38,6 +42,53 @@ def _suppress_cdk_pydantic_deprecation_warnings() -> None:
)


def _unwrap_to_declarative_stream(stream: object) -> object:
"""Unwrap a concurrent stream wrapper to access the underlying declarative stream.

This function uses duck-typing to navigate through various wrapper layers that may
exist around declarative streams, depending on the CDK version. It tries common
wrapper attribute names and returns the first object that has a 'retriever' attribute.

Args:
stream: A stream object that may be wrapped (e.g., AbstractStream wrapper).

Returns:
The underlying declarative stream object with a retriever attribute.

Raises:
NotImplementedError: If unable to locate a declarative stream with a retriever.
"""
if hasattr(stream, "retriever"):
return stream

wrapper_attrs = [
"declarative_stream",
"wrapped_stream",
"stream",
"_stream",
"underlying_stream",
"inner",
]

for attr_name in wrapper_attrs:
if hasattr(stream, attr_name):
unwrapped = getattr(stream, attr_name)
if unwrapped is not None and hasattr(unwrapped, "retriever"):
return unwrapped

for branch_attr in ["full_refresh_stream", "incremental_stream"]:
if hasattr(stream, branch_attr):
branch_stream = getattr(stream, branch_attr)
if branch_stream is not None and hasattr(branch_stream, "retriever"):
return branch_stream

stream_type = type(stream).__name__
raise NotImplementedError(
f"Unable to locate declarative stream with retriever from {stream_type}. "
f"fetch_record() requires access to the stream's retriever component."
)


class DeclarativeExecutor(Executor):
"""An executor for declarative sources."""

Expand Down Expand Up @@ -140,3 +191,153 @@ def install(self) -> None:
def uninstall(self) -> None:
"""No-op. The declarative source is included with PyAirbyte."""
pass

def fetch_record( # noqa: PLR0914, PLR0912, PLR0915
self,
stream_name: str,
primary_key_value: str,
) -> dict[str, Any]:
"""Fetch a single record by primary key from a declarative stream.

This method uses the already-instantiated streams from the declarative source
to access the stream's retriever and make an HTTP GET request by appending
the primary key value to the stream's base path (e.g., /users/123).

Args:
stream_name: The name of the stream to fetch from.
primary_key_value: The primary key value as a string.

Returns:
The fetched record as a dictionary.

Raises:
exc.AirbyteStreamNotFoundError: If the stream is not found.
exc.AirbyteRecordNotFoundError: If the record is not found (empty response).
NotImplementedError: If the stream does not use SimpleRetriever.
"""
streams = self.declarative_source.streams(self._config_dict)

target_stream = None
for stream in streams:
stream_name_attr = getattr(stream, "name", None)
if stream_name_attr == stream_name:
target_stream = stream
break
try:
unwrapped = _unwrap_to_declarative_stream(stream)
if getattr(unwrapped, "name", None) == stream_name:
target_stream = stream
break
except NotImplementedError:
continue

if target_stream is None:
available_streams = []
for s in streams:
name = getattr(s, "name", None)
if name:
available_streams.append(name)
raise exc.AirbyteStreamNotFoundError(
stream_name=stream_name,
connector_name=self.name,
available_streams=available_streams,
message=f"Stream '{stream_name}' not found in source.",
)

declarative_stream = _unwrap_to_declarative_stream(target_stream)

retriever = declarative_stream.retriever # type: ignore[attr-defined]

if not isinstance(retriever, SimpleRetriever):
raise NotImplementedError(
f"Stream '{stream_name}' uses {type(retriever).__name__}, but fetch_record() "
"only supports SimpleRetriever."
)

empty_slice = StreamSlice(partition={}, cursor_slice={})
base_path = retriever.requester.get_path(
stream_state={},
stream_slice=empty_slice,
next_page_token=None,
)

if base_path:
fetch_path = f"{base_path.rstrip('/')}/{primary_key_value}"
else:
fetch_path = primary_key_value

response = retriever.requester.send_request(
path=fetch_path,
stream_state={},
stream_slice=empty_slice,
next_page_token=None,
request_headers=retriever._request_headers( # noqa: SLF001
stream_slice=empty_slice,
next_page_token=None,
),
request_params=retriever._request_params( # noqa: SLF001
stream_slice=empty_slice,
next_page_token=None,
),
request_body_data=retriever._request_body_data( # noqa: SLF001
stream_slice=empty_slice,
next_page_token=None,
),
request_body_json=retriever._request_body_json( # noqa: SLF001
stream_slice=empty_slice,
next_page_token=None,
),
)

if response is None:
msg = (
f"No response received when fetching record with primary key "
f"'{primary_key_value}' from stream '{stream_name}'."
)
raise exc.AirbyteRecordNotFoundError(
stream_name=stream_name,
primary_key_value=primary_key_value,
connector_name=self.name,
message=msg,
)

records_schema = {}
if hasattr(declarative_stream, "schema_loader"):
schema_loader = declarative_stream.schema_loader
if hasattr(schema_loader, "get_json_schema"):
with contextlib.suppress(Exception):
records_schema = schema_loader.get_json_schema()

records = list(
retriever.record_selector.select_records(
response=response,
stream_state={},
records_schema=records_schema,
stream_slice=empty_slice,
next_page_token=None,
)
)

if not records:
try:
response_json = response.json()
if isinstance(response_json, dict) and response_json:
return response_json
except Exception:
pass

msg = (
f"Record with primary key '{primary_key_value}' "
f"not found in stream '{stream_name}'."
)
raise exc.AirbyteRecordNotFoundError(
stream_name=stream_name,
primary_key_value=primary_key_value,
connector_name=self.name,
message=msg,
)

first_record = records[0]
if hasattr(first_record, "data"):
return dict(first_record.data) # type: ignore[arg-type]
return dict(first_record) # type: ignore[arg-type]
8 changes: 8 additions & 0 deletions airbyte/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,14 @@ class AirbyteStateNotFoundError(AirbyteConnectorError, KeyError):
available_streams: list[str] | None = None


@dataclass
class AirbyteRecordNotFoundError(AirbyteConnectorError):
"""Record not found in stream."""

stream_name: str | None = None
primary_key_value: str | None = None


@dataclass
class PyAirbyteSecretNotFoundError(PyAirbyteError):
"""Secret not found."""
Expand Down
111 changes: 111 additions & 0 deletions airbyte/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from airbyte import exceptions as exc
from airbyte._connector_base import ConnectorBase
from airbyte._executors.declarative import DeclarativeExecutor
from airbyte._message_iterators import AirbyteMessageIterator
from airbyte._util.temp_files import as_temp_files
from airbyte.caches.util import get_default_cache
Expand Down Expand Up @@ -601,6 +602,116 @@ def get_documents(
render_metadata=render_metadata,
)

def _get_stream_primary_key(self, stream_name: str) -> list[str]:
"""Get the primary key for a stream.

Returns the primary key as a flat list of field names.
Handles the Airbyte protocol's nested list structure.
"""
catalog = self.configured_catalog
for configured_stream in catalog.streams:
if configured_stream.stream.name == stream_name:
pk = configured_stream.primary_key
if not pk:
return []
if isinstance(pk, list) and len(pk) > 0:
if isinstance(pk[0], list):
return [field[0] if isinstance(field, list) else field for field in pk]
return list(pk) # type: ignore[arg-type]
return []
raise exc.AirbyteStreamNotFoundError(
stream_name=stream_name,
connector_name=self.name,
available_streams=self.get_available_streams(),
)

def _normalize_and_validate_pk_value(
self,
stream_name: str,
pk_value: Any, # noqa: ANN401
) -> str:
"""Normalize and validate a primary key value.

Accepts:
- A string or int (converted to string)
- A dict with a single entry matching the stream's primary key field

Returns the PK value as a string.
"""
primary_key_fields = self._get_stream_primary_key(stream_name)

if not primary_key_fields:
raise exc.PyAirbyteInputError(
message=f"Stream '{stream_name}' does not have a primary key defined.",
input_value=str(pk_value),
)

if len(primary_key_fields) > 1:
raise NotImplementedError(
f"Stream '{stream_name}' has a composite primary key {primary_key_fields}. "
"Fetching by composite primary key is not yet supported."
)

pk_field = primary_key_fields[0]

if isinstance(pk_value, dict):
if len(pk_value) != 1:
raise exc.PyAirbyteInputError(
message="When providing pk_value as a dict, it must contain exactly one entry.",
input_value=str(pk_value),
)
provided_key = next(iter(pk_value.keys()))
if provided_key != pk_field:
msg = (
f"Primary key field '{provided_key}' does not match "
f"stream's primary key '{pk_field}'."
)
raise exc.PyAirbyteInputError(
message=msg,
input_value=str(pk_value),
)
return str(pk_value[provided_key])

return str(pk_value)

def get_record(
self,
stream_name: str,
*,
pk_value: Any, # noqa: ANN401
) -> dict[str, Any]:
"""Fetch a single record by primary key value.

This method is currently only supported for declarative (YAML-based) sources.

Args:
stream_name: The name of the stream to fetch from.
pk_value: The primary key value. Can be:
- A string or integer value (e.g., "123" or 123)
- A dict with a single entry (e.g., {"id": "123"})

Returns:
The fetched record as a dictionary.

Raises:
exc.AirbyteStreamNotFoundError: If the stream does not exist.
exc.AirbyteRecordNotFoundError: If the record is not found.
exc.PyAirbyteInputError: If the pk_value format is invalid.
NotImplementedError: If the source is not declarative or uses composite keys.
"""
if not isinstance(self.executor, DeclarativeExecutor):
raise NotImplementedError(
f"get_record() is only supported for declarative sources. "
f"This source uses {type(self.executor).__name__}."
)

pk_value_str = self._normalize_and_validate_pk_value(stream_name, pk_value)

return self.executor.fetch_record(
stream_name=stream_name,
primary_key_value=pk_value_str,
)

def get_samples(
self,
streams: list[str] | Literal["*"] | None = None,
Expand Down
Loading
Loading