diff --git a/airbyte/_executors/declarative.py b/airbyte/_executors/declarative.py index e227eca3..9240f9f5 100644 --- a/airbyte/_executors/declarative.py +++ b/airbyte/_executors/declarative.py @@ -15,7 +15,11 @@ from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( ConcurrentDeclarativeSource, ) +from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever +from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream +from airbyte_cdk.sources.types import StreamSlice +from airbyte import exceptions as exc from airbyte._executors.base import Executor @@ -140,3 +144,139 @@ def install(self) -> None: def uninstall(self) -> None: """No-op. The declarative source is included with PyAirbyte.""" pass + + def fetch_record( + self, + stream_name: str, + primary_key_value: str, + ) -> dict[str, Any]: + """Fetch a single record by primary key from a declarative stream. + + This method uses the already-instantiated streams from the declarative source + to access the stream's retriever and make an HTTP GET request by appending + the primary key value to the stream's base path (e.g., /users/123). + + Args: + stream_name: The name of the stream to fetch from. + primary_key_value: The primary key value as a string. + + Returns: + The fetched record as a dictionary. + + Raises: + exc.AirbyteStreamNotFoundError: If the stream is not found. + exc.AirbyteRecordNotFoundError: If the record is not found (empty response). + NotImplementedError: If the stream does not use SimpleRetriever. + """ + streams = self.declarative_source.streams(self._config_dict) + + target_stream = None + for stream in streams: + if stream.name == stream_name: + if not isinstance(stream, AbstractStream): + raise NotImplementedError( + f"Stream '{stream_name}' is type {type(stream).__name__}; " + "fetch_record() supports only AbstractStream." + ) + target_stream = stream + break + + if target_stream is None: + available_streams = [s.name for s in streams] + raise exc.AirbyteStreamNotFoundError( + stream_name=stream_name, + connector_name=self.name, + available_streams=available_streams, + message=f"Stream '{stream_name}' not found in source.", + ) + + if not hasattr(target_stream, "retriever"): + raise NotImplementedError( + f"Stream '{stream_name}' does not have a retriever attribute. " + f"fetch_record() requires access to the stream's retriever component." + ) + + retriever = target_stream.retriever + + # Guard: Retriever must be SimpleRetriever + if not isinstance(retriever, SimpleRetriever): + raise NotImplementedError( + f"Stream '{stream_name}' uses {type(retriever).__name__}, but fetch_record() " + "only supports SimpleRetriever." + ) + + empty_slice = StreamSlice(partition={}, cursor_slice={}) + base_path = retriever.requester.get_path( + stream_state={}, + stream_slice=empty_slice, + next_page_token=None, + ) + + if base_path: + fetch_path = f"{base_path.rstrip('/')}/{primary_key_value}" + else: + fetch_path = primary_key_value + + response = retriever.requester.send_request( + path=fetch_path, + stream_state={}, + stream_slice=empty_slice, + next_page_token=None, + request_headers=retriever._request_headers( # noqa: SLF001 + stream_slice=empty_slice, + next_page_token=None, + ), + request_params=retriever._request_params( # noqa: SLF001 + stream_slice=empty_slice, + next_page_token=None, + ), + request_body_data=retriever._request_body_data( # noqa: SLF001 + stream_slice=empty_slice, + next_page_token=None, + ), + request_body_json=retriever._request_body_json( # noqa: SLF001 + stream_slice=empty_slice, + next_page_token=None, + ), + ) + + # Guard: Response must not be None + if response is None: + raise exc.AirbyteRecordNotFoundError( + stream_name=stream_name, + primary_key_value=primary_key_value, + connector_name=self.name, + message=f"No response received when fetching record with primary key " + f"'{primary_key_value}' from stream '{stream_name}'.", + ) + + records_schema = {} + if hasattr(target_stream, "schema_loader"): + schema_loader = target_stream.schema_loader + if hasattr(schema_loader, "get_json_schema"): + records_schema = schema_loader.get_json_schema() + + records = list( + retriever.record_selector.select_records( + response=response, + stream_state={}, + records_schema=records_schema, + stream_slice=empty_slice, + next_page_token=None, + ) + ) + + # Guard: Records must not be empty + if not records: + raise exc.AirbyteRecordNotFoundError( + stream_name=stream_name, + primary_key_value=primary_key_value, + connector_name=self.name, + message=f"Record with primary key '{primary_key_value}' " + f"not found in stream '{stream_name}'.", + ) + + first_record = records[0] + if hasattr(first_record, "data"): + return dict(first_record.data) # type: ignore[arg-type] + return dict(first_record) # type: ignore[arg-type] diff --git a/airbyte/exceptions.py b/airbyte/exceptions.py index e082f7a8..f7cb1d13 100644 --- a/airbyte/exceptions.py +++ b/airbyte/exceptions.py @@ -412,6 +412,14 @@ class AirbyteStateNotFoundError(AirbyteConnectorError, KeyError): available_streams: list[str] | None = None +@dataclass +class AirbyteRecordNotFoundError(AirbyteConnectorError): + """Record not found in stream.""" + + stream_name: str | None = None + primary_key_value: str | None = None + + @dataclass class PyAirbyteSecretNotFoundError(PyAirbyteError): """Secret not found.""" diff --git a/airbyte/mcp/local_ops.py b/airbyte/mcp/local_ops.py index df018a2b..7ba35569 100644 --- a/airbyte/mcp/local_ops.py +++ b/airbyte/mcp/local_ops.py @@ -461,6 +461,124 @@ def read_source_stream_records( return records +@mcp_tool( + domain="local", + read_only=True, + idempotent=True, + extra_help_text=_CONFIG_HELP, +) +def get_source_record( # noqa: PLR0913, PLR0917 + source_connector_name: Annotated[ + str, + Field(description="The name of the source connector."), + ], + stream_name: Annotated[ + str, + Field(description="The name of the stream to fetch the record from."), + ], + pk_value: Annotated[ + str | int | dict[str, Any], + Field( + description=( + "The primary key value to fetch. " + "Can be a string, int, or dict with PK field name(s) as keys." + ) + ), + ], + config: Annotated[ + dict | str | None, + Field( + description="The configuration for the source connector as a dict or JSON string.", + default=None, + ), + ], + config_file: Annotated[ + str | Path | None, + Field( + description="Path to a YAML or JSON file containing the source connector config.", + default=None, + ), + ], + config_secret_name: Annotated[ + str | None, + Field( + description="The name of the secret containing the configuration.", + default=None, + ), + ], + override_execution_mode: Annotated[ + Literal["docker", "python", "yaml", "auto"], + Field( + description="Optionally override the execution method to use for the connector. " + "This parameter is ignored if manifest_path is provided (yaml mode will be used).", + default="auto", + ), + ], + manifest_path: Annotated[ + str | Path | None, + Field( + description="Path to a local YAML manifest file for declarative connectors.", + default=None, + ), + ], + allow_scanning: Annotated[ + bool, + Field( + description="If True, fall back to scanning stream records if direct fetch fails.", + default=False, + ), + ], + scan_timeout_seconds: Annotated[ + int, + Field( + description="Maximum time in seconds to spend scanning for the record.", + default=60, + ), + ], +) -> dict[str, Any] | str: + """Fetch a single record from a source connector by primary key value. + + This operation requires a valid configuration and only works with + declarative (YAML-based) sources. For sources with SimpleRetriever-based + streams, it will attempt a direct fetch by constructing the appropriate + API request. If allow_scanning is True and direct fetch fails, it will + fall back to scanning through stream records. + """ + try: + source: Source = _get_mcp_source( + connector_name=source_connector_name, + override_execution_mode=override_execution_mode, + manifest_path=manifest_path, + ) + config_dict = resolve_config( + config=config, + config_file=config_file, + config_secret_name=config_secret_name, + config_spec_jsonschema=source.config_spec, + ) + source.set_config(config_dict) + + record = source.get_record( + stream_name=stream_name, + pk_value=pk_value, + allow_scanning=allow_scanning, + scan_timeout_seconds=scan_timeout_seconds, + ) + + print( + f"Retrieved record from stream '{stream_name}' with pk_value={pk_value!r}", + file=sys.stderr, + ) + + except Exception as ex: + tb_str = traceback.format_exc() + return ( + f"Error fetching record from source '{source_connector_name}': {ex!r}, {ex!s}\n{tb_str}" + ) + else: + return record + + @mcp_tool( domain="local", read_only=True, diff --git a/airbyte/sources/base.py b/airbyte/sources/base.py index 7fd5093c..82891797 100644 --- a/airbyte/sources/base.py +++ b/airbyte/sources/base.py @@ -30,7 +30,9 @@ from airbyte import exceptions as exc from airbyte._connector_base import ConnectorBase +from airbyte._executors.declarative import DeclarativeExecutor from airbyte._message_iterators import AirbyteMessageIterator +from airbyte._util.name_normalizers import LowerCaseNormalizer from airbyte._util.temp_files import as_temp_files from airbyte.caches.util import get_default_cache from airbyte.datasets._lazy import LazyDataset @@ -417,6 +419,11 @@ def configured_catalog(self) -> ConfiguredAirbyteCatalog: streams_filter: list[str] = self._selected_stream_names or self.get_available_streams() return self.get_configured_catalog(streams=streams_filter) + @property + def catalog_provider(self) -> CatalogProvider: + """Return a catalog provider for this source.""" + return CatalogProvider(self.configured_catalog) + def get_configured_catalog( self, streams: Literal["*"] | list[str] | None = None, @@ -601,6 +608,136 @@ def get_documents( render_metadata=render_metadata, ) + def _normalize_and_validate_pk_value( + self, + stream_name: str, + pk_value: Any, # noqa: ANN401 + ) -> str: + """Normalize and validate a primary key value. + + Accepts: + - A string or int (converted to string) + - A dict with a single entry matching the stream's primary key field + + Returns the PK value as a string. + """ + primary_key_fields = self.catalog_provider.get_primary_keys(stream_name) + + if not primary_key_fields: + raise exc.PyAirbyteInputError( + message=f"Stream '{stream_name}' does not have a primary key defined.", + input_value=str(pk_value), + ) + + if len(primary_key_fields) > 1: + raise NotImplementedError( + f"Stream '{stream_name}' has a composite primary key {primary_key_fields}. " + "Fetching by composite primary key is not yet supported." + ) + + pk_field = primary_key_fields[0] + + if isinstance(pk_value, dict): + if len(pk_value) != 1: + raise exc.PyAirbyteInputError( + message="When providing pk_value as a dict, it must contain exactly one entry.", + input_value=str(pk_value), + ) + provided_key = next(iter(pk_value.keys())) + normalized_provided_key = LowerCaseNormalizer.normalize(provided_key) + if normalized_provided_key != pk_field: + msg = ( + f"Primary key field '{provided_key}' does not match " + f"stream's primary key '{pk_field}'." + ) + raise exc.PyAirbyteInputError( + message=msg, + input_value=str(pk_value), + ) + return str(pk_value[provided_key]) + + return str(pk_value) + + def get_record( + self, + stream_name: str, + *, + pk_value: Any, # noqa: ANN401 + allow_scanning: bool = False, + scan_timeout_seconds: int = 5, + ) -> dict[str, Any]: + """Fetch a single record by primary key value. + + This method is currently only supported for declarative (YAML-based) sources. + + Args: + stream_name: The name of the stream to fetch from. + pk_value: The primary key value. Can be: + - A string or integer value (e.g., "123" or 123) + - A dict with a single entry (e.g., {"id": "123"}) + allow_scanning: If True, fall back to scanning the stream if direct fetch fails. + scan_timeout_seconds: Maximum time to spend scanning for the record. + + Returns: + The fetched record as a dictionary. + + Raises: + exc.AirbyteStreamNotFoundError: If the stream does not exist. + exc.AirbyteRecordNotFoundError: If the record is not found. + exc.PyAirbyteInputError: If the pk_value format is invalid. + NotImplementedError: If the source is not declarative or uses composite keys. + """ + if isinstance(self.executor, DeclarativeExecutor): + pk_value_str = self._normalize_and_validate_pk_value(stream_name, pk_value) + try: + return self.executor.fetch_record( + stream_name=stream_name, + primary_key_value=pk_value_str, + ) + except (NotImplementedError, exc.AirbyteRecordNotFoundError) as e: + if not allow_scanning: + raise + scan_reason = type(e).__name__ + + elif not allow_scanning: + raise NotImplementedError( + f"get_record() direct fetch is only supported for declarative sources. " + f"This source uses {type(self.executor).__name__}. " + f"Set allow_scanning=True to enable scanning fallback." + ) + else: + scan_reason = "non-declarative source" + + pk_value_str = self._normalize_and_validate_pk_value(stream_name, pk_value) + primary_key_fields = self.catalog_provider.get_primary_keys(stream_name) + pk_field = primary_key_fields[0] + + start_time = time.monotonic() + for record in self.get_records(stream_name): + if time.monotonic() - start_time > scan_timeout_seconds: + raise exc.AirbyteRecordNotFoundError( + stream_name=stream_name, + context={ + "primary_key_field": pk_field, + "primary_key_value": pk_value_str, + "scan_timeout_seconds": scan_timeout_seconds, + "scan_reason": scan_reason, + }, + ) + + record_data = record if isinstance(record, dict) else record.data + if str(record_data.get(pk_field)) == pk_value_str: + return record_data + + raise exc.AirbyteRecordNotFoundError( + stream_name=stream_name, + context={ + "primary_key_field": pk_field, + "primary_key_value": pk_value_str, + "scan_reason": scan_reason, + }, + ) + def get_samples( self, streams: list[str] | Literal["*"] | None = None, diff --git a/tests/unit_tests/test_get_record.py b/tests/unit_tests/test_get_record.py new file mode 100644 index 00000000..277088c6 --- /dev/null +++ b/tests/unit_tests/test_get_record.py @@ -0,0 +1,219 @@ +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +"""Unit tests for Source.get_record() and DeclarativeExecutor.fetch_record().""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import Mock, PropertyMock, patch + +import pytest + +from airbyte_cdk.sources.declarative.retrievers.simple_retriever import SimpleRetriever +from airbyte_cdk.sources.streams.concurrent.abstract_stream import AbstractStream + +from airbyte import exceptions as exc +from airbyte._executors.declarative import DeclarativeExecutor +from airbyte.sources.base import Source + + +@pytest.mark.parametrize( + "stream_name,pk_value,expected_error", + [ + pytest.param("users", "123", None, id="valid_stream_and_pk"), + pytest.param( + "nonexistent", "123", exc.AirbyteStreamNotFoundError, id="stream_not_found" + ), + ], +) +def test_declarative_executor_fetch_record_stream_validation( + stream_name: str, + pk_value: str, + expected_error: type[Exception] | None, +) -> None: + """Test stream validation in DeclarativeExecutor.fetch_record().""" + manifest = { + "streams": [ + { + "name": "users", + "retriever": { + "type": "SimpleRetriever", + "requester": { + "url_base": "https://api.example.com", + "path": "/users", + }, + "record_selector": {"extractor": {"field_path": []}}, + }, + } + ] + } + + executor = DeclarativeExecutor( + name="test-source", + manifest=manifest, + ) + + mock_stream = Mock(spec=AbstractStream) + mock_stream.name = "users" + + mock_retriever = Mock(spec=SimpleRetriever) + mock_retriever.requester = Mock() + mock_retriever.requester.get_path = Mock(return_value="/users") + mock_retriever.requester.send_request = Mock( + return_value=Mock(json=lambda: {"id": "123"}) + ) + mock_retriever._request_headers = Mock(return_value={}) + mock_retriever._request_params = Mock(return_value={}) + mock_retriever._request_body_data = Mock(return_value=None) + mock_retriever._request_body_json = Mock(return_value=None) + mock_retriever.record_selector = Mock() + mock_retriever.record_selector.select_records = Mock(return_value=[{"id": "123"}]) + + mock_stream.retriever = mock_retriever + + mock_streams = [mock_stream] if stream_name == "users" else [] + + mock_declarative_source = Mock() + mock_declarative_source.streams = Mock(return_value=mock_streams) + + if expected_error: + with patch.object( + type(executor), "declarative_source", new_callable=PropertyMock + ) as mock_prop: + mock_prop.return_value = mock_declarative_source + with pytest.raises(expected_error): + executor.fetch_record(stream_name, pk_value) + else: + with patch.object( + type(executor), "declarative_source", new_callable=PropertyMock + ) as mock_prop: + mock_prop.return_value = mock_declarative_source + result = executor.fetch_record(stream_name, pk_value) + assert result == {"id": "123"} + + +@pytest.mark.parametrize( + "pk_value,primary_key_fields,expected_result,expected_error", + [ + pytest.param("123", ["id"], "123", None, id="string_value"), + pytest.param(123, ["id"], "123", None, id="int_value"), + pytest.param({"id": "123"}, ["id"], "123", None, id="dict_with_correct_key"), + pytest.param( + {"wrong_key": "123"}, + ["id"], + None, + exc.PyAirbyteInputError, + id="dict_with_wrong_key", + ), + pytest.param( + {"id": "123", "extra": "456"}, + ["id"], + None, + exc.PyAirbyteInputError, + id="dict_with_multiple_entries", + ), + pytest.param( + "123", + ["id", "org_id"], + None, + NotImplementedError, + id="composite_primary_key", + ), + pytest.param("123", [], None, exc.PyAirbyteInputError, id="no_primary_key"), + ], +) +def test_source_normalize_and_validate_pk_value( + pk_value: Any, + primary_key_fields: list[str], + expected_result: str | None, + expected_error: type[Exception] | None, +) -> None: + """Test _normalize_and_validate_pk_value() handles various input formats.""" + from airbyte.shared.catalog_providers import CatalogProvider + + mock_executor = Mock() + source = Source( + executor=mock_executor, + name="test-source", + config={"api_key": "test"}, + ) + + mock_catalog_provider = Mock(spec=CatalogProvider) + mock_catalog_provider.get_primary_keys.return_value = primary_key_fields + + with patch.object( + type(source), "catalog_provider", new_callable=PropertyMock + ) as mock_provider_prop: + mock_provider_prop.return_value = mock_catalog_provider + + if expected_error: + with pytest.raises(expected_error): + source._normalize_and_validate_pk_value("test_stream", pk_value) + else: + result = source._normalize_and_validate_pk_value("test_stream", pk_value) + assert result == expected_result + + +def test_source_get_record_requires_declarative_executor() -> None: + """Test get_record() raises NotImplementedError for non-declarative executors.""" + from airbyte._executors.python import VenvExecutor + + mock_executor = Mock(spec=VenvExecutor) + source = Source( + executor=mock_executor, + name="test-source", + config={"api_key": "test"}, + ) + + with pytest.raises( + NotImplementedError, match="only supported for declarative sources" + ): + source.get_record("test_stream", pk_value="123") + + +def test_source_get_record_calls_executor_fetch_record() -> None: + """Test get_record() calls executor.fetch_record() with correct parameters.""" + mock_executor = Mock(spec=DeclarativeExecutor) + mock_executor.fetch_record.return_value = {"id": "123", "name": "Test"} + + source = Source( + executor=mock_executor, + name="test-source", + config={"api_key": "test"}, + ) + source._config_dict = {"api_key": "test"} + + with patch.object(source, "_normalize_and_validate_pk_value", return_value="123"): + result = source.get_record("test_stream", pk_value="123") + + assert result == {"id": "123", "name": "Test"} + mock_executor.fetch_record.assert_called_once_with( + stream_name="test_stream", + primary_key_value="123", + ) + + +@pytest.mark.parametrize( + "pk_value", + [ + pytest.param("123", id="string_pk"), + pytest.param(123, id="int_pk"), + pytest.param({"id": "123"}, id="dict_pk"), + ], +) +def test_source_get_record_accepts_various_pk_formats(pk_value: Any) -> None: + """Test get_record() accepts various PK value formats.""" + mock_executor = Mock(spec=DeclarativeExecutor) + mock_executor.fetch_record.return_value = {"id": "123", "name": "Test"} + + source = Source( + executor=mock_executor, + name="test-source", + config={"api_key": "test"}, + ) + source._config_dict = {"api_key": "test"} + + with patch.object(source, "_normalize_and_validate_pk_value", return_value="123"): + result = source.get_record("test_stream", pk_value=pk_value) + + assert result == {"id": "123", "name": "Test"} + mock_executor.fetch_record.assert_called_once()