From 203ee6a4cb7f47a9b711ee8aec5ba1794d615d73 Mon Sep 17 00:00:00 2001 From: Sourcery AI Date: Fri, 11 Mar 2022 01:46:23 +0000 Subject: [PATCH] 'Refactored by Sourcery' --- compile.py | 9 ++- .../echo_asyncio/asyncio_protocol_client.py | 11 +-- .../echo_asyncio/asyncio_protocol_server.py | 2 +- .../echo_asyncio/asyncio_stream_client.py | 8 +- .../echo_asyncio/asyncio_stream_server.py | 1 - examples/echo_pymaid/datagram_client.py | 9 +-- examples/echo_ws/client.py | 9 +-- examples/heartbeat/client.py | 7 +- examples/net/client.py | 7 +- examples/pb/client.py | 7 +- examples/pb/stub.py | 20 +---- examples/pb/ws_client.py | 9 +-- plugin/js/jsimpl.py | 21 +++--- plugin/js/jsrpc.py | 15 ++-- pymaid/cli/parser.py | 3 +- pymaid/conf/backend.py | 3 +- pymaid/error/base.py | 35 +++------ pymaid/net/http/h11.py | 9 --- pymaid/net/raw.py | 17 ++--- pymaid/net/utils/uri.py | 2 +- pymaid/net/ws/protocol.py | 75 ++++++++++--------- pymaid/rpc/context.py | 5 +- pymaid/rpc/pb/router.py | 24 ++---- pymaid/rpc/router.py | 2 +- pymaid/utils/autoreload.py | 7 +- pymaid/utils/hash.py | 14 ++-- pymaid/utils/timeout.py | 5 +- 27 files changed, 141 insertions(+), 195 deletions(-) diff --git a/compile.py b/compile.py index 26e5450..683235a 100644 --- a/compile.py +++ b/compile.py @@ -52,9 +52,12 @@ def parse_args(): def get_protos(path): protos = [] for root, dirnames, filenames in os.walk(path): - for filename in filenames: - if filename.endswith('.proto'): - protos.append(os.path.join(root, filename)) + protos.extend( + os.path.join(root, filename) + for filename in filenames + if filename.endswith('.proto') + ) + return protos diff --git a/examples/echo_asyncio/asyncio_protocol_client.py b/examples/echo_asyncio/asyncio_protocol_client.py index 409fdea..80095c8 100644 --- a/examples/echo_asyncio/asyncio_protocol_client.py +++ b/examples/echo_asyncio/asyncio_protocol_client.py @@ -10,7 +10,7 @@ def __init__(self): def connection_made(self, transport): sock = transport.get_extra_info('socket') - args.debug('Connection to {}'.format(sock)) + args.debug(f'Connection to {sock}') self.transport = transport def data_received(self, data): @@ -38,7 +38,7 @@ async def wrapper(loop, address, count): write = transport.write req = b'a' * args.msize receive_event = protocol.receive_event = asyncio.Event() - for x in range(count): + for _ in range(count): write(req) await receive_event.wait() receive_event.clear() @@ -52,12 +52,9 @@ async def main(): global args args = parse_args(get_client_parser()) loop = asyncio.get_running_loop() - tasks = [] - for x in range(args.concurrency): - tasks.append(asyncio.create_task( + tasks = [asyncio.create_task( wrapper(loop, args.address, args.request) - )) - + ) for _ in range(args.concurrency)] await asyncio.gather(*tasks) diff --git a/examples/echo_asyncio/asyncio_protocol_server.py b/examples/echo_asyncio/asyncio_protocol_server.py index 12dfa3b..76f5348 100644 --- a/examples/echo_asyncio/asyncio_protocol_server.py +++ b/examples/echo_asyncio/asyncio_protocol_server.py @@ -7,7 +7,7 @@ class EchoProtocol(asyncio.Protocol): def connection_made(self, transport): sock = transport.get_extra_info('socket') - args.debug('Connection from {}'.format(sock)) + args.debug(f'Connection from {sock}') self.transport = transport def data_received(self, data): diff --git a/examples/echo_asyncio/asyncio_stream_client.py b/examples/echo_asyncio/asyncio_stream_client.py index 194fb0c..5c9de1f 100644 --- a/examples/echo_asyncio/asyncio_stream_client.py +++ b/examples/echo_asyncio/asyncio_stream_client.py @@ -14,7 +14,6 @@ async def wrapper(address, count): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) except (OSError, NameError): args.debug('set nodelay failed') - pass args.debug(f'[conn][reader|{reader}][writer|{writer}] connected') req = b'a' * args.msize @@ -35,9 +34,10 @@ async def wrapper(address, count): async def main(): global args args = parse_args(get_client_parser()) - tasks = [] - for x in range(args.concurrency): - tasks.append(asyncio.create_task(wrapper(args.address, args.request))) + tasks = [ + asyncio.create_task(wrapper(args.address, args.request)) + for _ in range(args.concurrency) + ] await asyncio.gather(*tasks) diff --git a/examples/echo_asyncio/asyncio_stream_server.py b/examples/echo_asyncio/asyncio_stream_server.py index 2ea39fb..da7acf8 100644 --- a/examples/echo_asyncio/asyncio_stream_server.py +++ b/examples/echo_asyncio/asyncio_stream_server.py @@ -11,7 +11,6 @@ async def handler(reader, writer): sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) except (OSError, NameError) as ex: args.debug(f'set nodelay failed {ex}') - pass read, write = reader.read, writer.write while 1: data = await read(256 * 1024) diff --git a/examples/echo_pymaid/datagram_client.py b/examples/echo_pymaid/datagram_client.py index 82bb0b6..1ef9153 100644 --- a/examples/echo_pymaid/datagram_client.py +++ b/examples/echo_pymaid/datagram_client.py @@ -29,7 +29,7 @@ async def wrapper(loop, address, count): write = transport.sendto req = b'a' * args.msize receive_event = protocol.receive_event = pymaid.Event() - for x in range(count): + for _ in range(count): write(req) await receive_event.wait() receive_event.clear() @@ -42,12 +42,9 @@ async def main(): global args args = parse_args(get_client_parser()) loop = pymaid.get_event_loop() - tasks = [] - for x in range(args.concurrency): - tasks.append(pymaid.create_task( + tasks = [pymaid.create_task( wrapper(loop, args.address, args.request) - )) - + ) for _ in range(args.concurrency)] await pymaid.gather(*tasks) diff --git a/examples/echo_ws/client.py b/examples/echo_ws/client.py index 609e418..e7248d0 100644 --- a/examples/echo_ws/client.py +++ b/examples/echo_ws/client.py @@ -29,11 +29,10 @@ async def wrapper(address, count, msize): async def main(): args = parse_args(get_client_parser()) - tasks = [] - for x in range(args.concurrency): - tasks.append( - pymaid.create_task(wrapper(args.address, args.request, args.msize)) - ) + tasks = [ + pymaid.create_task(wrapper(args.address, args.request, args.msize)) + for _ in range(args.concurrency) + ] # await pymaid.wait(tasks, timeout=args.timeout) await pymaid.gather(*tasks) diff --git a/examples/heartbeat/client.py b/examples/heartbeat/client.py index 1808095..6b9ee32 100644 --- a/examples/heartbeat/client.py +++ b/examples/heartbeat/client.py @@ -11,9 +11,10 @@ async def wrapper(address): async def main(): args = parse_args(get_client_parser()) - tasks = [] - for x in range(args.concurrency): - tasks.append(pymaid.create_task(wrapper(args.address))) + tasks = [ + pymaid.create_task(wrapper(args.address)) + for _ in range(args.concurrency) + ] # await pymaid.wait(tasks, timeout=args.timeout) await pymaid.gather(*tasks) diff --git a/examples/net/client.py b/examples/net/client.py index cdd868e..506f2fb 100644 --- a/examples/net/client.py +++ b/examples/net/client.py @@ -25,9 +25,10 @@ async def wrapper(address, count): async def main(): args = parse_args(get_client_parser()) - tasks = [] - for x in range(args.concurrency): - tasks.append(pymaid.create_task(wrapper(args.address, args.request))) + tasks = [ + pymaid.create_task(wrapper(args.address, args.request)) + for _ in range(args.concurrency) + ] # await pymaid.wait(tasks, timeout=args.timeout) await pymaid.gather(*tasks) diff --git a/examples/pb/client.py b/examples/pb/client.py index 4932678..e3c6d0d 100644 --- a/examples/pb/client.py +++ b/examples/pb/client.py @@ -10,11 +10,12 @@ async def main(): args = parse_args(get_client_parser()) service = pymaid.rpc.pb.router.PBRouterStub(EchoService_Stub) - tasks = [] address = args.address request = args.request - for x in range(args.concurrency): - tasks.append(pymaid.create_task(worker(address, service, request))) + tasks = [ + pymaid.create_task(worker(address, service, request)) + for _ in range(args.concurrency) + ] # await pymaid.wait(tasks, timeout=args.timeout) await pymaid.gather(*tasks) diff --git a/examples/pb/stub.py b/examples/pb/stub.py index c297e5b..8f0f543 100644 --- a/examples/pb/stub.py +++ b/examples/pb/stub.py @@ -13,7 +13,7 @@ async def get_requests(): async def worker(address, service, count, **kwargs): conn = await pymaid.rpc.pb.dial_stream(address, **kwargs) - for x in range(count): + for _ in range(count): # UnaryUnaryEcho resp = await service.UnaryUnaryEcho(request, conn=conn) assert len(resp.message) == 8000 @@ -55,24 +55,6 @@ async def worker(address, service, count, **kwargs): async for resp in service.StreamStreamEcho(get_requests(), conn=conn): assert len(resp.message) == 8000 - # # This block performs the same STREAM_STREAM interaction as above - # # while showing more advanced stream control features. - # async with service.StreamStreamEcho.open(conn=conn) as context: - # async for req in get_requests(): - # await context.send_message(request) - # # you can still do something here - # resp = await context.recv_message() - # assert len(resp.message) == 8000 - # # or you can send requests first, then wait for responses - # async for req in get_requests(): - # await context.send_message(request) - # # you can still do something here - # async for resp in context: - # # you can still do something here - # assert len(resp.message) == 8000 - # # you can send end message yourself - # # or let context handle this at cleanup for you - # await context.send_message(end=True) conn.shutdown() conn.close() await conn.wait_closed() diff --git a/examples/pb/ws_client.py b/examples/pb/ws_client.py index 64f787a..886c0c4 100644 --- a/examples/pb/ws_client.py +++ b/examples/pb/ws_client.py @@ -13,21 +13,16 @@ async def main(): args = parse_args(get_client_parser()) service = pymaid.rpc.pb.router.PBRouterStub(EchoService_Stub) - tasks = [] address = args.address request = args.request - for x in range(args.concurrency): - tasks.append( - pymaid.create_task( + tasks = [pymaid.create_task( worker( address, service, request, transport_class=WebSocket | Connection, ) - ) - ) - + ) for _ in range(args.concurrency)] # await pymaid.wait(tasks, timeout=args.timeout) await pymaid.gather(*tasks) diff --git a/plugin/js/jsimpl.py b/plugin/js/jsimpl.py index 5ca4e24..06af5b1 100644 --- a/plugin/js/jsimpl.py +++ b/plugin/js/jsimpl.py @@ -89,9 +89,12 @@ def parse_args(): def get_modules(root_path): modules = [] for root, dirnames, filenames in os.walk(root_path): - for filename in filenames: - if filename.endswith('_pb2.py'): - modules.append(os.path.join(root, filename)) + modules.extend( + os.path.join(root, filename) + for filename in filenames + if filename.endswith('_pb2.py') + ) + print('modules', modules) return modules @@ -116,7 +119,7 @@ def extra_message(message, indent=' '): text = f'{indent}{field.name}: {LABELS[field.label]} ' if field.type == descriptor.FieldDescriptor.TYPE_MESSAGE: fields.append(text + field.message_type.name) - fields.extend(extra_message(field.message_type, indent + ' ')) + fields.extend(extra_message(field.message_type, f'{indent} ')) else: fields.append(text + TYPES[field.type]) # print (fields) @@ -130,12 +133,12 @@ def generate_jsimpl(service_descriptor, package, prefix): service_name = service_descriptor.name print(f'generating {service_descriptor.full_name}') for method in service_descriptor.methods: - req = star_indent + 'req: ' + method.input_type.name + star_indent + req = f'{star_indent}req: {method.input_type.name}{star_indent}' req += star_indent.join(extra_message(method.input_type)) - resp = star_indent + 'resp: ' + method.output_type.name + star_indent + resp = f'{star_indent}resp: {method.output_type.name}{star_indent}' resp += star_indent.join(extra_message(method.output_type)) - input_type = prefix + '.' + method.input_type.full_name - output_type = prefix + '.' + method.output_type.full_name + input_type = f'{prefix}.{method.input_type.full_name}' + output_type = f'{prefix}.{method.output_type.full_name}' requires.update([REQUIRE_TEMPLATE.safe_substitute(name=input_type), REQUIRE_TEMPLATE.safe_substitute(name=output_type)]) in_out_types.extend([ @@ -175,7 +178,7 @@ def generate(path, output, package, prefix, root): if not os.path.exists(output_path): os.makedirs(output_path) file_path = os.path.join(output_path, splits[-1][:-7]) - with open(file_path + '_broadcast.js', 'w') as fp: + with open(f'{file_path}_broadcast.js', 'w') as fp: fp.write(content) diff --git a/plugin/js/jsrpc.py b/plugin/js/jsrpc.py index d1a17fb..7a9159e 100644 --- a/plugin/js/jsrpc.py +++ b/plugin/js/jsrpc.py @@ -56,9 +56,12 @@ def parse_args(): def get_modules(root_path): modules = [] for root, dirnames, filenames in os.walk(root_path): - for filename in filenames: - if filename.endswith('_pb2.py'): - modules.append(os.path.join(root, filename)) + modules.extend( + os.path.join(root, filename) + for filename in filenames + if filename.endswith('_pb2.py') + ) + print('modules', modules) return modules @@ -83,8 +86,8 @@ def generate_js_rpc(service_descriptor, package, prefix): service_name = service_descriptor.name print('generating %s' % service_descriptor.full_name) for method in service_descriptor.methods: - input_type = prefix + '.' + method.input_type.full_name - output_type = prefix + '.' + method.output_type.full_name + input_type = f'{prefix}.{method.input_type.full_name}' + output_type = f'{prefix}.{method.output_type.full_name}' requires.update([REQUIRE_TEMPLATE.safe_substitute(name=input_type), REQUIRE_TEMPLATE.safe_substitute(name=output_type)]) mstr = METHOD_TEMPLATE.safe_substitute( @@ -112,7 +115,7 @@ def generate(path, output, package, prefix, root): if not os.path.exists(output_path): os.makedirs(output_path) file_path = os.path.join(output_path, splits[-1][:-7]) - with open(file_path + '_rpc.js', 'w') as fp: + with open(f'{file_path}_rpc.js', 'w') as fp: fp.write(content) diff --git a/pymaid/cli/parser.py b/pymaid/cli/parser.py index 3979df6..1533e61 100644 --- a/pymaid/cli/parser.py +++ b/pymaid/cli/parser.py @@ -41,8 +41,7 @@ def on_parse_callback(self, args): if self.on_parse: self.on_parse(args) if self.subparsers: - subcmd = self.get_subcmd(args) - if subcmd: + if subcmd := self.get_subcmd(args): sub_parser = self.subparsers._name_parser_map[subcmd] sub_parser.on_parse_callback(args) diff --git a/pymaid/conf/backend.py b/pymaid/conf/backend.py index 2e5108b..c92e997 100644 --- a/pymaid/conf/backend.py +++ b/pymaid/conf/backend.py @@ -137,8 +137,7 @@ async def run(self): for item in update: ns = item['namespaceName'].rsplit('.', 1)[0] nid = item['notificationId'] - data = get_data(ns, self.subscriptions[ns]['format']) - if data: + if data := get_data(ns, self.subscriptions[ns]['format']): self.subscriptions[ns]['notificationId'] = nid delta[ns] = data diff --git a/pymaid/error/base.py b/pymaid/error/base.py index ff03fc8..c7c787f 100644 --- a/pymaid/error/base.py +++ b/pymaid/error/base.py @@ -10,10 +10,7 @@ class BaseEx(Exception, metaclass=abc.ABCMeta): message = 'BaseEx' def __init__(self, *args, **kwargs): - if '_message_' in kwargs: - message = kwargs.pop('_message_') - else: - message = self.message + message = kwargs.pop('_message_') if '_message_' in kwargs else self.message if args or kwargs: message = message.format(*args, **kwargs) self.message = message @@ -28,29 +25,25 @@ def wraps(cls, target: Exception): class Error(BaseEx): def __str__(self): - return '[ERROR][code|{}][message|{}][data|{}]'.format( - self.code, self.message, self.data - ) + return f'[ERROR][code|{self.code}][message|{self.message}][data|{self.data}]' __repr__ = __str__ def __bytes__(self): - return '[ERROR][code|{}][message|{}][data|{}]'.format( - self.code, self.message, self.data - ).encode('utf-8') + return f'[ERROR][code|{self.code}][message|{self.message}][data|{self.data}]'.encode( + 'utf-8' + ) class Warning(BaseEx): def __str__(self): - return '[WARN][code|{}][message|{}][data|{}]'.format( - self.code, self.message, self.data - ) + return f'[WARN][code|{self.code}][message|{self.message}][data|{self.data}]' __repr__ = __str__ def __bytes__(self): - return '[WARN][code|{}][message|{}][data|{}]'.format( - self.code, self.message, self.data - ).encode('utf-8') + return f'[WARN][code|{self.code}][message|{self.message}][data|{self.data}]'.encode( + 'utf-8' + ) class ErrorManager(metaclass=abc.ABCMeta): @@ -79,10 +72,7 @@ def add_manager(cls, name, manager): @classmethod def add_error(cls, name, message, *, code=None): frame = getframe(1) # get caller frame - if cls.__fullname__: - fullname = f'{cls.__fullname__}.{name}' - else: - fullname = name + fullname = f'{cls.__fullname__}.{name}' if cls.__fullname__ else name error = type( name, (Error, cls), { @@ -99,10 +89,7 @@ def add_error(cls, name, message, *, code=None): @classmethod def add_warning(cls, name, message, *, code=None): frame = getframe(1) # get caller frame - if cls.__fullname__: - fullname = f'{cls.__fullname__}.{name}' - else: - fullname = name + fullname = f'{cls.__fullname__}.{name}' if cls.__fullname__ else name warning = type( name, (Warning, cls), diff --git a/pymaid/net/http/h11.py b/pymaid/net/http/h11.py index 1add0cb..57f95c2 100644 --- a/pymaid/net/http/h11.py +++ b/pymaid/net/http/h11.py @@ -102,15 +102,6 @@ def append_header(self, name: str, value: str): # TODO: what to do with `Transfer-Encoding: chunked` if name in self.HEADER_SINGLETON and name in self.headers: return - raise HttpError.BadRequest( - _message_='multiple value for singleton header', - data={ - 'name': name, - 'value': value, - 'in_header': self.headers.getall(name), - }, - ) - if name in self.HEADER_MUST_HAVE_VALUE and not value: raise HttpError.BadRequest( _message_='header that must have value get empty value', diff --git a/pymaid/net/raw.py b/pymaid/net/raw.py index 1a622cc..45397ef 100644 --- a/pymaid/net/raw.py +++ b/pymaid/net/raw.py @@ -76,10 +76,10 @@ async def getaddrinfo( def set_sock_options(sock: socket.socket): - setsockopt = sock.setsockopt - # stream opts if sock.type == socket.SOCK_STREAM and sock.family != socket.AF_UNIX: + setsockopt = sock.setsockopt + setsockopt(socket.SOL_TCP, socket.TCP_NODELAY, 1) setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) @@ -132,14 +132,13 @@ async def sock_connect( sock.close() break - if err is not None: - try: - raise err - finally: - # Break explicitly a reference cycle - err = None - else: + if err is None: raise socket.error('getaddrinfo returns an empty list') + try: + raise err + finally: + # Break explicitly a reference cycle + err = None async def sock_listen( diff --git a/pymaid/net/utils/uri.py b/pymaid/net/utils/uri.py index cc8921e..73265fe 100755 --- a/pymaid/net/utils/uri.py +++ b/pymaid/net/utils/uri.py @@ -70,7 +70,7 @@ def parse_uri(uri: str) -> URI: query = parsed.query fragment = parsed.fragment - if 'unix' == scheme: + if scheme == 'unix': port = None # when using unix domain socket, assume path is the address if host: diff --git a/pymaid/net/ws/protocol.py b/pymaid/net/ws/protocol.py index 07f6b49..1c65c0e 100644 --- a/pymaid/net/ws/protocol.py +++ b/pymaid/net/ws/protocol.py @@ -253,8 +253,10 @@ def build_request(cls, hostname, resource, key, **headers): @classmethod def build_response(cls, headers: CIMultiDict) -> bytes: - if ('websocket' != headers.get('Upgrade', '').lower() - or 'upgrade' != headers.get('Connection', '').lower()): + if ( + headers.get('Upgrade', '').lower() != 'websocket' + or headers.get('Connection', '').lower() != 'upgrade' + ): raise ProtocolError( f'invalid websocket handshake header: {headers}' ) @@ -274,8 +276,10 @@ def build_response(cls, headers: CIMultiDict) -> bytes: @classmethod def validate_upgrade(cls, headers: CIMultiDict, upgrade_key: bytes): - if ('websocket' != headers.get('Upgrade', '').lower() - or 'upgrade' != headers.get('Connection', '').lower()): + if ( + headers.get('Upgrade', '').lower() != 'websocket' + or headers.get('Connection', '').lower() != 'upgrade' + ): raise ProtocolError( f'invalid websocket handshake header: {headers}' ) @@ -335,38 +339,39 @@ def __init__( self.prepare() def prepare(self): - if self.opcode == self.OPCODE_CLOSE: - payload = self.payload - length = self.length - if not length: - self.close_reason = CloseReason.NO_STATUS_RCVD - elif length == 1: - raise ProtocolError(f'Invalid close frame: {self} {payload}') + if self.opcode != self.OPCODE_CLOSE: + return + payload = self.payload + length = self.length + if not length: + self.close_reason = CloseReason.NO_STATUS_RCVD + elif length == 1: + raise ProtocolError(f'Invalid close frame: {self} {payload}') + else: + code = unpack_H(payload[:2])[0] + if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON: + raise ProtocolError('invalid close code range') + try: + code = CloseReason(code) + except ValueError: + pass + if code in LOCAL_ONLY_CLOSE_REASONS: + raise ProtocolError('remote CLOSE with local-only reason') + if (not isinstance(code, CloseReason) + and code <= MAX_PROTOCOL_CLOSE_REASON): + raise ProtocolError('CLOSE with unknown reserved code') + try: + reason = payload[2:].decode('utf-8') + except UnicodeDecodeError: + raise ProtocolError( + 'close reason is not valid UTF-8', + CloseReason.INVALID_FRAME_PAYLOAD_DATA, + ) + if isinstance(code, CloseReason): + code.reason = reason else: - code = unpack_H(payload[:2])[0] - if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON: - raise ProtocolError('invalid close code range') - try: - code = CloseReason(code) - except ValueError: - pass - if code in LOCAL_ONLY_CLOSE_REASONS: - raise ProtocolError('remote CLOSE with local-only reason') - if (not isinstance(code, CloseReason) - and code <= MAX_PROTOCOL_CLOSE_REASON): - raise ProtocolError('CLOSE with unknown reserved code') - try: - reason = payload[2:].decode('utf-8') - except UnicodeDecodeError: - raise ProtocolError( - 'close reason is not valid UTF-8', - CloseReason.INVALID_FRAME_PAYLOAD_DATA, - ) - if isinstance(code, CloseReason): - code.reason = reason - else: - code = (code, reason) - self.close_reason = code + code = (code, reason) + self.close_reason = code def __repr__(self): return ( diff --git a/pymaid/rpc/context.py b/pymaid/rpc/context.py index 6d4da0d..ceac436 100644 --- a/pymaid/rpc/context.py +++ b/pymaid/rpc/context.py @@ -214,10 +214,7 @@ def __init__(self, initiative: bool): # for initiative side, the id will be EVEN # for passive side, the id will be ODD self.initiative = initiative - if initiative: - self.outbound_transmission_id = 1 - else: - self.outbound_transmission_id = 2 + self.outbound_transmission_id = 1 if initiative else 2 self.contexts = {} def next_transmission_id(self) -> int: diff --git a/pymaid/rpc/pb/router.py b/pymaid/rpc/pb/router.py index 01f8235..5b3419d 100644 --- a/pymaid/rpc/pb/router.py +++ b/pymaid/rpc/pb/router.py @@ -25,18 +25,15 @@ def get_service_methods(self, service: GeneratedServiceType): method.CopyToProto(mdp) if not mdp.client_streaming and not mdp.server_streaming: method_class = UnaryUnaryMethod - elif not mdp.client_streaming and mdp.server_streaming: + elif not mdp.client_streaming: method_class = UnaryStreamMethod - elif mdp.client_streaming and not mdp.server_streaming: + elif not mdp.server_streaming: method_class = StreamUnaryMethod - elif mdp.client_streaming and mdp.server_streaming: - method_class = StreamStreamMethod else: - assert False, 'should be one of above' - + method_class = StreamStreamMethod request_class = service.GetRequestClass(method) response_class = service.GetResponseClass(method) - method_ins = method_class( + yield method_class( method.name, method.full_name, method_impl, @@ -48,7 +45,6 @@ def get_service_methods(self, service: GeneratedServiceType): 'void_response': issubclass(response_class, Void), }, ) - yield method_ins def feed_messages(self, conn, messages): Request = Meta.PacketType.REQUEST @@ -117,18 +113,15 @@ def get_router_stubs(self, stub): method.CopyToProto(mdp) if not mdp.client_streaming and not mdp.server_streaming: method_class = UnaryUnaryMethodStub - elif not mdp.client_streaming and mdp.server_streaming: + elif not mdp.client_streaming: method_class = UnaryStreamMethodStub - elif mdp.client_streaming and not mdp.server_streaming: + elif not mdp.server_streaming: method_class = StreamUnaryMethodStub - elif mdp.client_streaming and mdp.server_streaming: - method_class = StreamStreamMethodStub else: - assert False, 'should be one of above' - + method_class = StreamStreamMethodStub request_class = stub.GetRequestClass(method) response_class = stub.GetResponseClass(method) - method_stub = method_class( + yield method_class( method.name, method.full_name, request_class, @@ -139,4 +132,3 @@ def get_router_stubs(self, stub): 'void_response': issubclass(response_class, Void), }, ) - yield method_stub diff --git a/pymaid/rpc/router.py b/pymaid/rpc/router.py index 6ba102f..1514f31 100644 --- a/pymaid/rpc/router.py +++ b/pymaid/rpc/router.py @@ -28,7 +28,7 @@ def include_service(self, service: ServiceType): assert method.full_name not in routes routes[method.full_name] = method # js/lua pb lib will format as '.service.method' - routes['.' + method.full_name] = method + routes[f'.{method.full_name}'] = method def include_services(self, services: Sequence[ServiceType]): for service in services: diff --git a/pymaid/utils/autoreload.py b/pymaid/utils/autoreload.py index b5f4fbe..2f6911d 100644 --- a/pymaid/utils/autoreload.py +++ b/pymaid/utils/autoreload.py @@ -43,12 +43,9 @@ def source_from_cache(path): if ext not in ('.pyc', '.pyo'): raise ValueError('Not a cached Python file extension', ext) # Should we look for .pyw files? - return basename + '.py' + return f'{basename}.py' -if sys.version_info[0] >= 3: - PY3 = True -else: - PY3 = False +PY3 = sys.version_info[0] >= 3 # ------------------------------------------------------------------------------ # Autoreload functionality diff --git a/pymaid/utils/hash.py b/pymaid/utils/hash.py index 12b7945..ed193d9 100644 --- a/pymaid/utils/hash.py +++ b/pymaid/utils/hash.py @@ -30,9 +30,11 @@ def __init__(self, key, weight=16, enabled=True): self.enabled = enabled def __eq__(self, other): - if not isinstance(other, HashNode): - return NotImplemented - return self.hashed_key == other.hashed_key + return ( + self.hashed_key == other.hashed_key + if isinstance(other, HashNode) + else NotImplemented + ) def __ne__(self, other): return self != other @@ -191,8 +193,8 @@ def rehash(self): entry_count = primes[pos if pos < len(primes) else -1] for node in self.nodes: key = node.key - offset = hash_func('cat' + key) % entry_count - skip = (hash_func('lee' + key) % (entry_count - 1)) + 1 + offset = hash_func(f'cat{key}') % entry_count + skip = hash_func(f'lee{key}') % (entry_count - 1) + 1 permutation.append([ (offset + idx * skip) % entry_count for idx in range(entry_count) @@ -218,7 +220,7 @@ def rehash(self): def get_node(self, key): if not self.nodes: return - key = self.hash_func('cat' + key) + key = self.hash_func(f'cat{key}') return self.nodes[self.lookup_table[key % len(self.lookup_table)]] def reset(self): diff --git a/pymaid/utils/timeout.py b/pymaid/utils/timeout.py index 31df689..0bbac6e 100644 --- a/pymaid/utils/timeout.py +++ b/pymaid/utils/timeout.py @@ -24,10 +24,7 @@ def timeout(delay: Optional[float]) -> 'Timeout': delay - value in seconds or None to disable timeout logic ''' - if delay is not None: - deadline = get_running_loop().time() + delay # type: Optional[float] - else: - deadline = None + deadline = get_running_loop().time() + delay if delay is not None else None return Timeout(deadline)