diff --git a/pyproject.toml b/pyproject.toml index 3bd4962662..608bb9630d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dependencies = [ "standard-distutils~=3.11.9; python_version>='3.11'", "databricks-bb-analyzer~=0.1.9", "sqlglot==28.5.0", - "databricks-labs-blueprint[yaml]>=0.11.4,<0.12.0", + "databricks-labs-blueprint[yaml]>=0.12.0,<0.13.0", "databricks-labs-lsql==0.16.0", "cryptography>=44.0.2,<46.1.0", "pyodbc~=5.2.0", @@ -42,7 +42,6 @@ dependencies = [ "duckdb~=1.2.2", "databricks-switch-plugin~=0.1.5", # Temporary, until Switch is migrated to be a transpiler (LSP) plugin. "requests>=2.28.1,<3" # Matches databricks-sdk (and 'types-requests' below), to avoid conflicts. - ] [project.urls] diff --git a/src/databricks/labs/lakebridge/transpiler/lsp/lsp_engine.py b/src/databricks/labs/lakebridge/transpiler/lsp/lsp_engine.py index 9c9930ecc9..a8af40580c 100644 --- a/src/databricks/labs/lakebridge/transpiler/lsp/lsp_engine.py +++ b/src/databricks/labs/lakebridge/transpiler/lsp/lsp_engine.py @@ -38,6 +38,7 @@ from pygls.lsp.client import LanguageClient from databricks.labs.blueprint.installation import JsonValue, RootJsonValue +from databricks.labs.blueprint.logger import readlines from databricks.labs.blueprint.wheels import ProductInfo from databricks.labs.lakebridge.config import LSPConfigOptionV1, TranspileConfig, TranspileResult, extract_string_field from databricks.labs.lakebridge.errors.exceptions import IllegalStateException @@ -376,86 +377,30 @@ async def start_io(self, cmd: str, *args, limit: int = _DEFAULT_LIMIT, **kwargs) await super().start_io(cmd, *args, limit=limit, **kwargs) # forward stderr task = asyncio.create_task(self.pipe_stderr(limit=limit), name="pipe-lsp-stderr") - task.add_done_callback(self._detect_pipe_stderr_exception) self._async_tasks.append(task) async def pipe_stderr(self, *, limit: int = _DEFAULT_LIMIT) -> None: assert (server := self._server) is not None assert (stderr := server.stderr) is not None - return await self.pipe_stream(stream=stderr, limit=limit) - - @staticmethod - async def pipe_stream(*, stream: asyncio.StreamReader, limit: int) -> None: - """Read lines from the LSP server's stderr and log them. - - The lines will be logged in real-time as they arrive, once the newline character is seen. Trailing whitespace - is stripped from each line before logging, and empty lines are ignored. - - On EOF any pending line will be logged, even if it is incomplete (i.e. does not end with a newline). - - Logs are treated as UTF-8, with invalid byte sequences replaced with the Unicode replacement character. - - Long lines will be split into chunks with a maximum length of the limit. If the split falls in the middle of a - multi-byte UTF-8 character, the bytes on either side of the boundary will likely be invalid and logged as such. - - Args: - stream: The stream to mirror as logger output. - limit: The maximum number of bytes for a line to be logged as a single line. Longer lines will be split - into chunks and logged as each chunk arrives. - """ - # Maximum size of pending buffer is the limit argument. - pending_buffer = bytearray() - - # Loop, reading whatever data is available as it arrives. - while chunk := await stream.read(limit - len(pending_buffer)): - # Process the chunk we've read, line by line. - line_from = 0 - while -1 != (idx := chunk.find(b"\n", line_from)): - # Figure out the slice corresponding to this line, accounting for any pending data the last read. - line_chunk = memoryview(chunk)[line_from:idx] - line_bytes: bytearray | bytes - if pending_buffer: - pending_buffer.extend(line_chunk) - line_bytes = pending_buffer - else: - line_bytes = bytes(line_chunk) - del line_chunk - - # Invalid UTF-8 isn't great, but we can at least log it with the replacement character rather than - # dropping it silently or triggering an exception. - message = line_bytes.decode("utf-8", errors="replace").rstrip() - if message: - logger.debug(message) - del line_bytes, message - - # Set up for handling the next line of this chunk. - pending_buffer.clear() - line_from = idx + 1 - # Anything remaining in this chunk is pending data for the next read. - if remaining := memoryview(chunk)[line_from:]: - pending_buffer.extend(remaining) - if len(pending_buffer) >= limit: - # Line too long, log what we have and reset. - log_now = pending_buffer[:limit] - message = log_now.decode("utf-8", errors="replace").rstrip() - if message: - # Note: the very next character might be a '\n', but we don't know that yet. So might be more - # for this line, might not be. - logger.debug(f"{message}[..?]") - del log_now, message, pending_buffer[:limit] - del remaining - if pending_buffer: - # Here we've hit EOF but have an incomplete line pending. Log it anyway. - message = pending_buffer.decode("utf-8", errors="replace").rstrip() - if message: - logger.debug(f"{message} ") - - def _detect_pipe_stderr_exception(self, task: asyncio.Task) -> None: - if (err := task.exception()) is not None: - logger.critical("An error occurred while processing LSP server output", exc_info=err) - elif not self._stop_event.is_set(): - logger.warning("LSP server stderr closed prematurely, no more output will be logged.") + try: + async for line in readlines(stream=stderr, limit=limit): + logger.debug(str(line)) + except Exception as e: # pylint: disable=broad-exception-caught + logger.critical("An error occurred while reading LSP server output; now draining.", exc_info=e) + # Drain to prevent blocking of the subprocess if the pipe is unread. + try: + while await stderr.read(limit): + pass + except Exception as drain_error: # pylint: disable=broad-exception-caught + # Exception while draining, situation seems unrecoverable. + logger.warning( + "Uncoverable error draining LSP server output; beware of deadlock.", exc_info=drain_error + ) + else: + if not self._stop_event.is_set(): + logger.warning("LSP server stderr closed prematurely, no more output will be logged.") + logger.debug("Finished piping stderr from subprocess.") class ChangeManager(abc.ABC): diff --git a/tests/unit/transpiler/test_lsp_err.py b/tests/unit/transpiler/test_lsp_err.py index 544830175d..fc859508d0 100644 --- a/tests/unit/transpiler/test_lsp_err.py +++ b/tests/unit/transpiler/test_lsp_err.py @@ -1,4 +1,3 @@ -import asyncio import logging import re from collections.abc import AsyncGenerator, Generator, Sequence @@ -11,7 +10,6 @@ from databricks.labs.lakebridge.config import TranspileConfig from databricks.labs.lakebridge.transpiler.lsp.lsp_engine import ( LSPEngine, - LakebridgeLanguageClient, logger as lsp_logger, ) @@ -21,7 +19,7 @@ class LSPServerLogs: # The log-level at which the LSP engine writes out stderr lines from the LSP server. stderr_log_level: ClassVar[int] = logging.DEBUG # The function name from the stderr lines are logged, to help filter out other logs. - stderr_log_function: ClassVar[str] = "pipe_stream" + stderr_log_function: ClassVar[str] = "pipe_stderr" def __init__(self, caplog: pytest.LogCaptureFixture): self._caplog = caplog @@ -71,7 +69,7 @@ async def run_lsp_server() -> AsyncGenerator[LSPEngine]: @pytest.mark.asyncio async def test_stderr_captured_as_logs(capture_lsp_server_logs: LSPServerLogs) -> None: - """Verify that output from the LSP engine is captured as logs at INFO level.""" + """Verify that output from the LSP engine is captured as logs.""" # The LSP engine logs a message to stderr when it starts; look for that message in the logs. with capture_lsp_server_logs.capture(): async with run_lsp_server() as lsp_engine: @@ -123,116 +121,3 @@ async def test_stderr_with_long_lines( expected_matcher = re.compile(r"SELECT '(?P.*?)';", re.DOTALL | re.MULTILINE) assert (expected_match := expected_matcher.search(log_text)) is not None assert expected_match.group("padding").count("X") == padding_size - - -class MockStreamReader(asyncio.StreamReader): - """Mock asyncio.StreamReader that returns pre-configured chunks.""" - - def __init__(self, data_chunks: Sequence[bytes]) -> None: - super().__init__() - # Chunks represent data that could be returned on successive reads, mimicking the nature of non-blocking I/O - # where partial data may be returned. The chunk boundaries represent the splits where partial data is returned. - self._remaining_data = data_chunks - - async def read(self, n: int = -1) -> bytes: - match n: - case -1: - # Read all remaining data. - data = b"".join(self._remaining_data) - self._remaining_data = [] - case 0: - # Empty read. - data = b"" - case max_read if max_read > 0: - # Read up to n, but only from the first chunk. - match self._remaining_data: - case []: - data = b"" - case [head, *tail]: - if len(head) <= max_read: - data = head - self._remaining_data = tail - else: - data = head[:max_read] - self._remaining_data = [head[max_read:], *tail] - case _: - raise ValueError(f"Unsupported read size: {n}") - return data - - -async def assert_pipe_stream_chunks_produce_logs( - data_chunks: Sequence[bytes], - expected_messages: Sequence[str], - caplog: pytest.LogCaptureFixture, - *, - limit: int = 128, -) -> None: - stream = MockStreamReader(data_chunks) - with caplog.at_level(LSPServerLogs.stderr_log_level): - await LakebridgeLanguageClient.pipe_stream(stream=stream, limit=limit) - - messages = tuple(LSPServerLogs.get_pipe_stream_logs(caplog)) - assert messages == expected_messages - - -async def test_pipe_stream_normal_lines(caplog: pytest.LogCaptureFixture) -> None: - """Verify the simple case of each line fitting within the limit: one line per log message.""" - data_chunks = (b"first line\n", b"second line\n", b"third line\n") - expected_messages = ("first line", "second line", "third line") - await assert_pipe_stream_chunks_produce_logs(data_chunks, expected_messages, caplog) - - -async def test_pipe_stream_whitespace_handling(caplog: pytest.LogCaptureFixture) -> None: - """Verify that trailing whitespace is stripped, and that empty log-lines are skipped.""" - data_chunks = (b" first line \r\n", b"\tsecond line\t\r\n", b" \t \r\n", b"\n", b"last\tproper\tline\n", b" \t ") - expected_messages = (" first line", "\tsecond line", "last\tproper\tline") - await assert_pipe_stream_chunks_produce_logs(data_chunks, expected_messages, caplog) - - -@pytest.mark.parametrize( - ("data_chunks", "expected_messages"), - ( - # Note: limit for all examples is 10. - # Single line split over 2 reads. - ((b"1234567", b"89\n"), ("123456789",)), - # Single read, exactly on the limit. - ((b"123456789\n",), ("123456789",)), - # Single read, exactly on the minimum limit to trigger premature flush. - ((b"1234567890",), ("1234567890[..?]",)), - # Maximum line length. - ((b"123456789", b"123456789\n"), ("1234567891[..?]", "23456789")), - # Multiple lines in one read, with existing data from the previous read. - ((b"1", b"12\n45\n78\n0", b"12\n"), ("112", "45", "78", "012")), - # A very long line, with some existing data in the buffer, and leaving some remainder. - ( - (b"12", b"3456789012" b"3456789012" b"3456789012" b"34567890\n1234"), - ("1234567890[..?]", "1234567890[..?]", "1234567890[..?]", "1234567890[..?]", "1234 "), - ), - ), -) -async def test_pipe_stream_line_exceeds_limit( - data_chunks: Sequence[bytes], - expected_messages: Sequence[str], - caplog: pytest.LogCaptureFixture, -) -> None: - """Verify that line buffering and splitting is handled, including if a line is (much!) longer than the limit.""" - await assert_pipe_stream_chunks_produce_logs(data_chunks, expected_messages, caplog, limit=10) - - -async def test_pipe_stream_incomplete_line_at_eof(caplog: pytest.LogCaptureFixture) -> None: - """Verify that an incomplete line at EOF is logged.""" - data_chunks = (b"normal_line\n", b"incomplete_line") - expected_messages = ("normal_line", "incomplete_line ") - await assert_pipe_stream_chunks_produce_logs(data_chunks, expected_messages, caplog) - - -async def test_pipe_stream_invalid_utf8(caplog: pytest.LogCaptureFixture) -> None: - """Test invalid UTF-8 sequences are replaced with replacement character.""" - data_chunks = ( - # A line with invalid UTF-8 bytes in it. - b"bad[\xc0\xc0]utf8\n", - # A long line, that will be split across the utf-8 sequence. - "123456789abcd\U0001f596efgh\n".encode("utf-8"), - ) - expected_messages = ("bad[\ufffd\ufffd]utf8", "123456789abcd\ufffd[..?]", "\ufffdefgh") - await assert_pipe_stream_chunks_produce_logs(data_chunks, expected_messages, caplog, limit=16)