Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 20 additions & 25 deletions quarkchain/cluster/jsonrpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,7 @@
import websockets
import rlp
from aiohttp import web
from async_armor import armor
from decorator import decorator
from jsonrpcserver import config
from jsonrpcserver.async_methods import AsyncMethods
from jsonrpcserver.exceptions import InvalidParams, InvalidRequest, ServerError

from quarkchain.cluster.master import MasterServer
from quarkchain.cluster.rpc import AccountBranchData
Expand All @@ -38,6 +34,7 @@
import uuid
from quarkchain.cluster.log_filter import LogFilter
from quarkchain.cluster.subscription import SUB_LOGS
from quarkchain.cluster.jsonrpc_server import RpcMethods, InvalidParams

# defaults
DEFAULT_STARTGAS = 100 * 1000
Expand All @@ -47,13 +44,9 @@
# TODO: revisit this parameter
JSON_RPC_CLIENT_REQUEST_MAX_SIZE = 16 * 1024 * 1024

# Disable jsonrpcserver logging
config.log_requests = False
config.log_responses = False

EMPTY_TX_ID = "0x" + "0" * Constant.TX_ID_HEX_LENGTH


def quantity_decoder(hex_str, allow_optional=False):
"""Decode `hexStr` representing a quantity."""
if allow_optional and hex_str is None:
Expand Down Expand Up @@ -463,8 +456,8 @@ def _parse_log_request(
return addresses, topics


public_methods = AsyncMethods()
private_methods = AsyncMethods()
public_methods = RpcMethods()
private_methods = RpcMethods()


# noinspection PyPep8Naming
Expand Down Expand Up @@ -495,7 +488,7 @@ def start_private_server(cls, env, master_server):

@classmethod
def start_test_server(cls, env, master_server):
methods = AsyncMethods()
methods = RpcMethods()
for method in public_methods.values():
methods.add(method)
for method in private_methods.values():
Expand All @@ -511,7 +504,7 @@ def start_test_server(cls, env, master_server):
return server

def __init__(
self, env, master_server: MasterServer, port, host, methods: AsyncMethods
self, env, master_server: MasterServer, port, host, methods: RpcMethods
):
self.loop = asyncio.get_event_loop()
self.port = port
Expand All @@ -521,7 +514,7 @@ def __init__(
self.counters = dict()

# Bind RPC handler functions to this instance
self.handlers = AsyncMethods()
self.handlers = RpcMethods()
for rpc_name in methods:
func = methods[rpc_name]
self.handlers[rpc_name] = func.__get__(self, self.__class__)
Expand All @@ -540,14 +533,14 @@ async def __handle(self, request):
self.counters[method] += 1
else:
self.counters[method] = 1
# Use armor to prevent the handler from being cancelled when
# Use asyncio.shield to prevent the handler from being cancelled when
# aiohttp server loses connection to client
response = await armor(self.handlers.dispatch(request))
response = await asyncio.shield(self.handlers.dispatch(d))
if response is None:
return web.Response()
if "error" in response:
Logger.error(response)
if response.is_notification:
return web.Response()
return web.json_response(response, status=response.http_status)
return web.json_response(response)

def start(self):
app = web.Application(client_max_size=JSON_RPC_CLIENT_REQUEST_MAX_SIZE)
Expand Down Expand Up @@ -1464,7 +1457,7 @@ def start_websocket_server(cls, env, slave_server):
return server

def __init__(
self, env, slave_server: SlaveServer, port, host, methods: AsyncMethods
self, env, slave_server: SlaveServer, port, host, methods: RpcMethods
):
self.loop = asyncio.get_event_loop()
self.port = port
Expand All @@ -1475,14 +1468,14 @@ def __init__(
self.pending_tx_cache = LRUCache(maxsize=1024)

# Bind RPC handler functions to this instance
self.handlers = AsyncMethods()
self.handlers = RpcMethods()
for rpc_name in methods:
func = methods[rpc_name]
self.handlers[rpc_name] = func.__get__(self, self.__class__)

self.shard_subscription_managers = self.slave.shard_subscription_managers

async def __handle(self, websocket, path):
async def __handle(self, websocket):
sub_ids = dict() # per-websocket var, Dict[sub_id, full_shard_id]
try:
async for message in websocket:
Expand All @@ -1501,14 +1494,16 @@ async def __handle(self, websocket, path):
msg_id = d.get("id", 0)

response = await self.handlers.dispatch(
message,
d,
context={
"websocket": websocket,
"msg_id": msg_id,
"sub_ids": sub_ids,
},
)

if response is None:
continue
if "error" in response:
Logger.error(response)
else:
Expand All @@ -1519,8 +1514,7 @@ async def __handle(self, websocket, path):
elif method == "unsubscribe":
sub_id = d.get("params")[0]
del sub_ids[sub_id]
if not response.is_notification:
await websocket.send(json.dumps(response))
await websocket.send(json.dumps(response))
finally: # current websocket connection terminates, remove subscribers in this connection
for sub_id, full_shard_id in sub_ids.items():
try:
Expand All @@ -1536,7 +1530,8 @@ def start(self):
self.loop.run_until_complete(start_server)

def shutdown(self):
pass # TODO
if hasattr(self, '_server') and self._server is not None:
self._server.close()

@staticmethod
def response_transcoder(sub_id, result):
Expand Down
173 changes: 173 additions & 0 deletions quarkchain/cluster/jsonrpc_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import inspect
import logging
from typing import Any, Callable, Dict, Optional, Awaitable

from aiohttp import web

logger = logging.getLogger(__name__)


class JsonRpcError(Exception):
code = -32000
message = "Server error"

def __init__(self, message=None, data=None):
super().__init__(message or self.message)
self.message = message or self.message
self.data = data

def to_dict(self):
error = {
"code": self.code,
"message": self.message,
}
if self.data is not None:
error["data"] = self.data
return error

class InvalidRequest(JsonRpcError):
code = -32600
message = "Invalid Request"

class MethodNotFound(JsonRpcError):
code = -32601
message = "Method not found"

class InvalidParams(JsonRpcError):
code = -32602
message = "Invalid params"

class ServerError(JsonRpcError):
code = -32000
message = "Server error"


class RpcMethods:
def __init__(self):
self._methods: Dict[str, Callable[..., Awaitable[Any]]] = {}

# ========== dict ==========
def __iter__(self):
return iter(self._methods)

def __getitem__(self, key):
return self._methods[key]

def __setitem__(self, key, value):
self._methods[key] = value

def items(self):
return self._methods.items()

def keys(self):
return self._methods.keys()

def values(self):
return self._methods.values()

# ========== decorator ==========
def add(self, func: Callable[..., Awaitable[Any]] = None, *, name: str = None):
"""
Usage:

@methods.add
async def foo(...):

or:

@methods.add(name="customName")
async def foo(...):
"""
if func is None:
def wrapper(f):
method_name = name or f.__name__
self._methods[method_name] = f
return f
return wrapper

method_name = name or func.__name__
self._methods[method_name] = func
return func

async def dispatch(self, request_json: Dict[str, Any], context=None) -> Optional[Dict[str, Any]]:
req_id = None

try:
if not isinstance(request_json, dict):
raise InvalidRequest("Request must be object")

req_id = request_json.get("id")

if request_json.get("jsonrpc") != "2.0":
raise InvalidRequest("Invalid JSON-RPC version")

method = request_json.get("method")
if not isinstance(method, str):
raise InvalidRequest("Method must be string")

is_notification = "id" not in request_json

if method not in self._methods:
raise MethodNotFound()

handler = self._methods[method]
params = request_json.get("params", [])

# Check if handler accepts a context parameter
sig = inspect.signature(handler)
pass_context = context is not None and "context" in sig.parameters

if isinstance(params, list):
result = await handler(*params, context=context) if pass_context else await handler(*params)
elif isinstance(params, dict):
result = await handler(**params, context=context) if pass_context else await handler(**params)
else:
raise InvalidParams()

if is_notification:
return None

return {
"jsonrpc": "2.0",
"result": result,
"id": req_id,
}

except JsonRpcError as e:
return {
"jsonrpc": "2.0",
"error": e.to_dict(),
"id": req_id,
}

except TypeError as e:
# Could be missing/extra arguments → treat as invalid params
return {
"jsonrpc": "2.0",
"error": {
"code": -32602,
"message": str(e),
},
"id": req_id,
}
except Exception:
logger.exception("Internal JSON-RPC error for method %s", method)
return {
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": "Internal error",
},
"id": req_id,
}

async def aiohttp_handler(self, request: web.Request) -> web.Response:
body = await request.json()

# support batch
if isinstance(body, list):
responses = [await self.dispatch(item) for item in body]
return web.json_response(responses)

response = await self.dispatch(body)
return web.json_response(response)
9 changes: 1 addition & 8 deletions quarkchain/cluster/prom.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
print("======")
raise e

import jsonrpcclient

# Disable jsonrpcclient verbose logging.
logging.getLogger("jsonrpcclient.client.request").setLevel(logging.WARNING)
logging.getLogger("jsonrpcclient.client.response").setLevel(logging.WARNING)

TIMEOUT = 10
fetcher = None
Expand Down Expand Up @@ -54,9 +49,7 @@ def get_highest() -> int:
global fetcher
assert isinstance(fetcher, Fetcher)

res = fetcher.cli.send(
jsonrpcclient.Request("getRootBlockByHeight"), timeout=TIMEOUT
)
res = fetcher.cli.call("getRootBlockByHeight")
if not res:
raise RuntimeError("Failed to get latest block height")
return int(res["height"], 16)
Expand Down
7 changes: 3 additions & 4 deletions quarkchain/cluster/subscription.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import asyncio
import json
from typing import List, Dict, Tuple, Optional, Callable
from typing import Any, List, Dict, Tuple, Optional, Callable

from jsonrpcserver.exceptions import InvalidParams
from websockets import WebSocketServerProtocol
from quarkchain.cluster.jsonrpc_server import InvalidParams

from quarkchain.core import MinorBlock

Expand All @@ -20,7 +19,7 @@ def __init__(self):
SUB_NEW_PENDING_TX: {},
SUB_LOGS: {},
SUB_SYNC: {},
} # type: Dict[str, Dict[str, WebSocketServerProtocol]]
} # type: Dict[str, Dict[str, Any]]
self.log_filter_gen = {} # type: Dict[str, Callable]

def add_subscriber(self, sub_type, sub_id, conn, extra=None):
Expand Down
Loading
Loading