Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ dependencies = [
"standard-distutils~=3.11.9; python_version>='3.11'",
"databricks-bb-analyzer~=0.1.9",
"sqlglot==28.5.0",
"databricks-labs-blueprint[yaml]>=0.11.4,<0.12.0",
"databricks-labs-blueprint[yaml]>=0.12.0,<0.13.0",
"databricks-labs-lsql==0.16.0",
"cryptography>=44.0.2,<46.1.0",
"pyodbc~=5.2.0",
Expand All @@ -42,7 +42,6 @@ dependencies = [
"duckdb~=1.2.2",
"databricks-switch-plugin~=0.1.5", # Temporary, until Switch is migrated to be a transpiler (LSP) plugin.
"requests>=2.28.1,<3" # Matches databricks-sdk (and 'types-requests' below), to avoid conflicts.

]

[project.urls]
Expand Down
93 changes: 19 additions & 74 deletions src/databricks/labs/lakebridge/transpiler/lsp/lsp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pygls.lsp.client import LanguageClient

from databricks.labs.blueprint.installation import JsonValue, RootJsonValue
from databricks.labs.blueprint.logger import readlines
from databricks.labs.blueprint.wheels import ProductInfo
from databricks.labs.lakebridge.config import LSPConfigOptionV1, TranspileConfig, TranspileResult, extract_string_field
from databricks.labs.lakebridge.errors.exceptions import IllegalStateException
Expand Down Expand Up @@ -376,86 +377,30 @@ async def start_io(self, cmd: str, *args, limit: int = _DEFAULT_LIMIT, **kwargs)
await super().start_io(cmd, *args, limit=limit, **kwargs)
# forward stderr
task = asyncio.create_task(self.pipe_stderr(limit=limit), name="pipe-lsp-stderr")
task.add_done_callback(self._detect_pipe_stderr_exception)
self._async_tasks.append(task)

async def pipe_stderr(self, *, limit: int = _DEFAULT_LIMIT) -> None:
assert (server := self._server) is not None
assert (stderr := server.stderr) is not None

return await self.pipe_stream(stream=stderr, limit=limit)

@staticmethod
async def pipe_stream(*, stream: asyncio.StreamReader, limit: int) -> None:
"""Read lines from the LSP server's stderr and log them.

The lines will be logged in real-time as they arrive, once the newline character is seen. Trailing whitespace
is stripped from each line before logging, and empty lines are ignored.

On EOF any pending line will be logged, even if it is incomplete (i.e. does not end with a newline).

Logs are treated as UTF-8, with invalid byte sequences replaced with the Unicode replacement character.

Long lines will be split into chunks with a maximum length of the limit. If the split falls in the middle of a
multi-byte UTF-8 character, the bytes on either side of the boundary will likely be invalid and logged as such.

Args:
stream: The stream to mirror as logger output.
limit: The maximum number of bytes for a line to be logged as a single line. Longer lines will be split
into chunks and logged as each chunk arrives.
"""
# Maximum size of pending buffer is the limit argument.
pending_buffer = bytearray()

# Loop, reading whatever data is available as it arrives.
while chunk := await stream.read(limit - len(pending_buffer)):
# Process the chunk we've read, line by line.
line_from = 0
while -1 != (idx := chunk.find(b"\n", line_from)):
# Figure out the slice corresponding to this line, accounting for any pending data the last read.
line_chunk = memoryview(chunk)[line_from:idx]
line_bytes: bytearray | bytes
if pending_buffer:
pending_buffer.extend(line_chunk)
line_bytes = pending_buffer
else:
line_bytes = bytes(line_chunk)
del line_chunk

# Invalid UTF-8 isn't great, but we can at least log it with the replacement character rather than
# dropping it silently or triggering an exception.
message = line_bytes.decode("utf-8", errors="replace").rstrip()
if message:
logger.debug(message)
del line_bytes, message

# Set up for handling the next line of this chunk.
pending_buffer.clear()
line_from = idx + 1
# Anything remaining in this chunk is pending data for the next read.
if remaining := memoryview(chunk)[line_from:]:
pending_buffer.extend(remaining)
if len(pending_buffer) >= limit:
# Line too long, log what we have and reset.
log_now = pending_buffer[:limit]
message = log_now.decode("utf-8", errors="replace").rstrip()
if message:
# Note: the very next character might be a '\n', but we don't know that yet. So might be more
# for this line, might not be.
logger.debug(f"{message}[..?]")
del log_now, message, pending_buffer[:limit]
del remaining
if pending_buffer:
# Here we've hit EOF but have an incomplete line pending. Log it anyway.
message = pending_buffer.decode("utf-8", errors="replace").rstrip()
if message:
logger.debug(f"{message} <missing EOL at EOF>")

def _detect_pipe_stderr_exception(self, task: asyncio.Task) -> None:
if (err := task.exception()) is not None:
logger.critical("An error occurred while processing LSP server output", exc_info=err)
elif not self._stop_event.is_set():
logger.warning("LSP server stderr closed prematurely, no more output will be logged.")
try:
async for line in readlines(stream=stderr, limit=limit):
logger.debug(str(line))
except Exception as e: # pylint: disable=broad-exception-caught
logger.critical("An error occurred while reading LSP server output; now draining.", exc_info=e)
# Drain to prevent blocking of the subprocess if the pipe is unread.
try:
while await stderr.read(limit):
pass
except Exception as drain_error: # pylint: disable=broad-exception-caught
# Exception while draining, situation seems unrecoverable.
logger.warning(
"Uncoverable error draining LSP server output; beware of deadlock.", exc_info=drain_error
)
else:
if not self._stop_event.is_set():
logger.warning("LSP server stderr closed prematurely, no more output will be logged.")
logger.debug("Finished piping stderr from subprocess.")


class ChangeManager(abc.ABC):
Expand Down
119 changes: 2 additions & 117 deletions tests/unit/transpiler/test_lsp_err.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import logging
import re
from collections.abc import AsyncGenerator, Generator, Sequence
Expand All @@ -11,7 +10,6 @@
from databricks.labs.lakebridge.config import TranspileConfig
from databricks.labs.lakebridge.transpiler.lsp.lsp_engine import (
LSPEngine,
LakebridgeLanguageClient,
logger as lsp_logger,
)

Expand All @@ -21,7 +19,7 @@ class LSPServerLogs:
# The log-level at which the LSP engine writes out stderr lines from the LSP server.
stderr_log_level: ClassVar[int] = logging.DEBUG
# The function name from the stderr lines are logged, to help filter out other logs.
stderr_log_function: ClassVar[str] = "pipe_stream"
stderr_log_function: ClassVar[str] = "pipe_stderr"

def __init__(self, caplog: pytest.LogCaptureFixture):
self._caplog = caplog
Expand Down Expand Up @@ -71,7 +69,7 @@ async def run_lsp_server() -> AsyncGenerator[LSPEngine]:

@pytest.mark.asyncio
async def test_stderr_captured_as_logs(capture_lsp_server_logs: LSPServerLogs) -> None:
"""Verify that output from the LSP engine is captured as logs at INFO level."""
"""Verify that output from the LSP engine is captured as logs."""
# The LSP engine logs a message to stderr when it starts; look for that message in the logs.
with capture_lsp_server_logs.capture():
async with run_lsp_server() as lsp_engine:
Expand Down Expand Up @@ -123,116 +121,3 @@ async def test_stderr_with_long_lines(
expected_matcher = re.compile(r"SELECT '(?P<padding>.*?)';", re.DOTALL | re.MULTILINE)
assert (expected_match := expected_matcher.search(log_text)) is not None
assert expected_match.group("padding").count("X") == padding_size


class MockStreamReader(asyncio.StreamReader):
"""Mock asyncio.StreamReader that returns pre-configured chunks."""

def __init__(self, data_chunks: Sequence[bytes]) -> None:
super().__init__()
# Chunks represent data that could be returned on successive reads, mimicking the nature of non-blocking I/O
# where partial data may be returned. The chunk boundaries represent the splits where partial data is returned.
self._remaining_data = data_chunks

async def read(self, n: int = -1) -> bytes:
match n:
case -1:
# Read all remaining data.
data = b"".join(self._remaining_data)
self._remaining_data = []
case 0:
# Empty read.
data = b""
case max_read if max_read > 0:
# Read up to n, but only from the first chunk.
match self._remaining_data:
case []:
data = b""
case [head, *tail]:
if len(head) <= max_read:
data = head
self._remaining_data = tail
else:
data = head[:max_read]
self._remaining_data = [head[max_read:], *tail]
case _:
raise ValueError(f"Unsupported read size: {n}")
return data


async def assert_pipe_stream_chunks_produce_logs(
data_chunks: Sequence[bytes],
expected_messages: Sequence[str],
caplog: pytest.LogCaptureFixture,
*,
limit: int = 128,
) -> None:
stream = MockStreamReader(data_chunks)
with caplog.at_level(LSPServerLogs.stderr_log_level):
await LakebridgeLanguageClient.pipe_stream(stream=stream, limit=limit)

messages = tuple(LSPServerLogs.get_pipe_stream_logs(caplog))
assert messages == expected_messages


async def test_pipe_stream_normal_lines(caplog: pytest.LogCaptureFixture) -> None:
"""Verify the simple case of each line fitting within the limit: one line per log message."""
data_chunks = (b"first line\n", b"second line\n", b"third line\n")
expected_messages = ("first line", "second line", "third line")
await assert_pipe_stream_chunks_produce_logs(data_chunks, expected_messages, caplog)


async def test_pipe_stream_whitespace_handling(caplog: pytest.LogCaptureFixture) -> None:
"""Verify that trailing whitespace is stripped, and that empty log-lines are skipped."""
data_chunks = (b" first line \r\n", b"\tsecond line\t\r\n", b" \t \r\n", b"\n", b"last\tproper\tline\n", b" \t ")
expected_messages = (" first line", "\tsecond line", "last\tproper\tline")
await assert_pipe_stream_chunks_produce_logs(data_chunks, expected_messages, caplog)


@pytest.mark.parametrize(
("data_chunks", "expected_messages"),
(
# Note: limit for all examples is 10.
# Single line split over 2 reads.
((b"1234567", b"89\n"), ("123456789",)),
# Single read, exactly on the limit.
((b"123456789\n",), ("123456789",)),
# Single read, exactly on the minimum limit to trigger premature flush.
((b"1234567890",), ("1234567890[..?]",)),
# Maximum line length.
((b"123456789", b"123456789\n"), ("1234567891[..?]", "23456789")),
# Multiple lines in one read, with existing data from the previous read.
((b"1", b"12\n45\n78\n0", b"12\n"), ("112", "45", "78", "012")),
# A very long line, with some existing data in the buffer, and leaving some remainder.
(
(b"12", b"3456789012" b"3456789012" b"3456789012" b"34567890\n1234"),
("1234567890[..?]", "1234567890[..?]", "1234567890[..?]", "1234567890[..?]", "1234 <missing EOL at EOF>"),
),
),
)
async def test_pipe_stream_line_exceeds_limit(
data_chunks: Sequence[bytes],
expected_messages: Sequence[str],
caplog: pytest.LogCaptureFixture,
) -> None:
"""Verify that line buffering and splitting is handled, including if a line is (much!) longer than the limit."""
await assert_pipe_stream_chunks_produce_logs(data_chunks, expected_messages, caplog, limit=10)


async def test_pipe_stream_incomplete_line_at_eof(caplog: pytest.LogCaptureFixture) -> None:
"""Verify that an incomplete line at EOF is logged."""
data_chunks = (b"normal_line\n", b"incomplete_line")
expected_messages = ("normal_line", "incomplete_line <missing EOL at EOF>")
await assert_pipe_stream_chunks_produce_logs(data_chunks, expected_messages, caplog)


async def test_pipe_stream_invalid_utf8(caplog: pytest.LogCaptureFixture) -> None:
"""Test invalid UTF-8 sequences are replaced with replacement character."""
data_chunks = (
# A line with invalid UTF-8 bytes in it.
b"bad[\xc0\xc0]utf8\n",
# A long line, that will be split across the utf-8 sequence.
"123456789abcd\U0001f596efgh\n".encode("utf-8"),
)
expected_messages = ("bad[\ufffd\ufffd]utf8", "123456789abcd\ufffd[..?]", "\ufffdefgh")
await assert_pipe_stream_chunks_produce_logs(data_chunks, expected_messages, caplog, limit=16)
Loading