Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions airbyte/_executors/declarative.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
from airbyte_cdk.sources.declarative.concurrent_declarative_source import (
ConcurrentDeclarativeSource,
)
from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever
from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream
from airbyte_cdk.sources.types import StreamSlice

from airbyte import exceptions as exc
from airbyte._executors.base import Executor


Expand Down Expand Up @@ -140,3 +144,139 @@ def install(self) -> None:
def uninstall(self) -> None:
"""No-op. The declarative source is included with PyAirbyte."""
pass

def fetch_record(
self,
stream_name: str,
primary_key_value: str,
) -> dict[str, Any]:
"""Fetch a single record by primary key from a declarative stream.

This method uses the already-instantiated streams from the declarative source
to access the stream's retriever and make an HTTP GET request by appending
the primary key value to the stream's base path (e.g., /users/123).

Args:
stream_name: The name of the stream to fetch from.
primary_key_value: The primary key value as a string.

Returns:
The fetched record as a dictionary.

Raises:
exc.AirbyteStreamNotFoundError: If the stream is not found.
exc.AirbyteRecordNotFoundError: If the record is not found (empty response).
NotImplementedError: If the stream does not use SimpleRetriever.
"""
streams = self.declarative_source.streams(self._config_dict)

target_stream = None
for stream in streams:
if stream.name == stream_name:
if not isinstance(stream, AbstractStream):
raise NotImplementedError(
f"Stream '{stream_name}' is type {type(stream).__name__}; "
"fetch_record() supports only AbstractStream."
)
target_stream = stream
break

if target_stream is None:
available_streams = [s.name for s in streams]
raise exc.AirbyteStreamNotFoundError(
stream_name=stream_name,
connector_name=self.name,
available_streams=available_streams,
message=f"Stream '{stream_name}' not found in source.",
)

if not hasattr(target_stream, "retriever"):
raise NotImplementedError(
f"Stream '{stream_name}' does not have a retriever attribute. "
f"fetch_record() requires access to the stream's retriever component."
)

retriever = target_stream.retriever

# Guard: Retriever must be SimpleRetriever
if not isinstance(retriever, SimpleRetriever):
raise NotImplementedError(
f"Stream '{stream_name}' uses {type(retriever).__name__}, but fetch_record() "
"only supports SimpleRetriever."
)

empty_slice = StreamSlice(partition={}, cursor_slice={})
base_path = retriever.requester.get_path(
stream_state={},
stream_slice=empty_slice,
next_page_token=None,
)

if base_path:
fetch_path = f"{base_path.rstrip('/')}/{primary_key_value}"
else:
fetch_path = primary_key_value

response = retriever.requester.send_request(
path=fetch_path,
stream_state={},
stream_slice=empty_slice,
next_page_token=None,
request_headers=retriever._request_headers( # noqa: SLF001
stream_slice=empty_slice,
next_page_token=None,
),
request_params=retriever._request_params( # noqa: SLF001
stream_slice=empty_slice,
next_page_token=None,
),
request_body_data=retriever._request_body_data( # noqa: SLF001
stream_slice=empty_slice,
next_page_token=None,
),
request_body_json=retriever._request_body_json( # noqa: SLF001
stream_slice=empty_slice,
next_page_token=None,
),
)

# Guard: Response must not be None
if response is None:
raise exc.AirbyteRecordNotFoundError(
stream_name=stream_name,
primary_key_value=primary_key_value,
connector_name=self.name,
message=f"No response received when fetching record with primary key "
f"'{primary_key_value}' from stream '{stream_name}'.",
)

records_schema = {}
if hasattr(target_stream, "schema_loader"):
schema_loader = target_stream.schema_loader
if hasattr(schema_loader, "get_json_schema"):
records_schema = schema_loader.get_json_schema()

records = list(
retriever.record_selector.select_records(
response=response,
stream_state={},
records_schema=records_schema,
stream_slice=empty_slice,
next_page_token=None,
)
)

# Guard: Records must not be empty
if not records:
raise exc.AirbyteRecordNotFoundError(
stream_name=stream_name,
primary_key_value=primary_key_value,
connector_name=self.name,
message=f"Record with primary key '{primary_key_value}' "
f"not found in stream '{stream_name}'.",
)

first_record = records[0]
if hasattr(first_record, "data"):
return dict(first_record.data) # type: ignore[arg-type]
return dict(first_record) # type: ignore[arg-type]
8 changes: 8 additions & 0 deletions airbyte/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,14 @@ class AirbyteStateNotFoundError(AirbyteConnectorError, KeyError):
available_streams: list[str] | None = None


@dataclass
class AirbyteRecordNotFoundError(AirbyteConnectorError):
"""Record not found in stream."""

stream_name: str | None = None
primary_key_value: str | None = None


@dataclass
class PyAirbyteSecretNotFoundError(PyAirbyteError):
"""Secret not found."""
Expand Down
137 changes: 137 additions & 0 deletions airbyte/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@

from airbyte import exceptions as exc
from airbyte._connector_base import ConnectorBase
from airbyte._executors.declarative import DeclarativeExecutor
from airbyte._message_iterators import AirbyteMessageIterator
from airbyte._util.name_normalizers import LowerCaseNormalizer
from airbyte._util.temp_files import as_temp_files
from airbyte.caches.util import get_default_cache
from airbyte.datasets._lazy import LazyDataset
Expand Down Expand Up @@ -417,6 +419,11 @@ def configured_catalog(self) -> ConfiguredAirbyteCatalog:
streams_filter: list[str] = self._selected_stream_names or self.get_available_streams()
return self.get_configured_catalog(streams=streams_filter)

@property
def catalog_provider(self) -> CatalogProvider:
"""Return a catalog provider for this source."""
return CatalogProvider(self.configured_catalog)

def get_configured_catalog(
self,
streams: Literal["*"] | list[str] | None = None,
Expand Down Expand Up @@ -601,6 +608,136 @@ def get_documents(
render_metadata=render_metadata,
)

def _normalize_and_validate_pk_value(
self,
stream_name: str,
pk_value: Any, # noqa: ANN401
) -> str:
"""Normalize and validate a primary key value.

Accepts:
- A string or int (converted to string)
- A dict with a single entry matching the stream's primary key field

Returns the PK value as a string.
"""
primary_key_fields = self.catalog_provider.get_primary_keys(stream_name)

if not primary_key_fields:
raise exc.PyAirbyteInputError(
message=f"Stream '{stream_name}' does not have a primary key defined.",
input_value=str(pk_value),
)

if len(primary_key_fields) > 1:
raise NotImplementedError(
f"Stream '{stream_name}' has a composite primary key {primary_key_fields}. "
"Fetching by composite primary key is not yet supported."
)

pk_field = primary_key_fields[0]

if isinstance(pk_value, dict):
if len(pk_value) != 1:
raise exc.PyAirbyteInputError(
message="When providing pk_value as a dict, it must contain exactly one entry.",
input_value=str(pk_value),
)
provided_key = next(iter(pk_value.keys()))
normalized_provided_key = LowerCaseNormalizer.normalize(provided_key)
if normalized_provided_key != pk_field:
msg = (
f"Primary key field '{provided_key}' does not match "
f"stream's primary key '{pk_field}'."
)
raise exc.PyAirbyteInputError(
message=msg,
input_value=str(pk_value),
)
return str(pk_value[provided_key])

return str(pk_value)

def get_record(
self,
stream_name: str,
*,
pk_value: Any, # noqa: ANN401
allow_scanning: bool = False,
scan_timeout_seconds: int = 5,
) -> dict[str, Any]:
"""Fetch a single record by primary key value.

This method is currently only supported for declarative (YAML-based) sources.

Args:
stream_name: The name of the stream to fetch from.
pk_value: The primary key value. Can be:
- A string or integer value (e.g., "123" or 123)
- A dict with a single entry (e.g., {"id": "123"})
allow_scanning: If True, fall back to scanning the stream if direct fetch fails.
scan_timeout_seconds: Maximum time to spend scanning for the record.

Returns:
The fetched record as a dictionary.

Raises:
exc.AirbyteStreamNotFoundError: If the stream does not exist.
exc.AirbyteRecordNotFoundError: If the record is not found.
exc.PyAirbyteInputError: If the pk_value format is invalid.
NotImplementedError: If the source is not declarative or uses composite keys.
"""
if isinstance(self.executor, DeclarativeExecutor):
pk_value_str = self._normalize_and_validate_pk_value(stream_name, pk_value)
try:
return self.executor.fetch_record(
stream_name=stream_name,
primary_key_value=pk_value_str,
)
except (NotImplementedError, exc.AirbyteRecordNotFoundError) as e:
if not allow_scanning:
raise
scan_reason = type(e).__name__

elif not allow_scanning:
raise NotImplementedError(
f"get_record() direct fetch is only supported for declarative sources. "
f"This source uses {type(self.executor).__name__}. "
f"Set allow_scanning=True to enable scanning fallback."
)
else:
scan_reason = "non-declarative source"

pk_value_str = self._normalize_and_validate_pk_value(stream_name, pk_value)
primary_key_fields = self.catalog_provider.get_primary_keys(stream_name)
pk_field = primary_key_fields[0]

start_time = time.monotonic()
for record in self.get_records(stream_name):
if time.monotonic() - start_time > scan_timeout_seconds:
raise exc.AirbyteRecordNotFoundError(
stream_name=stream_name,
context={
"primary_key_field": pk_field,
"primary_key_value": pk_value_str,
"scan_timeout_seconds": scan_timeout_seconds,
"scan_reason": scan_reason,
},
)

record_data = record if isinstance(record, dict) else record.data
if str(record_data.get(pk_field)) == pk_value_str:
return record_data

raise exc.AirbyteRecordNotFoundError(
stream_name=stream_name,
context={
"primary_key_field": pk_field,
"primary_key_value": pk_value_str,
"scan_reason": scan_reason,
},
)

def get_samples(
self,
streams: list[str] | Literal["*"] | None = None,
Expand Down
Loading
Loading