Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ venv/
dist/
__pycache__/
Pipfile.lock
uv.lock
.ruff_cache/
.vscode
python/example/test.py
Expand Down
6 changes: 6 additions & 0 deletions python/README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
### Usage

#### Features

- Arrow flight server (defaults to port 8815)
- HTTP server (disabled by default)
- Prometheus metrics server (disabled by default)

#### 1. Define your functions in a Python file
```python
from databend_udf import *
Expand Down
89 changes: 83 additions & 6 deletions python/databend_udf/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import json
import time
import logging
import inspect
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -21,6 +22,14 @@
from prometheus_client import start_http_server
import threading

from fastapi import FastAPI, Request, Response
from fastapi.responses import StreamingResponse
from typing import Any, Dict
from uvicorn import run
import pyarrow as pa
from pyarrow import ipc
from io import BytesIO

import pyarrow as pa
from pyarrow.flight import FlightServerBase, FlightInfo

Expand Down Expand Up @@ -216,7 +225,6 @@ def gcd(x, y):
batch_mode=batch_mode,
)


class UDFServer(FlightServerBase):
"""
A server that provides user-defined functions to clients.
Expand All @@ -232,11 +240,19 @@ class UDFServer(FlightServerBase):
_location: str
_functions: Dict[str, UserDefinedFunction]

def __init__(self, location="0.0.0.0:8815", metric_location=None, **kwargs):
def __init__(
self,
location="0.0.0.0:8815",
metric_location=None,
http_location=None,
**kwargs,
):
super(UDFServer, self).__init__("grpc://" + location, **kwargs)
self._location = location
self._metric_location = metric_location
self._http_location = http_location
self._functions = {}
self.app = FastAPI()

# Initialize Prometheus metrics
self.requests_count = Counter(
Expand Down Expand Up @@ -296,16 +312,42 @@ def _start_metrics_server(self):

def start_server():
start_http_server(port, host)
logger.info(
f"Prometheus metrics server started on {self._metric_location}"
)

metrics_thread = threading.Thread(target=start_server, daemon=True)
metrics_thread.start()
except Exception as e:
logger.error(f"Failed to start metrics server: {e}")
raise

def _start_httpudf_server(self):
"""Start UDF HTTP server if http_location is provided"""
try:
host, port = self._http_location.split(":")
port = int(port)
app = self.app

# Middleware to measure elapsed time
@app.middleware("http")
async def log_elapsed_time(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
elapsed_time = time.time() - start_time
logger.info(f"{request.method} {request.url.path} - From {start_time}, Elapsed time: {elapsed_time:.4f} seconds")
return response

@app.get("/")
async def root():
return {"protocol" : "http", "description": "databend-udf-server"}

def start_server():
run(app, host=host, port=port)

http_thread = threading.Thread(target=start_server, daemon=True)
http_thread.start()
except Exception as e:
logger.error(f"Failed to start http udf server: {e}")
raise

def get_flight_info(self, context, descriptor):
"""Return the result schema of a function."""
func_name = descriptor.path[0].decode("utf-8")
Expand Down Expand Up @@ -377,6 +419,39 @@ def add_function(self, udf: UserDefinedFunction):
f"RETURNS {output_type} LANGUAGE python "
f"HANDLER = '{name}' ADDRESS = 'http://{self._location}';"
)

## http router register
@self.app.get("/" + name)
async def flight_info():
# Return the flight info schema of the function
full_schema = pa.schema(list(udf._input_schema) + list(udf._result_schema))
# Serialize the schema using PyArrow's IPC
buf = BytesIO()
writer = pa.ipc.new_stream(buf, full_schema)
writer.close()
# Return the serialized schema as a streaming response
return Response(content=buf.getvalue(), media_type="application/octet-stream")

@self.app.post("/" + name)
async def handle(request: Request):
# Deserialize the RecordBatch from the input data
body = await request.body()
reader = pa.ipc.open_stream(BytesIO(body))
batch = next(reader)

# Create a generator that applies the UDF to the data and yields the results
async def generate_batches():
# Get the first batch from the reader
result_batch = udf.eval_batch(batch)
buf = BytesIO()
writer = pa.ipc.new_stream(buf, udf._result_schema)
for b in result_batch:
writer.write_batch(b)
writer.close()
yield buf.getvalue()

return StreamingResponse(generate_batches(), media_type="application/octet-stream")

logger.info(f"added function: {name}, SQL:\n{sql}\n")

def serve(self):
Expand All @@ -387,7 +462,9 @@ def serve(self):
logger.info(
f"Prometheus metrics available at http://{self._metric_location}/metrics"
)

if self._http_location:
self._start_httpudf_server()
logger.info(f"UDF HTTP SERVER available at http://{self._http_location}")
super(UDFServer, self).serve()


Expand Down
46 changes: 46 additions & 0 deletions python/example/http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import requests
import pyarrow as pa
from pyarrow import ipc
from io import BytesIO


def main():
# Create a RecordBatch
data = [
pa.array([1, 2, 3, 4]),
pa.array([5, 6, 7, 8]),
]
schema = pa.schema(
[
("a", pa.int32()),
("b", pa.int32()),
]
)
batch = pa.RecordBatch.from_arrays(data, schema=schema)


# fetch result schema
response = requests.get("http://localhost:8818/gcd")
reader = pa.ipc.open_stream(BytesIO(response.content))
schema = reader.schema
print("schema \n\n", schema)

# Serialize the RecordBatch
buf = BytesIO()
writer = pa.ipc.new_stream(buf, batch.schema)
writer.write_batch(batch)
writer.close()
serialized_batch = buf.getvalue()

# Send the serialized RecordBatch to the server
response = requests.post("http://localhost:8818/gcd", data=serialized_batch)
# Deserialize the response
reader = pa.ipc.open_stream(BytesIO(response.content))
result_batches = [b for b in reader]
# Print the result
for batch in result_batches:
print("res \n", batch)


if __name__ == "__main__":
main()
17 changes: 14 additions & 3 deletions python/example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
import time
from typing import List, Dict, Any, Tuple, Optional

import sys, os

cwd = os.getcwd()
sys.path.append(cwd)

from databend_udf import udf, UDFServer
# from test import udf, UDFServer

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -48,6 +52,7 @@ def bool_select(condition, a, b):
name="gcd",
input_types=["INT", "INT"],
result_type="INT",
io_threads=1,
skip_null=True,
)
def gcd(x: int, y: int) -> int:
Expand Down Expand Up @@ -300,7 +305,6 @@ def return_all_non_nullable(
json,
)


@udf(input_types=["INT"], result_type="INT")
def wait(x):
time.sleep(0.1)
Expand All @@ -312,9 +316,15 @@ def wait_concurrent(x):
time.sleep(0.1)
return x

@udf(input_types=["INT"], result_type="INT", batch_mode=True)
def wait_batch(x: List[int]) -> List[int]:
time.sleep(0.1)
return x

if __name__ == "__main__":
udf_server = UDFServer("0.0.0.0:8815", metric_location="0.0.0.0:8816")
udf_server = UDFServer(
"0.0.0.0:8815", metric_location="0.0.0.0:8816", http_location="0.0.0.0:8818"
)
udf_server.add_function(add_signed)
udf_server.add_function(add_unsigned)
udf_server.add_function(add_float)
Expand All @@ -338,4 +348,5 @@ def wait_concurrent(x):
udf_server.add_function(return_all_non_nullable)
udf_server.add_function(wait)
udf_server.add_function(wait_concurrent)
udf_server.add_function(wait_batch)
udf_server.serve()
8 changes: 7 additions & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@ readme = "README.md"
requires-python = ">=3.7"
dependencies = [
"pyarrow",
"prometheus-client>=0.17.0"
"prometheus-client>=0.17.0",
"fastapi>=0.103.2",
"uvicorn>=0.22.0",
]

[dev-dependencies]
requests = ">=2.31.0"

[project.optional-dependencies]
lint = ["ruff"]

Expand Down
Loading