Skip to content

feat: add hover information for columns #4362

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/sushi/models/customers.sql
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ MODEL (
grain customer_id,
description 'Sushi customer data',
column_descriptions (
customer_id = 'customer_id uniquely identifies customers'
customer_id = 'customer_id uniquely identifies customers',
status = 'status of the customer'
)
);

Expand Down
1 change: 0 additions & 1 deletion package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

66 changes: 66 additions & 0 deletions sqlmesh/lsp/columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import typing as t
from dataclasses import dataclass

from lsprotocol.types import Range

from sqlmesh.core.model.definition import SqlModel
from sqlglot import exp

from sqlmesh.lsp.reference import range_from_token_position_details, TokenPositionDetails


@dataclass
class ColumnDescriptionMap:
range: Range
column_name: str
data_type: t.Optional[str] = None
description: t.Optional[str] = None


def get_columns_and_ranges_for_model(model: SqlModel) -> t.Optional[t.List[ColumnDescriptionMap]]:
"""
Get the top level columns and their position in the file to be able to provide hover information.

If the column information is not available, return None.
"""
type_definitions = model.columns_to_types
columns = model.column_descriptions
query = model.query

if not isinstance(query, exp.Query):
return None

path = model._path
if not path.is_file():
return None
with open(path, "r") as f:
lines = f.readlines()

# Get the top-level columns from the SELECT
outs = []
top_level_columns = query.expressions
for projection in top_level_columns:
if isinstance(projection, exp.Alias):
column = projection.get('alias')
elif isinstance(projection, exp.Column):
column = projection
else:
continue

if not isinstance(column, exp.Column):
continue

column_name = column.name
data_type = type_definitions[column_name] if type_definitions is not None else None
description = columns[column_name] if column_name in columns else None
token_details = TokenPositionDetails.from_meta(column.this.meta)
column_range = range_from_token_position_details(token_details, lines)
column_description_map = ColumnDescriptionMap(
range=column_range,
column_name=column_name,
data_type=str(data_type) if data_type else None,
description=description,
)
outs.append(column_description_map)

return outs
43 changes: 34 additions & 9 deletions sqlmesh/lsp/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,14 @@
from sqlmesh._version import __version__
from sqlmesh.core.context import Context
from sqlmesh.core.linter.definition import AnnotatedRuleViolation
from sqlmesh.core.model import SqlModel
from sqlmesh.lsp.columns import get_columns_and_ranges_for_model
from sqlmesh.lsp.completions import get_sql_completions
from sqlmesh.lsp.context import LSPContext, ModelTarget
from sqlmesh.lsp.custom import ALL_MODELS_FEATURE, AllModelsRequest, AllModelsResponse
from sqlmesh.lsp.reference import (
get_references,
is_position_in_range,
)


Expand Down Expand Up @@ -189,17 +192,39 @@ def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hov
references = get_references(
self.lsp_context, params.text_document.uri, params.position
)
if not references:
if references:
reference = references[0]
if reference.description:
return types.Hover(
contents=types.MarkupContent(
kind=types.MarkupKind.Markdown, value=reference.description
),
range=reference.range,
)

# Try columns if no description is found
models = self.lsp_context.map[params.text_document.uri]
if not models:
return None
reference = references[0]
if not reference.description:
if not isinstance(models, ModelTarget):
return None
return types.Hover(
contents=types.MarkupContent(
kind=types.MarkupKind.Markdown, value=reference.description
),
range=reference.range,
)
model = self.lsp_context.context.get_model(models.names[0])
if not isinstance(model, SqlModel):
return None
columns = get_columns_and_ranges_for_model(model)
if not columns:
return None

for column in columns:
if column.description and is_position_in_range(column.range, params.position):
return types.Hover(
contents=types.MarkupContent(
kind=types.MarkupKind.Markdown,
value=column.description,
),
range=column.range,
)
return None

except Exception as e:
ls.show_message(f"Error getting hover information: {e}", types.MessageType.Error)
Expand Down
29 changes: 15 additions & 14 deletions sqlmesh/lsp/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ class Reference(PydanticModel):
description: t.Optional[str] = None


def is_position_in_range(range: Range, position: Position) -> bool:
return (
range.start.line < position.line
or (range.start.line == position.line and range.start.character <= position.character)
) and (
range.end.line > position.line
or (range.end.line == position.line and range.end.character >= position.character)
)


def by_position(position: Position) -> t.Callable[[Reference], bool]:
"""
Filter reference to only filter references that contain the given position.
Expand All @@ -35,17 +45,8 @@ def by_position(position: Position) -> t.Callable[[Reference], bool]:
A function that returns True if the reference contains the position, False otherwise
"""

def contains_position(r: Reference) -> bool:
return (
r.range.start.line < position.line
or (
r.range.start.line == position.line
and r.range.start.character <= position.character
)
) and (
r.range.end.line > position.line
or (r.range.end.line == position.line and r.range.end.character >= position.character)
)
def contains_position(reference: Reference) -> bool:
return is_position_in_range(reference.range, position)

return contains_position

Expand Down Expand Up @@ -167,15 +168,15 @@ def get_model_definitions_for_a_path(

# Extract metadata for positioning
table_meta = TokenPositionDetails.from_meta(table.this.meta)
table_range = _range_from_token_position_details(table_meta, read_file)
table_range = range_from_token_position_details(table_meta, read_file)
start_pos = table_range.start
end_pos = table_range.end

# If there's a catalog or database qualifier, adjust the start position
catalog_or_db = table.args.get("catalog") or table.args.get("db")
if catalog_or_db is not None:
catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta)
catalog_or_db_range = _range_from_token_position_details(catalog_or_db_meta, read_file)
catalog_or_db_range = range_from_token_position_details(catalog_or_db_meta, read_file)
start_pos = catalog_or_db_range.start

references.append(
Expand Down Expand Up @@ -215,7 +216,7 @@ def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails":
)


def _range_from_token_position_details(
def range_from_token_position_details(
token_position_details: TokenPositionDetails, read_file: t.List[str]
) -> Range:
"""
Expand Down
28 changes: 28 additions & 0 deletions tests/lsp/test_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from sqlmesh import Context
from sqlmesh.core.model.definition import SqlModel
from sqlmesh.lsp.columns import get_columns_and_ranges_for_model
from sqlmesh.lsp.context import LSPContext


def test_get_columns_and_ranges_for_model():
context = Context(paths=["examples/sushi"])
lsp_context = LSPContext(context)

model = lsp_context.context.get_model("sushi.customers")
if not isinstance(model, SqlModel):
raise ValueError("Model is not a SqlModel")

columns = get_columns_and_ranges_for_model(model)
assert columns is not None

assert len(columns) == 3
assert columns[0].column_name == "customer_id"
assert columns[0].description == "customer_id uniquely identifies customers"
assert columns[0].data_type == "INT"
assert columns[0].range is not None
assert columns[0].range.start.line == 27
assert columns[0].range.end.line == 27
assert columns[1].column_name == "status"
assert columns[1].description is None
assert columns[2].column_name == "zip"
assert columns[2].description is None