diff --git a/examples/asyncio/coroutines.py b/examples/asyncio/coroutines.py index 3d1ad7d6a..f4355664c 100644 --- a/examples/asyncio/coroutines.py +++ b/examples/asyncio/coroutines.py @@ -13,14 +13,14 @@ ctx = Context.instance() -async def ping(): +async def ping() -> None: """print dots to indicate idleness""" while True: await asyncio.sleep(0.5) print('.') -async def receiver(): +async def receiver() -> None: """receive messages with polling""" pull = ctx.socket(zmq.PULL) pull.connect(url) @@ -34,7 +34,7 @@ async def receiver(): print('recvd', msg) -async def sender(): +async def sender() -> None: """send a message every second""" tic = time.time() push = ctx.socket(zmq.PUSH) diff --git a/examples/asyncio/helloworld_pubsub_dealerrouter.py b/examples/asyncio/helloworld_pubsub_dealerrouter.py index dd6006754..f9ec53e41 100644 --- a/examples/asyncio/helloworld_pubsub_dealerrouter.py +++ b/examples/asyncio/helloworld_pubsub_dealerrouter.py @@ -12,17 +12,18 @@ import logging import traceback +import zmq import zmq.asyncio from zmq.asyncio import Context # set message based on language class HelloWorld: - def __init__(self): + def __init__(self) -> None: self.lang = 'eng' self.msg = "Hello World" - def change_language(self): + def change_language(self) -> None: if self.lang == 'eng': self.lang = 'jap' self.msg = "Hello Sekai" @@ -31,7 +32,7 @@ def change_language(self): self.lang = 'eng' self.msg = "Hello World" - def msg_pub(self): + def msg_pub(self) -> str: return self.msg @@ -39,13 +40,13 @@ def msg_pub(self): # changes "World" to "Sekai" and returns message 'sekai' class HelloWorldPrinter: # process received message - def msg_sub(self, msg): + def msg_sub(self, msg: str) -> None: print(f"message received world: {msg}") # manages message flow between publishers and subscribers class HelloWorldMessage: - def __init__(self, url='127.0.0.1', port='5555'): + def __init__(self, url: str = '127.0.0.1', port: int = 5555): # get ZeroMQ version print("Current libzmq version is %s" % zmq.zmq_version()) print("Current pyzmq version is %s" % zmq.__version__) @@ -57,6 +58,8 @@ def __init__(self, url='127.0.0.1', port='5555'): # init hello world publisher obj self.hello_world = HelloWorld() + def main(self) -> None: + # activate publishers / subscribers asyncio.get_event_loop().run_until_complete( asyncio.wait( @@ -70,7 +73,7 @@ def __init__(self, url='127.0.0.1', port='5555'): ) # generates message "Hello World" and publish to topic 'world' - async def hello_world_pub(self): + async def hello_world_pub(self) -> None: pub = self.ctx.socket(zmq.PUB) pub.connect(self.url) @@ -106,7 +109,7 @@ async def hello_world_pub(self): pass # processes message topic 'world'; "Hello World" or "Hello Sekai" - async def hello_world_sub(self): + async def hello_world_sub(self) -> None: print("Setting up world sub") obj = HelloWorldPrinter() # setup subscriber @@ -120,7 +123,7 @@ async def hello_world_sub(self): # keep listening to all published message on topic 'world' while True: [topic, msg] = await sub.recv_multipart() - print(f"world sub; topic: {topic}\tmessage: {msg}") + print(f"world sub; topic: {topic.decode()}\tmessage: {msg.decode()}") # process message obj.msg_sub(msg.decode('utf-8')) @@ -141,7 +144,7 @@ async def hello_world_sub(self): pass # Deal a message to topic 'lang' that language should be changed - async def lang_changer_dealer(self): + async def lang_changer_dealer(self) -> None: # setup dealer deal = self.ctx.socket(zmq.DEALER) deal.setsockopt(zmq.IDENTITY, b'lang_dealer') @@ -176,7 +179,7 @@ async def lang_changer_dealer(self): pass # changes Hello xxx message when a command is received from topic 'lang'; keeps listening for commands - async def lang_changer_router(self): + async def lang_changer_router(self) -> None: # setup router rout = self.ctx.socket(zmq.ROUTER) rout.bind(self.url[:-1] + f"{int(self.url[-1]) + 1}") @@ -188,7 +191,9 @@ async def lang_changer_router(self): # keep listening to all published message on topic 'world' while True: [id_dealer, msg] = await rout.recv_multipart() - print(f"Command rout; Sender ID: {id_dealer};\tmessage: {msg}") + print( + f"Command rout; Sender ID: {id_dealer!r};\tmessage: {msg.decode()}" + ) self.hello_world.change_language() print( @@ -208,5 +213,10 @@ async def lang_changer_router(self): pass +def main() -> None: + hello_world = HelloWorldMessage() + hello_world.main() + + if __name__ == '__main__': - HelloWorldMessage() + main() diff --git a/examples/asyncio/tornado_asyncio.py b/examples/asyncio/tornado_asyncio.py index fd3132b52..d3dbed277 100644 --- a/examples/asyncio/tornado_asyncio.py +++ b/examples/asyncio/tornado_asyncio.py @@ -5,35 +5,27 @@ import asyncio from tornado.ioloop import IOLoop -from tornado.platform.asyncio import AsyncIOMainLoop import zmq.asyncio -# Tell tornado to use asyncio -AsyncIOMainLoop().install() -# This must be instantiated after the installing the IOLoop -queue = asyncio.Queue() # type: ignore -ctx = zmq.asyncio.Context() - - -async def pushing(): - server = ctx.socket(zmq.PUSH) +async def pushing() -> None: + server = zmq.asyncio.Context.instance().socket(zmq.PUSH) server.bind('tcp://*:9000') while True: await server.send(b"Hello") await asyncio.sleep(1) -async def pulling(): - client = ctx.socket(zmq.PULL) +async def pulling() -> None: + client = zmq.asyncio.Context.instance().socket(zmq.PULL) client.connect('tcp://127.0.0.1:9000') while True: greeting = await client.recv() print(greeting) -def zmq_tornado_loop(): +def main() -> None: loop = IOLoop.current() loop.spawn_callback(pushing) loop.spawn_callback(pulling) @@ -41,4 +33,4 @@ def zmq_tornado_loop(): if __name__ == '__main__': - zmq_tornado_loop() + main() diff --git a/examples/chat/display.py b/examples/chat/display.py index dc4aad9a0..028904ae4 100644 --- a/examples/chat/display.py +++ b/examples/chat/display.py @@ -17,11 +17,12 @@ # # You should have received a copy of the Lesser GNU General Public License # along with this program. If not, see . +from typing import List import zmq -def main(addrs): +def main(addrs: List[str]): context = zmq.Context() socket = context.socket(zmq.SUB) diff --git a/examples/chat/prompt.py b/examples/chat/prompt.py index 9c540ad1d..707d02dc4 100644 --- a/examples/chat/prompt.py +++ b/examples/chat/prompt.py @@ -21,7 +21,7 @@ import zmq -def main(addr, who): +def main(addr: str, who: str): ctx = zmq.Context() socket = ctx.socket(zmq.PUB) diff --git a/examples/cython/example.py b/examples/cython/example.py index 111e29ba3..128d6ee51 100644 --- a/examples/cython/example.py +++ b/examples/cython/example.py @@ -7,7 +7,7 @@ import zmq -def python_sender(url, n): +def python_sender(url: str, n: int) -> None: """Use entirely high-level Python APIs to send messages""" ctx = zmq.Context() s = ctx.socket(zmq.PUSH) @@ -23,7 +23,7 @@ def python_sender(url, n): s.send(buf) -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser(description="send & recv messages with Cython") diff --git a/examples/draft/radio-dish.py b/examples/draft/radio-dish.py index 5d8afa7f1..51d093aef 100644 --- a/examples/draft/radio-dish.py +++ b/examples/draft/radio-dish.py @@ -13,13 +13,13 @@ for i in range(10): time.sleep(0.1) - radio.send(b'%03i' % i, group='numbers') + radio.send(f'{i:03}'.encode('ascii'), group='numbers') try: msg = dish.recv(copy=False) except zmq.Again: print('missed a message') continue - print("Received {}:{}".format(msg.group, msg.bytes.decode('utf8'))) + print(f"Received {msg.group}:{msg.bytes.decode('utf8')}") dish.close() radio.close() diff --git a/examples/eventloop/asyncweb.py b/examples/eventloop/asyncweb.py index 286fc63aa..305c73391 100644 --- a/examples/eventloop/asyncweb.py +++ b/examples/eventloop/asyncweb.py @@ -20,7 +20,7 @@ from zmq.eventloop.future import Context as FutureContext -def slow_responder(): +def slow_responder() -> None: """thread for slowly responding to replies.""" ctx = zmq.Context() socket = ctx.socket(zmq.ROUTER) @@ -35,14 +35,14 @@ def slow_responder(): i += 1 -def dot(): +def dot() -> None: """callback for showing that IOLoop is still responsive while we wait""" sys.stdout.write('.') sys.stdout.flush() class TestHandler(web.RequestHandler): - async def get(self): + async def get(self) -> None: ctx = FutureContext.instance() s = ctx.socket(zmq.DEALER) @@ -56,7 +56,7 @@ async def get(self): self.write(reply) -def main(): +def main() -> None: worker = threading.Thread(target=slow_responder) worker.daemon = True worker.start() diff --git a/examples/eventloop/echostream.py b/examples/eventloop/echostream.py index 310d7652a..28e31844c 100644 --- a/examples/eventloop/echostream.py +++ b/examples/eventloop/echostream.py @@ -1,19 +1,22 @@ #!/usr/bin/env python """Adapted echo.py to put the send in the event loop using a ZMQStream. """ +from typing import List + +from tornado import ioloop import zmq -from zmq.eventloop import ioloop, zmqstream +from zmq.eventloop import zmqstream -loop = ioloop.IOLoop.instance() +loop = ioloop.IOLoop.current() ctx = zmq.Context() s = ctx.socket(zmq.ROUTER) s.bind('tcp://127.0.0.1:5555') -stream = zmqstream.ZMQStream(s, loop) +stream = zmqstream.ZMQStream(s) -def echo(msg): +def echo(msg: List[bytes]): print(" ".join(map(repr, msg))) stream.send_multipart(msg) diff --git a/examples/gevent/poll.py b/examples/gevent/poll.py index 722bcfc5a..09701ba02 100644 --- a/examples/gevent/poll.py +++ b/examples/gevent/poll.py @@ -34,11 +34,11 @@ def sender(): while msgcnt < 10: socks = dict(poller.poll()) if receiver1 in socks and socks[receiver1] == zmq.POLLIN: - print("Message from receiver1: %s" % receiver1.recv()) + print(f"Message from receiver1: {receiver1.recv()!r}") msgcnt += 1 if receiver2 in socks and socks[receiver2] == zmq.POLLIN: - print("Message from receiver2: %s" % receiver2.recv()) + print(f"Message from receiver2: {receiver2.recv()!r}") msgcnt += 1 -print("%d messages received" % msgcnt) +print(f"{msgcnt} messages received") diff --git a/examples/gevent/simple.py b/examples/gevent/simple.py index 1425eb7e7..1996b2000 100644 --- a/examples/gevent/simple.py +++ b/examples/gevent/simple.py @@ -1,3 +1,5 @@ +from typing import Optional + from gevent import spawn, spawn_later import zmq.green as zmq @@ -23,7 +25,7 @@ sock.connect('ipc:///tmp/zmqtest') -def get_objs(sock): +def get_objs(sock: zmq.Socket): while True: o = sock.recv_pyobj() print('received python object:', o) @@ -32,7 +34,7 @@ def get_objs(sock): break -def print_every(s, t=None): +def print_every(s: str, t: Optional[float] = None): print(s) if t: spawn_later(t, print_every, s, t) diff --git a/examples/heartbeat/heartbeater.py b/examples/heartbeat/heartbeater.py index d5df19c9a..d6a8032b1 100644 --- a/examples/heartbeat/heartbeater.py +++ b/examples/heartbeat/heartbeater.py @@ -14,9 +14,12 @@ """ import time +from typing import Set + +from tornado import ioloop import zmq -from zmq.eventloop import ioloop, zmqstream +from zmq.eventloop import zmqstream class HeartBeater: @@ -24,7 +27,13 @@ class HeartBeater: pingstream: a PUB stream pongstream: an ROUTER stream""" - def __init__(self, loop, pingstream, pongstream, period=1000): + def __init__( + self, + loop: ioloop.IOLoop, + pingstream: zmqstream.ZMQStream, + pongstream: zmqstream.ZMQStream, + period: int = 1000, + ): self.loop = loop self.period = period @@ -32,16 +41,16 @@ def __init__(self, loop, pingstream, pongstream, period=1000): self.pongstream = pongstream self.pongstream.on_recv(self.handle_pong) - self.hearts = set() - self.responses = set() + self.hearts: Set = set() + self.responses: Set = set() self.lifetime = 0 - self.tic = time.time() + self.tic = time.monotonic() - self.caller = ioloop.PeriodicCallback(self.beat, period, self.loop) + self.caller = ioloop.PeriodicCallback(self.beat, period) self.caller.start() def beat(self): - toc = time.time() + toc = time.monotonic() self.lifetime += toc - self.tic self.tic = toc print(self.lifetime) @@ -50,18 +59,20 @@ def beat(self): heartfailures = self.hearts.difference(goodhearts) newhearts = self.responses.difference(goodhearts) # print(newhearts, goodhearts, heartfailures) - map(self.handle_new_heart, newhearts) - map(self.handle_heart_failure, heartfailures) + for heart in newhearts: + self.handle_new_heart(heart) + for heart in heartfailures: + self.handle_heart_failure(heart) self.responses = set() - print("%i beating hearts: %s" % (len(self.hearts), self.hearts)) + print(f"{len(self.hearts)} beating hearts: {self.hearts}") self.pingstream.send(str(self.lifetime)) def handle_new_heart(self, heart): - print("yay, got new heart %s!" % heart) + print(f"yay, got new heart {heart}!") self.hearts.add(heart) def handle_heart_failure(self, heart): - print("Heart %s failed :(" % heart) + print(f"Heart {heart} failed :(") self.hearts.remove(heart) def handle_pong(self, msg): diff --git a/examples/heartbeat/ping.py b/examples/heartbeat/ping.py index 49d8a9d1e..d5e3832f3 100644 --- a/examples/heartbeat/ping.py +++ b/examples/heartbeat/ping.py @@ -1,7 +1,7 @@ #!/usr/bin/env python """For use with pong.py -This script simply pings a process started by pong.py or tspong.py, to +This script simply pings a process started by pong.py or tspong.py, to demonstrate that zmq remains responsive while Python blocks. Authors @@ -9,7 +9,6 @@ * MinRK """ -import sys import time import numpy @@ -25,7 +24,8 @@ time.sleep(1) n = 0 while True: - time.sleep(numpy.random.random()) + t: float = numpy.random.random() + time.sleep(t) for i in range(4): n += 1 msg = 'ping %i' % n diff --git a/examples/logger/zmqlogger.py b/examples/logger/zmqlogger.py index 08acee1b7..196379f42 100644 --- a/examples/logger/zmqlogger.py +++ b/examples/logger/zmqlogger.py @@ -25,7 +25,7 @@ ) -def sub_logger(port, level=logging.DEBUG): +def sub_logger(port: int, level: int = logging.DEBUG) -> None: ctx = zmq.Context() sub = ctx.socket(zmq.SUB) sub.bind('tcp://127.0.0.1:%i' % port) @@ -33,23 +33,24 @@ def sub_logger(port, level=logging.DEBUG): logging.basicConfig(level=level) while True: - level, message = sub.recv_multipart() + level_name, message = sub.recv_multipart() + level_name = level_name.decode('ascii').lower() message = message.decode('ascii') if message.endswith('\n'): # trim trailing newline, which will get appended again message = message[:-1] - log = getattr(logging, level.lower().decode('ascii')) + log = getattr(logging, level_name) log(message) -def log_worker(port, interval=1, level=logging.DEBUG): +def log_worker(port: int, interval: float = 1, level: int = logging.DEBUG) -> None: ctx = zmq.Context() pub = ctx.socket(zmq.PUB) pub.connect('tcp://127.0.0.1:%i' % port) logger = logging.getLogger(str(os.getpid())) logger.setLevel(level) - handler = PUBHandler(pub) + handler: PUBHandler = PUBHandler(pub) logger.addHandler(handler) print("starting logger at %i with level=%s" % (os.getpid(), level)) diff --git a/examples/mongodb/client.py b/examples/mongodb/client.py index 39a5dbb40..3456e2980 100644 --- a/examples/mongodb/client.py +++ b/examples/mongodb/client.py @@ -6,6 +6,7 @@ # ----------------------------------------------------------------------------- import json +from typing import Any, Dict, List import zmq @@ -15,26 +16,26 @@ class MongoZMQClient: Client that connects with MongoZMQ server to add/fetch docs """ - def __init__(self, connect_addr='tcp://127.0.0.1:5000'): + def __init__(self, connect_addr: str = 'tcp://127.0.0.1:5000'): self._context = zmq.Context() self._socket = self._context.socket(zmq.DEALER) self._socket.connect(connect_addr) - def _send_recv_msg(self, msg): + def _send_recv_msg(self, msg: List[bytes]) -> str: self._socket.send_multipart(msg) - return self._socket.recv_multipart()[0] + return self._socket.recv_multipart()[0].decode("utf8") - def get_doc(self, keys): - msg = ['get', json.dumps(keys)] + def get_doc(self, keys: Dict[str, Any]) -> Dict: + msg = [b'get', json.dumps(keys).encode("utf8")] json_str = self._send_recv_msg(msg) return json.loads(json_str) - def add_doc(self, doc): - msg = ['add', json.dumps(doc)] + def add_doc(self, doc: Dict) -> str: + msg = [b'add', json.dumps(doc).encode("utf8")] return self._send_recv_msg(msg) -def main(): +def main() -> None: client = MongoZMQClient() for i in range(10): doc = {'job': str(i)} diff --git a/examples/mongodb/controller.py b/examples/mongodb/controller.py index 5c16cc9a1..4d51a0a24 100644 --- a/examples/mongodb/controller.py +++ b/examples/mongodb/controller.py @@ -7,6 +7,7 @@ import json import sys +from typing import Any, Dict, Optional, Union import pymongo import pymongo.json_util @@ -21,7 +22,9 @@ class MongoZMQ: NOTE: mongod must be started before using this class """ - def __init__(self, db_name, table_name, bind_addr="tcp://127.0.0.1:5000"): + def __init__( + self, db_name: str, table_name: str, bind_addr: str = "tcp://127.0.0.1:5000" + ): """ bind_addr: address to bind zmq socket on db_name: name of database to write to (created if doesn't exist) @@ -34,20 +37,21 @@ def __init__(self, db_name, table_name, bind_addr="tcp://127.0.0.1:5000"): self._db = self._conn[self._db_name] self._table = self._db[self._table_name] - def _doc_to_json(self, doc): + def _doc_to_json(self, doc: Any) -> str: return json.dumps(doc, default=pymongo.json_util.default) - def add_document(self, doc): + def add_document(self, doc: Dict) -> Optional[str]: """ Inserts a document (dictionary) into mongo database table """ - print('adding docment %s' % (doc)) + print(f'adding document {doc}') try: self._table.insert(doc) except Exception as e: return 'Error: %s' % e + return None - def get_document_by_keys(self, keys): + def get_document_by_keys(self, keys: Dict[str, Any]) -> Union[Dict, str]: """ Attempts to return a single document from database table that matches each key/value in keys dictionary. @@ -58,7 +62,7 @@ def get_document_by_keys(self, keys): except Exception as e: return 'Error: %s' % e - def start(self): + def start(self) -> None: context = zmq.Context() socket = context.socket(zmq.ROUTER) socket.bind(self._bind_addr) @@ -88,7 +92,7 @@ def start(self): socket.send_multipart(reply) -def main(): +def main() -> None: MongoZMQ('ipcontroller', 'jobs').start() diff --git a/examples/monitoring/simple_monitor.py b/examples/monitoring/simple_monitor.py index 890051695..597b2a3dd 100644 --- a/examples/monitoring/simple_monitor.py +++ b/examples/monitoring/simple_monitor.py @@ -10,11 +10,14 @@ import threading import time +from typing import Any, Dict import zmq from zmq.utils.monitor import recv_monitor_message -line = lambda: print('-' * 40) + +def line() -> None: + print('-' * 40) print("libzmq-%s" % zmq.zmq_version()) @@ -30,10 +33,12 @@ EVENT_MAP[value] = name -def event_monitor(monitor): +def event_monitor(monitor: zmq.Socket) -> None: while monitor.poll(): - evt = recv_monitor_message(monitor) - evt.update({'description': EVENT_MAP[evt['event']]}) + evt: Dict[str, Any] = {} + mon_evt = recv_monitor_message(monitor) + evt.update(mon_evt) + evt['description'] = EVENT_MAP[evt['event']] print(f"Event: {evt}") if evt['event'] == zmq.EVENT_MONITOR_STOPPED: break diff --git a/examples/pubsub/publisher.py b/examples/pubsub/publisher.py index 7f51f462c..fbfc6d4ff 100644 --- a/examples/pubsub/publisher.py +++ b/examples/pubsub/publisher.py @@ -18,7 +18,7 @@ import zmq -def sync(bind_to): +def sync(bind_to: str) -> None: # use bind socket + 1 sync_with = ':'.join( bind_to.split(':')[:-1] + [str(int(bind_to.split(':')[-1]) + 1)] @@ -32,7 +32,7 @@ def sync(bind_to): s.send(b'GO') -def main(): +def main() -> None: if len(sys.argv) != 4: print('usage: publisher ') sys.exit(1) diff --git a/examples/pubsub/subscriber.py b/examples/pubsub/subscriber.py index 6d7e9d9e6..061a9e525 100644 --- a/examples/pubsub/subscriber.py +++ b/examples/pubsub/subscriber.py @@ -19,7 +19,7 @@ import zmq -def sync(connect_to): +def sync(connect_to: str) -> None: # use connect socket + 1 sync_with = ':'.join( connect_to.split(':')[:-1] + [str(int(connect_to.split(':')[-1]) + 1)] @@ -31,7 +31,7 @@ def sync(connect_to): s.recv() -def main(): +def main() -> None: if len(sys.argv) != 3: print('usage: subscriber ') sys.exit(1) diff --git a/examples/pubsub/topics_pub.py b/examples/pubsub/topics_pub.py index 9f9cfafaf..dfad578b5 100755 --- a/examples/pubsub/topics_pub.py +++ b/examples/pubsub/topics_pub.py @@ -24,7 +24,7 @@ import zmq -def main(): +def main() -> None: if len(sys.argv) != 2: print('usage: publisher ') sys.exit(1) @@ -46,7 +46,7 @@ def main(): s.bind(bind_to) print("Starting broadcast on topics:") - print(" %s" % all_topics) + print(f" {all_topics}") print("Hit Ctrl-C to stop broadcasting.") print("Waiting so subscriber sockets can connect...") print("") @@ -55,9 +55,9 @@ def main(): msg_counter = itertools.count() try: for topic in itertools.cycle(all_topics): - msg_body = str(next(msg_counter)).encode('utf-8') - print(f' Topic: {topic}, msg:{msg_body}') - s.send_multipart([topic, msg_body]) + msg_body = str(next(msg_counter)) + print(f" Topic: {topic.decode('utf8')}, msg:{msg_body}") + s.send_multipart([topic, msg_body.encode("utf8")]) # short wait so we don't hog the cpu time.sleep(0.1) except KeyboardInterrupt: diff --git a/examples/pubsub/topics_sub.py b/examples/pubsub/topics_sub.py index 360f2b8ee..3915b4cb0 100755 --- a/examples/pubsub/topics_sub.py +++ b/examples/pubsub/topics_sub.py @@ -25,7 +25,7 @@ import zmq -def main(): +def main() -> None: if len(sys.argv) < 2: print('usage: subscriber [topic topic ...]') sys.exit(1) diff --git a/examples/security/asyncio-ironhouse.py b/examples/security/asyncio-ironhouse.py index 435bcf54c..69fbca752 100644 --- a/examples/security/asyncio-ironhouse.py +++ b/examples/security/asyncio-ironhouse.py @@ -23,7 +23,7 @@ from zmq.auth.asyncio import AsyncioAuthenticator -async def run(): +async def run() -> None: '''Run Ironhouse example''' # These directories are generated by the generate_certificates script @@ -82,7 +82,7 @@ async def run(): # use copy=False to allow access to message properties via the zmq.Frame API # default recv(copy=True) returns only bytes, discarding properties identity, msg = await server.recv_multipart(copy=False) - logging.info(f"Received {msg.bytes} from {msg['User-Id']!r}") + logging.info(f"Received {msg.bytes!r} from {msg['User-Id']!r}") if msg.bytes == b"Hello": logging.info("Ironhouse test OK") else: diff --git a/examples/security/generate_certificates.py b/examples/security/generate_certificates.py index 627c858af..2b1511989 100644 --- a/examples/security/generate_certificates.py +++ b/examples/security/generate_certificates.py @@ -12,11 +12,12 @@ import os import shutil +from typing import Union import zmq.auth -def generate_certificates(base_dir): +def generate_certificates(base_dir: Union[str, os.PathLike]) -> None: '''Generate client and server CURVE certificate files''' keys_dir = os.path.join(base_dir, 'certificates') public_keys_dir = os.path.join(base_dir, 'public_keys') diff --git a/examples/security/ioloop-ironhouse.py b/examples/security/ioloop-ironhouse.py index e318c861e..e5c3793c6 100644 --- a/examples/security/ioloop-ironhouse.py +++ b/examples/security/ioloop-ironhouse.py @@ -16,21 +16,24 @@ import logging import os import sys +from typing import List + +from tornado import ioloop import zmq import zmq.auth from zmq.auth.ioloop import IOLoopAuthenticator -from zmq.eventloop import ioloop, zmqstream +from zmq.eventloop import zmqstream -def echo(server, msg): +def echo(server: zmqstream.ZMQStream, msg: List[bytes]) -> None: logging.debug("server recvd %s", msg) reply = msg + [b'World'] logging.debug("server sending %s", reply) server.send_multipart(reply) -def setup_server(server_secret_file, endpoint='tcp://127.0.0.1:9000'): +def setup_server(server_secret_file: str, endpoint: str = 'tcp://127.0.0.1:9000'): """setup a simple echo server with CURVE auth""" server = zmq.Context.instance().socket(zmq.ROUTER) @@ -46,7 +49,7 @@ def setup_server(server_secret_file, endpoint='tcp://127.0.0.1:9000'): return server_stream -def client_msg_recvd(msg): +def client_msg_recvd(msg: List[bytes]): logging.debug("client recvd %s", msg) logging.info("Ironhouse test OK") # stop the loop when we get the reply @@ -54,7 +57,9 @@ def client_msg_recvd(msg): def setup_client( - client_secret_file, server_public_file, endpoint='tcp://127.0.0.1:9000' + client_secret_file: str, + server_public_file: str, + endpoint: str = 'tcp://127.0.0.1:9000', ): """setup a simple client with CURVE auth""" @@ -77,7 +82,7 @@ def setup_client( return client_stream -def run(): +def run() -> None: '''Run Ironhouse example''' # These direcotries are generated by the generate_certificates script diff --git a/examples/security/ironhouse.py b/examples/security/ironhouse.py index 8d683210f..02912b117 100644 --- a/examples/security/ironhouse.py +++ b/examples/security/ironhouse.py @@ -20,7 +20,7 @@ from zmq.auth.thread import ThreadAuthenticator -def run(): +def run() -> None: '''Run Ironhouse example''' # These directories are generated by the generate_certificates script diff --git a/examples/security/stonehouse.py b/examples/security/stonehouse.py index ce730b084..2e4f0eb02 100644 --- a/examples/security/stonehouse.py +++ b/examples/security/stonehouse.py @@ -21,7 +21,7 @@ from zmq.auth.thread import ThreadAuthenticator -def run(): +def run() -> None: '''Run Stonehouse example''' # These directories are generated by the generate_certificates script diff --git a/examples/security/strawhouse.py b/examples/security/strawhouse.py index 0cc6057e3..034c681e3 100644 --- a/examples/security/strawhouse.py +++ b/examples/security/strawhouse.py @@ -19,7 +19,7 @@ from zmq.auth.thread import ThreadAuthenticator -def run(): +def run() -> None: '''Run strawhouse client''' allow_test_pass = False diff --git a/examples/security/woodhouse.py b/examples/security/woodhouse.py index c1186cea6..2a5f6599d 100644 --- a/examples/security/woodhouse.py +++ b/examples/security/woodhouse.py @@ -18,7 +18,7 @@ from zmq.auth.thread import ThreadAuthenticator -def run(): +def run() -> None: '''Run woodhouse example''' valid_client_test_pass = False diff --git a/examples/serialization/serialsocket.py b/examples/serialization/serialsocket.py index 823c82890..9a070581c 100644 --- a/examples/serialization/serialsocket.py +++ b/examples/serialization/serialsocket.py @@ -2,6 +2,7 @@ import pickle import zlib +from typing import Any, Dict, cast import numpy @@ -18,20 +19,24 @@ class SerializingSocket(zmq.Socket): for reconstructing the array on the other side (dtype,shape). """ - def send_zipped_pickle(self, obj, flags=0, protocol=-1): + def send_zipped_pickle( + self, obj: Any, flags: int = 0, protocol: int = pickle.HIGHEST_PROTOCOL + ) -> None: """pack and compress an object with pickle and zlib.""" pobj = pickle.dumps(obj, protocol) zobj = zlib.compress(pobj) print('zipped pickle is %i bytes' % len(zobj)) return self.send(zobj, flags=flags) - def recv_zipped_pickle(self, flags=0): + def recv_zipped_pickle(self, flags: int = 0) -> Any: """reconstruct a Python object sent with zipped_pickle""" zobj = self.recv(flags) pobj = zlib.decompress(zobj) return pickle.loads(pobj) - def send_array(self, A, flags=0, copy=True, track=False): + def send_array( + self, A: numpy.ndarray, flags: int = 0, copy: bool = True, track: bool = False + ) -> Any: """send a numpy array with metadata""" md = dict( dtype=str(A.dtype), @@ -40,19 +45,21 @@ def send_array(self, A, flags=0, copy=True, track=False): self.send_json(md, flags | zmq.SNDMORE) return self.send(A, flags, copy=copy, track=track) - def recv_array(self, flags=0, copy=True, track=False): + def recv_array( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> numpy.ndarray: """recv a numpy array""" - md = self.recv_json(flags=flags) + md = cast(Dict[str, Any], self.recv_json(flags=flags)) msg = self.recv(flags=flags, copy=copy, track=track) A = numpy.frombuffer(msg, dtype=md['dtype']) return A.reshape(md['shape']) -class SerializingContext(zmq.Context): +class SerializingContext(zmq.Context[SerializingSocket]): _socket_class = SerializingSocket -def main(): +def main() -> None: ctx = SerializingContext() req = ctx.socket(zmq.REQ) rep = ctx.socket(zmq.REP) diff --git a/examples/win32-interrupt/display.py b/examples/win32-interrupt/display.py index 5a026c1fc..ee41ef9ca 100644 --- a/examples/win32-interrupt/display.py +++ b/examples/win32-interrupt/display.py @@ -1,12 +1,13 @@ """The display part of a simply two process chat app.""" # This file has been placed in the public domain. +from typing import List import zmq from zmq.utils.win32 import allow_interrupt -def main(addrs): +def main(addrs: List[str]): context = zmq.Context() control = context.socket(zmq.PUB) control.bind('inproc://control') @@ -24,14 +25,14 @@ def interrupt_polling(): with allow_interrupt(interrupt_polling): message = '' while message != 'quit': - message = updates.recv_multipart() - if len(message) < 2: + recvd = updates.recv_multipart() + if len(recvd) < 2: print('Invalid message.') continue - account = message[0] - message = ' '.join(message[1:]) + account = recvd[0].decode("utf8") + message = ' '.join(b.decode("utf8") for b in recvd[1:]) if message == 'quit': - print('Killed by "%s".' % account) + print(f'Killed by {account}.') break print(f'{account}: {message}') diff --git a/examples/win32-interrupt/prompt.py b/examples/win32-interrupt/prompt.py index 89d470aed..588c56f8f 100644 --- a/examples/win32-interrupt/prompt.py +++ b/examples/win32-interrupt/prompt.py @@ -21,7 +21,7 @@ import zmq -def main(addr, account): +def main(addr: str, account: str) -> None: ctx = zmq.Context() socket = ctx.socket(zmq.PUB) diff --git a/mypy_tests/test_context.py b/mypy_tests/test_context.py index ec650dd0b..270574804 100644 --- a/mypy_tests/test_context.py +++ b/mypy_tests/test_context.py @@ -3,7 +3,6 @@ ctx = zmq.Context.instance() s = ctx.socket(zmq.PUSH) s.send(b"buf") - ctx2 = zmq.Context.shadow_pyczmq(123) s2 = ctx2.socket(zmq.PUSH) s.send(b"buf") diff --git a/zmq/__init__.pyi b/zmq/__init__.pyi index 99a714007..c9bc245d3 100644 --- a/zmq/__init__.pyi +++ b/zmq/__init__.pyi @@ -2,6 +2,10 @@ from typing import List from . import backend, sugar +COPY_THRESHOLD: int +DRAFT_API: bool +__version__: str + # mypy doesn't like overwriting symbols with * so be explicit # about what comes from backend, not from sugar # see tools/backend_imports.py to generate this list @@ -21,8 +25,5 @@ from .constants import * from .error import * from .sugar import * -COPY_THRESHOLD: int -DRAFT_API: bool - def get_includes() -> List[str]: ... def get_library_dirs() -> List[str]: ... diff --git a/zmq/_future.py b/zmq/_future.py index 3a1dc01cb..b2b3c2b8b 100644 --- a/zmq/_future.py +++ b/zmq/_future.py @@ -4,14 +4,37 @@ # Distributed under the terms of the Modified BSD License. import warnings +from asyncio import Future from collections import deque, namedtuple from itertools import chain -from typing import Type +from typing import ( + Any, + Awaitable, + Callable, + Dict, + List, + NamedTuple, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) import zmq as _zmq from zmq import EVENTS, POLLIN, POLLOUT +from zmq._typing import Literal + + +class _FutureEvent(NamedTuple): + future: Future + kind: str + kwargs: Dict + msg: Any + timer: Any -_FutureEvent = namedtuple('_FutureEvent', ('future', 'kind', 'kwargs', 'msg', 'timer')) # These are incomplete classes and need a Mixin for compatibility with an eventloop # defining the following attributes: @@ -25,9 +48,10 @@ class _Async: """Mixin for common async logic""" - _current_loop = None + _current_loop: Any = None + _Future: Type[Future] - def _get_loop(self): + def _get_loop(self) -> Any: """Get event loop Notice if event loop has changed, @@ -44,19 +68,30 @@ def _get_loop(self): self._init_io_state(current_loop) return current_loop - def _default_loop(self): + def _default_loop(self) -> Any: raise NotImplementedError("Must be implemented in a subclass") - def _init_io_state(self, loop=None): + def _init_io_state(self, loop=None) -> None: pass class _AsyncPoller(_Async, _zmq.Poller): """Poller that returns a Future on poll, instead of blocking.""" - _socket_class = None # type: Type[_AsyncSocket] + _socket_class: Type["_AsyncSocket"] + _READ: int + _WRITE: int + raw_sockets: List[Any] - def poll(self, timeout=-1): + def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None: + """Schedule callback for a raw socket""" + raise NotImplementedError() + + def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None: + """Unschedule callback for a raw socket""" + raise NotImplementedError() + + def poll(self, timeout=-1) -> Awaitable[List[Tuple[Any, int]]]: # type: ignore """Return a Future for a poll event""" future = self._Future() if timeout == 0: @@ -74,7 +109,7 @@ def poll(self, timeout=-1): watcher = self._Future() # watch raw sockets: - raw_sockets = [] + raw_sockets: List[Any] = [] def wake_raw(*args): if not watcher.done(): @@ -155,6 +190,9 @@ def cancel(): pass +T = TypeVar("T", bound="_AsyncSocket") + + class _AsyncSocket(_Async, _zmq.Socket): # Warning : these class variables are only here to allow to call super().__setattr__. @@ -162,18 +200,23 @@ class _AsyncSocket(_Async, _zmq.Socket): _recv_futures = None _send_futures = None _state = 0 - _shadow_sock = None + _shadow_sock: "_zmq.Socket" _poller_class = _AsyncPoller _fd = None - def __init__(self, context=None, socket_type=-1, io_loop=None, **kwargs): + def __init__( + self, + context=None, + socket_type=-1, + io_loop=None, + _from_socket: Optional["_zmq.Socket"] = None, + **kwargs, + ) -> None: if isinstance(context, _zmq.Socket): - context, from_socket = (None, context) - else: - from_socket = kwargs.pop('_from_socket', None) - if from_socket is not None: - super().__init__(shadow=from_socket.underlying) - self._shadow_sock = from_socket + context, _from_socket = (None, context) + if _from_socket is not None: + super().__init__(shadow=_from_socket.underlying) + self._shadow_sock = _from_socket else: super().__init__(context, socket_type, **kwargs) self._shadow_sock = _zmq.Socket.shadow(self.underlying) @@ -191,15 +234,16 @@ def __init__(self, context=None, socket_type=-1, io_loop=None, **kwargs): self._fd = self._shadow_sock.FD @classmethod - def from_socket(cls, socket, io_loop=None): + def from_socket(cls: Type[T], socket: "_zmq.Socket", io_loop: Any = None) -> T: """Create an async socket from an existing Socket""" return cls(_from_socket=socket, io_loop=io_loop) - def close(self, linger=None): + def close(self, linger: Optional[int] = None) -> None: if not self.closed and self._fd is not None: - for event in list( + event_list: List[_FutureEvent] = list( chain(self._recv_futures or [], self._send_futures or []) - ): + ) + for event in event_list: if not event.future.done(): try: event.future.cancel() @@ -219,7 +263,33 @@ def get(self, key): get.__doc__ = _zmq.Socket.get.__doc__ - def recv_multipart(self, flags=0, copy=True, track=False): + @overload # type: ignore + def recv_multipart( + self, flags: int = 0, *, track: bool = False + ) -> Awaitable[List[bytes]]: + ... + + @overload + def recv_multipart( + self, flags: int = 0, *, copy: Literal[True], track: bool = False + ) -> Awaitable[List[bytes]]: + ... + + @overload + def recv_multipart( + self, flags: int = 0, *, copy: Literal[False], track: bool = False + ) -> Awaitable[List[_zmq.Frame]]: # type: ignore + ... + + @overload + def recv_multipart( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]: + ... + + def recv_multipart( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]: """Receive a complete multipart zmq message. Returns a Future whose result will be a multipart message. @@ -228,7 +298,9 @@ def recv_multipart(self, flags=0, copy=True, track=False): 'recv_multipart', dict(flags=flags, copy=copy, track=track) ) - def recv(self, flags=0, copy=True, track=False): + def recv( # type: ignore + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> Awaitable[Union[bytes, _zmq.Frame]]: """Receive a single zmq frame. Returns a Future, whose result will be the received frame. @@ -237,7 +309,9 @@ def recv(self, flags=0, copy=True, track=False): """ return self._add_recv_event('recv', dict(flags=flags, copy=copy, track=track)) - def send_multipart(self, msg, flags=0, copy=True, track=False, **kwargs): + def send_multipart( # type: ignore + self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs + ) -> Awaitable[Optional[_zmq.MessageTracker]]: """Send a complete multipart zmq message. Returns a Future that resolves when sending is complete. @@ -245,9 +319,16 @@ def send_multipart(self, msg, flags=0, copy=True, track=False, **kwargs): kwargs['flags'] = flags kwargs['copy'] = copy kwargs['track'] = track - return self._add_send_event('send_multipart', msg=msg, kwargs=kwargs) - - def send(self, msg, flags=0, copy=True, track=False, **kwargs): + return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs) + + def send( # type: ignore + self, + data: Any, + flags: int = 0, + copy: bool = True, + track: bool = False, + **kwargs: Any, + ) -> Awaitable[Optional[_zmq.MessageTracker]]: """Send a single zmq frame. Returns a Future that resolves when sending is complete. @@ -258,7 +339,7 @@ def send(self, msg, flags=0, copy=True, track=False, **kwargs): kwargs['copy'] = copy kwargs['track'] = track kwargs.update(dict(flags=flags, copy=copy, track=track)) - return self._add_send_event('send', msg=msg, kwargs=kwargs) + return self._add_send_event('send', msg=data, kwargs=kwargs) def _deserialize(self, recvd, load): """Deserialize with Futures""" @@ -292,7 +373,7 @@ def _chain_cancel(_): return f - def poll(self, timeout=None, flags=_zmq.POLLIN): + def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore """poll the socket for events returns a Future for the poll results. @@ -303,7 +384,7 @@ def poll(self, timeout=None, flags=_zmq.POLLIN): p = self._poller_class() p.register(self, flags) - f = p.poll(timeout) + f = cast(Future, p.poll(timeout)) future = self._Future() @@ -330,6 +411,13 @@ def unwrap_result(f): f.add_done_callback(unwrap_result) return future + # overrides only necessary for updated types + def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore + return super().recv_string(*args, **kwargs) # type: ignore + + def send_string(self, s: str, flags: int = 0, encoding: str = 'utf-8') -> Awaitable[None]: # type: ignore + return super().send_string(s, flags=flags, encoding=encoding) # type: ignore + def _add_timeout(self, future, timeout): """Add a timeout for a send or recv Future""" diff --git a/zmq/_typing.py b/zmq/_typing.py new file mode 100644 index 000000000..e08013605 --- /dev/null +++ b/zmq/_typing.py @@ -0,0 +1,19 @@ +import sys +from typing import Any, Dict + +if sys.version_info >= (3, 8): + from typing import Literal, TypedDict +else: + # avoid runtime dependency on typing_extensions on py37 + try: + from typing_extensions import Literal, TypedDict # type: ignore + except ImportError: + + class _Literal: + def __getitem__(self, key): + return Any + + Literal = _Literal() # type: ignore + + class TypedDict(Dict): # type: ignore + pass diff --git a/zmq/asyncio.py b/zmq/asyncio.py index 2078636a2..8a0963e5b 100644 --- a/zmq/asyncio.py +++ b/zmq/asyncio.py @@ -151,7 +151,7 @@ def _clear_io_state(self): Poller._socket_class = Socket -class Context(_zmq.Context): +class Context(_zmq.Context[Socket]): """Context for creating asyncio-compatible Sockets""" _socket_class = Socket diff --git a/zmq/auth/asyncio.py b/zmq/auth/asyncio.py index 51a264d5b..b1f3e7bdb 100644 --- a/zmq/auth/asyncio.py +++ b/zmq/auth/asyncio.py @@ -7,6 +7,8 @@ # Distributed under the terms of the Modified BSD License. import asyncio +import warnings +from typing import Any, Optional import zmq from zmq.asyncio import Poller @@ -17,27 +19,34 @@ class AsyncioAuthenticator(Authenticator): """ZAP authentication for use in the asyncio IO loop""" - def __init__(self, context=None, loop=None): + __poller: Optional[Poller] + __task: Any + zap_socket: "zmq.asyncio.Socket" + + def __init__(self, context: Optional["zmq.Context"] = None, loop: Any = None): super().__init__(context) - self.loop = loop or asyncio.get_event_loop() + if loop is not None: + warnings.warn(f"{self.__class__.__name__}(loop) is deprecated and ignored") self.__poller = None self.__task = None - async def __handle_zap(self): + async def __handle_zap(self) -> None: while True: + if self.__poller is None: + break events = await self.__poller.poll() if self.zap_socket in dict(events): msg = await self.zap_socket.recv_multipart() self.handle_zap_message(msg) - def start(self): + def start(self) -> None: """Start ZAP authentication""" super().start() self.__poller = Poller() self.__poller.register(self.zap_socket, zmq.POLLIN) self.__task = asyncio.ensure_future(self.__handle_zap()) - def stop(self): + def stop(self) -> None: """Stop ZAP authentication""" if self.__task: self.__task.cancel() diff --git a/zmq/auth/base.py b/zmq/auth/base.py index 861b0617d..87acdddc0 100644 --- a/zmq/auth/base.py +++ b/zmq/auth/base.py @@ -4,6 +4,8 @@ # Distributed under the terms of the Modified BSD License. import logging +import os +from typing import Any, Dict, List, Optional, Set, Tuple, Union import zmq from zmq.error import _check_version @@ -43,13 +45,29 @@ class Authenticator: - GSSAPI requires no configuration. """ - def __init__(self, context=None, encoding='utf-8', log=None): + context: "zmq.Context" + encoding: str + allow_any: bool + credentials_providers: Dict[str, Any] + zap_socket: "zmq.Socket" + whitelist: Set[str] + blacklist: Set[str] + passwords: Dict[str, Dict[str, str]] + certs: Dict[str, Dict[bytes, Any]] + log: Any + + def __init__( + self, + context: Optional["zmq.Context"] = None, + encoding: str = 'utf-8', + log: Any = None, + ): _check_version((4, 0), "security") self.context = context or zmq.Context.instance() self.encoding = encoding self.allow_any = False self.credentials_providers = {} - self.zap_socket = None + self.zap_socket = None # type: ignore self.whitelist = set() self.blacklist = set() # passwords is a dict keyed by domain and contains values @@ -60,20 +78,20 @@ def __init__(self, context=None, encoding='utf-8', log=None): self.certs = {} self.log = log or logging.getLogger('zmq.auth') - def start(self): + def start(self) -> None: """Create and bind the ZAP socket""" self.zap_socket = self.context.socket(zmq.REP) self.zap_socket.linger = 1 self.zap_socket.bind("inproc://zeromq.zap.01") self.log.debug("Starting") - def stop(self): + def stop(self) -> None: """Close the ZAP socket""" if self.zap_socket: self.zap_socket.close() - self.zap_socket = None + self.zap_socket = None # type: ignore - def allow(self, *addresses): + def allow(self, *addresses: str) -> None: """Allow (whitelist) IP address(es). Connections from addresses not in the whitelist will be rejected. @@ -88,7 +106,7 @@ def allow(self, *addresses): self.log.debug("Allowing %s", ','.join(addresses)) self.whitelist.update(addresses) - def deny(self, *addresses): + def deny(self, *addresses: str) -> None: """Deny (blacklist) IP address(es). Addresses not in the blacklist will be allowed to continue with authentication. @@ -100,7 +118,9 @@ def deny(self, *addresses): self.log.debug("Denying %s", ','.join(addresses)) self.blacklist.update(addresses) - def configure_plain(self, domain='*', passwords=None): + def configure_plain( + self, domain: str = '*', passwords: Dict[str, str] = None + ) -> None: """Configure PLAIN authentication for a given domain. PLAIN authentication uses a plain-text password file. @@ -111,7 +131,9 @@ def configure_plain(self, domain='*', passwords=None): self.passwords[domain] = passwords self.log.debug("Configure plain: %s", domain) - def configure_curve(self, domain='*', location=None): + def configure_curve( + self, domain: str = '*', location: Union[str, os.PathLike] = "." + ) -> None: """Configure CURVE authentication for a given domain. CURVE authentication uses a directory that holds all public client certificates, @@ -136,7 +158,9 @@ def configure_curve(self, domain='*', location=None): except Exception as e: self.log.error("Failed to load CURVE certs from %s: %s", location, e) - def configure_curve_callback(self, domain='*', credentials_provider=None): + def configure_curve_callback( + self, domain: str = '*', credentials_provider: Any = None + ) -> None: """Configure CURVE authentication for a given domain. CURVE authentication using a callback function validating @@ -170,7 +194,7 @@ def callback(self, domain, key): else: self.log.error("None credentials_provider provided for domain:%s", domain) - def curve_user_id(self, client_public_key): + def curve_user_id(self, client_public_key: bytes) -> str: """Return the User-Id corresponding to a CURVE client's public key Default implementation uses the z85-encoding of the public key. @@ -191,14 +215,16 @@ def curve_user_id(self, client_public_key): """ return z85.encode(client_public_key).decode('ascii') - def configure_gssapi(self, domain='*', location=None): + def configure_gssapi( + self, domain: str = '*', location: Optional[str] = None + ) -> None: """Configure GSSAPI authentication Currently this is a no-op because there is nothing to configure with GSSAPI. """ pass - def handle_zap_message(self, msg): + def handle_zap_message(self, msg: List[bytes]): """Perform ZAP authentication""" if len(msg) < 6: self.log.error("Invalid ZAP message, not enough frames: %r", msg) @@ -298,7 +324,9 @@ def handle_zap_message(self, msg): else: self._send_zap_reply(request_id, b"400", reason) - def _authenticate_plain(self, domain, username, password): + def _authenticate_plain( + self, domain: str, username: str, password: str + ) -> Tuple[bool, bytes]: """PLAIN ZAP authentication""" allowed = False reason = b"" @@ -334,7 +362,7 @@ def _authenticate_plain(self, domain, username, password): return allowed, reason - def _authenticate_curve(self, domain, client_key): + def _authenticate_curve(self, domain: str, client_key: bytes) -> Tuple[bool, bytes]: """CURVE ZAP authentication""" allowed = False reason = b"" @@ -391,14 +419,18 @@ def _authenticate_curve(self, domain, client_key): return allowed, reason - def _authenticate_gssapi(self, domain, principal): + def _authenticate_gssapi(self, domain: str, principal: bytes) -> Tuple[bool, bytes]: """Nothing to do for GSSAPI, which has already been handled by an external service.""" self.log.debug("ALLOWED (GSSAPI) domain=%s principal=%s", domain, principal) return True, b'OK' def _send_zap_reply( - self, request_id, status_code, status_text, user_id='anonymous' - ): + self, + request_id: bytes, + status_code: bytes, + status_text: bytes, + user_id: str = 'anonymous', + ) -> None: """Send a ZAP reply to finish the authentication.""" user_id = user_id if status_code == b'200' else b'' if isinstance(user_id, unicode): diff --git a/zmq/auth/certs.py b/zmq/auth/certs.py index b56507f52..cda8ba446 100644 --- a/zmq/auth/certs.py +++ b/zmq/auth/certs.py @@ -8,32 +8,34 @@ import glob import io import os +from typing import Dict, Optional, Tuple, Union import zmq -from zmq.utils.strtypes import b, bytes, u, unicode -_cert_secret_banner = u( - """# **** Generated on {0} by pyzmq **** +_cert_secret_banner = """# **** Generated on {0} by pyzmq **** # ZeroMQ CURVE **Secret** Certificate # DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions. """ -) -_cert_public_banner = u( - """# **** Generated on {0} by pyzmq **** + +_cert_public_banner = """# **** Generated on {0} by pyzmq **** # ZeroMQ CURVE Public Certificate # Exchange securely, or use a secure mechanism to verify the contents # of this file after exchange. Store public certificates in your home # directory, in the .curve subdirectory. """ -) def _write_key_file( - key_filename, banner, public_key, secret_key=None, metadata=None, encoding='utf-8' -): + key_filename: Union[str, os.PathLike], + banner: str, + public_key: Union[str, bytes], + secret_key: Optional[Union[str, bytes]] = None, + metadata: Optional[Dict[str, str]] = None, + encoding: str = 'utf-8', +) -> None: """Create a certificate file""" if isinstance(public_key, bytes): public_key = public_key.decode(encoding) @@ -42,23 +44,27 @@ def _write_key_file( with open(key_filename, 'w', encoding='utf8') as f: f.write(banner.format(datetime.datetime.now())) - f.write(u('metadata\n')) + f.write('metadata\n') if metadata: for k, v in metadata.items(): if isinstance(k, bytes): k = k.decode(encoding) if isinstance(v, bytes): v = v.decode(encoding) - f.write(u(" {0} = {1}\n").format(k, v)) + f.write(f" {k} = {v}\n") - f.write(u('curve\n')) - f.write(u(" public-key = \"{0}\"\n").format(public_key)) + f.write('curve\n') + f.write(f" public-key = \"{public_key}\"\n") if secret_key: - f.write(u(" secret-key = \"{0}\"\n").format(secret_key)) + f.write(f" secret-key = \"{secret_key}\"\n") -def create_certificates(key_dir, name, metadata=None): +def create_certificates( + key_dir: Union[str, os.PathLike], + name: str, + metadata: Optional[Dict[str, str]] = None, +) -> Tuple[str, str]: """Create zmq certificates. Returns the file paths to the public and secret certificate files. @@ -82,7 +88,9 @@ def create_certificates(key_dir, name, metadata=None): return public_key_file, secret_key_file -def load_certificate(filename): +def load_certificate( + filename: Union[str, os.PathLike] +) -> Tuple[bytes, Optional[bytes]]: """Load public and secret key from a zmq certificate. Returns (public_key, secret_key) @@ -115,7 +123,7 @@ def load_certificate(filename): return public_key, secret_key -def load_certificates(directory='.'): +def load_certificates(directory: Union[str, os.PathLike] = '.') -> Dict[bytes, bool]: """Load public keys from all certificates in a directory""" certs = {} if not os.path.isdir(directory): diff --git a/zmq/auth/ioloop.py b/zmq/auth/ioloop.py index b063b0cf2..29771cbf2 100644 --- a/zmq/auth/ioloop.py +++ b/zmq/auth/ioloop.py @@ -5,9 +5,11 @@ # Copyright (C) PyZMQ Developers # Distributed under the terms of the Modified BSD License. +from typing import Any, Optional from tornado import ioloop +import zmq from zmq.eventloop import zmqstream from .base import Authenticator @@ -16,12 +18,21 @@ class IOLoopAuthenticator(Authenticator): """ZAP authentication for use in the tornado IOLoop""" - def __init__(self, context=None, encoding='utf-8', log=None, io_loop=None): + zap_stream: zmqstream.ZMQStream + io_loop: ioloop.IOLoop + + def __init__( + self, + context: Optional["zmq.Context"] = None, + encoding: str = 'utf-8', + log: Any = None, + io_loop: Optional[ioloop.IOLoop] = None, + ): super().__init__(context, encoding, log) - self.zap_stream = None + self.zap_stream = None # type: ignore self.io_loop = io_loop or ioloop.IOLoop.current() - def start(self): + def start(self) -> None: """Start ZAP authentication""" super().start() self.zap_stream = zmqstream.ZMQStream(self.zap_socket, self.io_loop) diff --git a/zmq/auth/thread.py b/zmq/auth/thread.py index 876ed51ec..0d42f7b9f 100644 --- a/zmq/auth/thread.py +++ b/zmq/auth/thread.py @@ -8,7 +8,9 @@ import logging import sys +from itertools import chain from threading import Event, Thread +from typing import Any, Dict, List, Optional, TypeVar, cast import zmq from zmq.utils import jsonapi @@ -24,14 +26,19 @@ class AuthenticationThread(Thread): """ def __init__( - self, context, endpoint, encoding='utf-8', log=None, authenticator=None - ): + self, + context: "zmq.Context", + endpoint: str, + encoding: str = 'utf-8', + log: Any = None, + authenticator: Optional[Authenticator] = None, + ) -> None: super().__init__() self.context = context or zmq.Context.instance() self.encoding = encoding self.log = log = log or logging.getLogger('zmq.auth') self.started = Event() - self.authenticator = authenticator or Authenticator( + self.authenticator: Authenticator = authenticator or Authenticator( context, encoding=encoding, log=log ) @@ -40,7 +47,7 @@ def __init__( self.pipe.linger = 1 self.pipe.connect(endpoint) - def run(self): + def run(self) -> None: """Start the Authentication Agent thread task""" self.authenticator.start() self.started.set() @@ -75,16 +82,18 @@ def run(self): self.pipe.close() self.authenticator.stop() - def _handle_zap(self): + def _handle_zap(self) -> None: """ Handle a message from the ZAP socket. """ + if self.authenticator.zap_socket is None: + raise RuntimeError("ZAP socket closed") msg = self.authenticator.zap_socket.recv_multipart() if not msg: return self.authenticator.handle_zap_message(msg) - def _handle_pipe(self, msg): + def _handle_pipe(self, msg: List[bytes]) -> bool: """ Handle a message from front-end API. """ @@ -114,7 +123,10 @@ def _handle_pipe(self, msg): elif command == b'PLAIN': domain = u(msg[1], self.encoding) json_passwords = msg[2] - self.authenticator.configure_plain(domain, jsonapi.loads(json_passwords)) + passwords: Dict[str, str] = cast( + Dict[str, str], jsonapi.loads(json_passwords) + ) + self.authenticator.configure_plain(domain, passwords) elif command == b'CURVE': # For now we don't do anything with domains @@ -134,7 +146,10 @@ def _handle_pipe(self, msg): return terminate -def _inherit_docstrings(cls): +T = TypeVar("T", bound=type) + + +def _inherit_docstrings(cls: T) -> T: """inherit docstrings from Authenticator, so we don't duplicate them""" for name, method in cls.__dict__.items(): if name.startswith('_') or not callable(method): @@ -149,59 +164,64 @@ def _inherit_docstrings(cls): class ThreadAuthenticator: """Run ZAP authentication in a background thread""" - context = None - log = None - encoding = None - pipe = None - pipe_endpoint = '' - thread = None - auth = None + context: "zmq.Context" + log: Any + encoding: str + pipe: "zmq.Socket" + pipe_endpoint: str = '' + thread: AuthenticationThread - def __init__(self, context=None, encoding='utf-8', log=None): - self.context = context or zmq.Context.instance() + def __init__( + self, + context: Optional["zmq.Context"] = None, + encoding: str = 'utf-8', + log: Any = None, + ): self.log = log self.encoding = encoding - self.pipe = None + self.pipe = None # type: ignore self.pipe_endpoint = f"inproc://{id(self)}.inproc" - self.thread = None + self.thread = None # type: ignore + self.context = context or zmq.Context.instance() # proxy base Authenticator attributes - def __setattr__(self, key, value): - for obj in [self] + self.__class__.mro(): - if key in obj.__dict__: + def __setattr__(self, key: str, value: Any): + for obj in chain([self], self.__class__.mro()): + if key in obj.__dict__ or (key in getattr(obj, "__annotations__", {})): object.__setattr__(self, key, value) return setattr(self.thread.authenticator, key, value) - def __getattr__(self, key): - try: - object.__getattr__(self, key) - except AttributeError: - return getattr(self.thread.authenticator, key) + def __getattr__(self, key: str): + return getattr(self.thread.authenticator, key) - def allow(self, *addresses): + def allow(self, *addresses: str): self.pipe.send_multipart([b'ALLOW'] + [b(a, self.encoding) for a in addresses]) - def deny(self, *addresses): + def deny(self, *addresses: str): self.pipe.send_multipart([b'DENY'] + [b(a, self.encoding) for a in addresses]) - def configure_plain(self, domain='*', passwords=None): + def configure_plain( + self, domain: str = '*', passwords: Optional[Dict[str, str]] = None + ): self.pipe.send_multipart( [b'PLAIN', b(domain, self.encoding), jsonapi.dumps(passwords or {})] ) - def configure_curve(self, domain='*', location=''): + def configure_curve(self, domain: str = '*', location: str = ''): domain = b(domain, self.encoding) location = b(location, self.encoding) self.pipe.send_multipart([b'CURVE', domain, location]) - def configure_curve_callback(self, domain='*', credentials_provider=None): + def configure_curve_callback( + self, domain: str = '*', credentials_provider: Any = None + ): self.thread.authenticator.configure_curve_callback( domain, credentials_provider=credentials_provider ) - def start(self): + def start(self) -> None: """Start the authentication thread""" # create a socket to communicate with auth thread. self.pipe = self.context.socket(zmq.PAIR) @@ -214,23 +234,23 @@ def start(self): if not self.thread.started.wait(timeout=10): raise RuntimeError("Authenticator thread failed to start") - def stop(self): + def stop(self) -> None: """Stop the authentication thread""" if self.pipe: self.pipe.send(b'TERMINATE') if self.is_alive(): self.thread.join() - self.thread = None + self.thread = None # type: ignore self.pipe.close() - self.pipe = None + self.pipe = None # type: ignore - def is_alive(self): + def is_alive(self) -> bool: """Is the ZAP thread currently running?""" if self.thread and self.thread.is_alive(): return True return False - def __del__(self): + def __del__(self) -> None: self.stop() diff --git a/zmq/backend/__init__.pyi b/zmq/backend/__init__.pyi index 1b141f157..4c699a8fe 100644 --- a/zmq/backend/__init__.pyi +++ b/zmq/backend/__init__.pyi @@ -1,44 +1,92 @@ -from typing import Any, ByteString, List, Optional, Set, Tuple, Union +from typing import Any, List, Optional, Set, Tuple, TypeVar, Union, overload + +from typing_extensions import Literal + +import zmq from .select import select_backend +# avoid collision in Frame.bytes +_bytestr = bytes + +T = TypeVar("T") + class Frame: buffer: Any - bytes: ByteString + bytes: bytes more: bool tracker: Any - def copy_fast(self) -> Frame: ... - def get(self, option: int) -> Union[int, ByteString, str]: ... - def set(self, option: int, value: Union[int, ByteString, str]) -> None: ... + def __init__( + self, + data: Any = None, + track: bool = False, + copy: Optional[bool] = None, + copy_threshold: Optional[int] = None, + ): ... + def copy_fast(self: T) -> T: ... + def get(self, option: int) -> Union[int, _bytestr, str]: ... + def set(self, option: int, value: Union[int, _bytestr, str]) -> None: ... class Socket: underlying: int + context: "zmq.Context" + copy_threshold: int + + # specific option types + FD: int def close(self, linger: Optional[int] = ...) -> None: ... - def get(self, option: int) -> Union[int, ByteString, str]: ... - def set(self, option: int, value: Union[int, ByteString, str]) -> None: ... - def connect(self, url: str) -> None: ... + def get(self, option: int) -> Union[int, bytes, str]: ... + def set(self, option: int, value: Union[int, bytes, str]) -> None: ... + def connect(self, url: str): ... def disconnect(self, url: str) -> None: ... - def bind(self, url: str) -> None: ... + def bind(self, url: str): ... def unbind(self, url: str) -> None: ... def send( self, data: Any, - flags: Optional[int] = ..., - copy: Optional[bool] = ..., - track: Optional[bool] = ..., - ) -> Optional[Frame]: ... + flags: int = ..., + copy: bool = ..., + track: bool = ..., + ) -> Optional["zmq.MessageTracker"]: ... + @overload + def recv( + self, + flags: int = ..., + *, + copy: Literal[False], + track: bool = ..., + ) -> "zmq.Frame": ... + @overload + def recv( + self, + flags: int = ..., + *, + copy: Literal[True], + track: bool = ..., + ) -> bytes: ... + @overload + def recv( + self, + flags: int = ..., + track: bool = False, + ) -> bytes: ... + @overload def recv( self, flags: Optional[int] = ..., - copy: Optional[bool] = ..., - track: Optional[bool] = ..., - ) -> Union[Frame, ByteString]: ... + copy: bool = ..., + track: Optional[bool] = False, + ) -> Union["zmq.Frame", bytes]: ... + def monitor(self, addr: Optional[str], events: int) -> None: ... + # draft methods + def join(self, group: str) -> None: ... + def leave(self, group: str) -> None: ... class Context: underlying: int - def __init__(self, io_threads: int = 1, **kwargs): ... - def get(self, option: int) -> Union[int, ByteString, str]: ... - def set(self, option: int, value: Union[int, ByteString, str]) -> None: ... + def __init__(self, io_threads: int = 1, shadow: Any = None): ... + def get(self, option: int) -> Union[int, bytes, str]: ... + def set(self, option: int, value: Union[int, bytes, str]) -> None: ... def socket(self, socket_type: int) -> Socket: ... def term(self) -> None: ... diff --git a/zmq/backend/cython/constant_enums.pxi b/zmq/backend/cython/constant_enums.pxi index 607f5971b..adf019cb1 100644 --- a/zmq/backend/cython/constant_enums.pxi +++ b/zmq/backend/cython/constant_enums.pxi @@ -96,6 +96,7 @@ cdef extern from "zmq.h" nogil: enum: ZMQ_PLAIN enum: ZMQ_CURVE enum: ZMQ_GSSAPI + enum: ZMQ_HWM enum: ZMQ_AFFINITY enum: ZMQ_ROUTING_ID enum: ZMQ_SUBSCRIBE diff --git a/zmq/backend/cython/context.pyx b/zmq/backend/cython/context.pyx index 004e1cebc..b12cd2473 100644 --- a/zmq/backend/cython/context.pyx +++ b/zmq/backend/cython/context.pyx @@ -28,12 +28,12 @@ cdef class Context: io_threads : int The number of IO threads. """ - + # no-op for the signature def __init__(self, io_threads=1, shadow=0): pass - - def __cinit__(self, int io_threads=1, size_t shadow=0, **kwargs): + + def __cinit__(self, int io_threads=1, size_t shadow=0): self.handle = NULL if shadow: self.handle = shadow @@ -68,12 +68,12 @@ cdef class Context: rc = zmq_ctx_destroy(self.handle) self.handle = NULL return rc - + def term(self): """ctx.term() Close or terminate the context. - + This can be called to close the context by hand. If this is not called, the context will automatically be closed when it is garbage collected. """ @@ -85,9 +85,9 @@ cdef class Context: # ignore interrupted term # see PEP 475 notes about close & EINTR for why pass - + self.closed = True - + def set(self, int option, optval): """ctx.set(option, optval) @@ -95,7 +95,7 @@ cdef class Context: See the 0MQ API documentation for zmq_ctx_set for details on specific options. - + .. versionadded:: libzmq-3.2 .. versionadded:: 13.0 @@ -104,9 +104,9 @@ cdef class Context: option : int The option to set. Available values will depend on your version of libzmq. Examples include:: - + zmq.IO_THREADS, zmq.MAX_SOCKETS - + optval : int The value of the option to set. """ @@ -116,7 +116,7 @@ cdef class Context: if self.closed: raise RuntimeError("Context has been destroyed") - + if not isinstance(optval, int): raise TypeError('expected int, got: %r' % optval) optval_int_c = optval @@ -130,7 +130,7 @@ cdef class Context: See the 0MQ API documentation for zmq_ctx_get for details on specific options. - + .. versionadded:: libzmq-3.2 .. versionadded:: 13.0 @@ -139,9 +139,9 @@ cdef class Context: option : int The option to get. Available values will depend on your version of libzmq. Examples include:: - + zmq.IO_THREADS, zmq.MAX_SOCKETS - + Returns ------- optval : int diff --git a/zmq/constants.py b/zmq/constants.py index 77748fd1d..b76a3bb15 100644 --- a/zmq/constants.py +++ b/zmq/constants.py @@ -121,6 +121,8 @@ class _OptType(Enum): class SocketOption(IntEnum): """Options for Socket.get/set""" + _opt_type: str + def __new__(cls, value, opt_type=_OptType.int): """Attach option type as `._opt_type`""" obj = int.__new__(cls, value) @@ -128,6 +130,7 @@ def __new__(cls, value, opt_type=_OptType.int): obj._opt_type = opt_type return obj + HWM = 1 AFFINITY = 4, _OptType.int64 ROUTING_ID = 5, _OptType.bytes SUBSCRIBE = 6, _OptType.bytes @@ -451,6 +454,7 @@ class DeviceType(IntEnum): PLAIN: int = SecurityMechanism.PLAIN CURVE: int = SecurityMechanism.CURVE GSSAPI: int = SecurityMechanism.GSSAPI +HWM: int = SocketOption.HWM AFFINITY: int = SocketOption.AFFINITY ROUTING_ID: int = SocketOption.ROUTING_ID SUBSCRIBE: int = SocketOption.SUBSCRIBE @@ -686,6 +690,7 @@ class DeviceType(IntEnum): "CURVE", "GSSAPI", "SocketOption", + "HWM", "AFFINITY", "ROUTING_ID", "SUBSCRIBE", diff --git a/zmq/devices/basedevice.py b/zmq/devices/basedevice.py index ee039e14e..aa29f6893 100644 --- a/zmq/devices/basedevice.py +++ b/zmq/devices/basedevice.py @@ -7,8 +7,9 @@ import time from multiprocessing import Process from threading import Thread -from typing import Any +from typing import Any, List, Optional, Tuple +import zmq from zmq import ETERM, QUEUE, REQ, Context, ZMQBindError, ZMQError, device @@ -68,7 +69,24 @@ class Device: depending on whether the device should share the global instance or not. """ - def __init__(self, device_type=QUEUE, in_type=None, out_type=None): + device_type: int + in_type: int + out_type: int + + _in_binds: List[str] + _in_connects: List[str] + _in_sockopts: List[Tuple[int, Any]] + _out_binds: List[str] + _out_connects: List[str] + _out_sockopts: List[Tuple[int, Any]] + _random_addrs: List[str] + + def __init__( + self, + device_type: int = QUEUE, + in_type: Optional[int] = None, + out_type: Optional[int] = None, + ) -> None: self.device_type = device_type if in_type is None: raise TypeError("in_type must be specified") @@ -86,14 +104,14 @@ def __init__(self, device_type=QUEUE, in_type=None, out_type=None): self.daemon = True self.done = False - def bind_in(self, addr): + def bind_in(self, addr: str) -> None: """Enqueue ZMQ address for binding on in_socket. See zmq.Socket.bind for details. """ self._in_binds.append(addr) - def bind_in_to_random_port(self, addr, *args, **kwargs): + def bind_in_to_random_port(self, addr: str, *args, **kwargs) -> int: """Enqueue a random port on the given interface for binding on in_socket. @@ -107,28 +125,28 @@ def bind_in_to_random_port(self, addr, *args, **kwargs): return port - def connect_in(self, addr): + def connect_in(self, addr: str) -> None: """Enqueue ZMQ address for connecting on in_socket. See zmq.Socket.connect for details. """ self._in_connects.append(addr) - def setsockopt_in(self, opt, value): + def setsockopt_in(self, opt: int, value: Any) -> None: """Enqueue setsockopt(opt, value) for in_socket See zmq.Socket.setsockopt for details. """ self._in_sockopts.append((opt, value)) - def bind_out(self, addr): + def bind_out(self, addr: str) -> None: """Enqueue ZMQ address for binding on out_socket. See zmq.Socket.bind for details. """ self._out_binds.append(addr) - def bind_out_to_random_port(self, addr, *args, **kwargs): + def bind_out_to_random_port(self, addr: str, *args, **kwargs) -> int: """Enqueue a random port on the given interface for binding on out_socket. @@ -142,21 +160,21 @@ def bind_out_to_random_port(self, addr, *args, **kwargs): return port - def connect_out(self, addr): + def connect_out(self, addr: str): """Enqueue ZMQ address for connecting on out_socket. See zmq.Socket.connect for details. """ self._out_connects.append(addr) - def setsockopt_out(self, opt, value): + def setsockopt_out(self, opt: int, value: Any): """Enqueue setsockopt(opt, value) for out_socket See zmq.Socket.setsockopt for details. """ self._out_sockopts.append((opt, value)) - def _reserve_random_port(self, addr, *args, **kwargs): + def _reserve_random_port(self, addr: str, *args, **kwargs) -> int: ctx = Context() binder = ctx.socket(REQ) @@ -179,9 +197,8 @@ def _reserve_random_port(self, addr, *args, **kwargs): return port - def _setup_sockets(self): - ctx = self.context_factory() - + def _setup_sockets(self) -> Tuple[zmq.Socket, zmq.Socket]: + ctx: zmq.Context[zmq.Socket] = self.context_factory() # type: ignore self._context = ctx # create the sockets @@ -209,7 +226,7 @@ def _setup_sockets(self): return ins, outs - def run_device(self): + def run_device(self) -> None: """The runner method. Do not call me directly, instead call ``self.start()``, just like a Thread. @@ -217,7 +234,7 @@ def run_device(self): ins, outs = self._setup_sockets() device(self.device_type, ins, outs) - def run(self): + def run(self) -> None: """wrap run_device in try/catch ETERM""" try: self.run_device() @@ -230,11 +247,11 @@ def run(self): finally: self.done = True - def start(self): + def start(self) -> None: """Start the device. Override me in subclass for other launchers.""" return self.run() - def join(self, timeout=None): + def join(self, timeout: Optional[float] = None) -> None: """wait for me to finish, like Thread.join. Reimplemented appropriately by subclasses.""" @@ -251,12 +268,12 @@ class BackgroundDevice(Device): launcher: Any = None _launch_class: Any = None - def start(self): + def start(self) -> None: self.launcher = self._launch_class(target=self.run) self.launcher.daemon = self.daemon return self.launcher.start() - def join(self, timeout=None): + def join(self, timeout: Optional[float] = None) -> None: return self.launcher.join(timeout=timeout) diff --git a/zmq/error.py b/zmq/error.py index f5522bc3c..94603fb8c 100644 --- a/zmq/error.py +++ b/zmq/error.py @@ -4,7 +4,7 @@ # Distributed under the terms of the Modified BSD License. from errno import EINTR -from typing import Optional, Tuple +from typing import Optional, Tuple, Union class ZMQBaseError(Exception): @@ -188,7 +188,10 @@ def __str__(self): ) -def _check_version(min_version_info: Tuple[int, int, int], msg: str = "Feature"): +def _check_version( + min_version_info: Union[Tuple[int], Tuple[int, int], Tuple[int, int, int]], + msg: str = "Feature", +): """Check for libzmq raises ZMQVersionError if current zmq version is not at least min_version diff --git a/zmq/eventloop/future.py b/zmq/eventloop/future.py index 8604fd56b..051ab6580 100644 --- a/zmq/eventloop/future.py +++ b/zmq/eventloop/future.py @@ -9,7 +9,9 @@ # Copyright (c) PyZMQ Developers. # Distributed under the terms of the Modified BSD License. +import asyncio import warnings +from typing import Any, Type from tornado.concurrent import Future from tornado.ioloop import IOLoop @@ -48,7 +50,7 @@ def cancel(self): class _AsyncTornado: - _Future = _TornadoFuture + _Future: Type[asyncio.Future] = _TornadoFuture _READ = IOLoop.READ _WRITE = IOLoop.WRITE @@ -79,7 +81,7 @@ class Socket(_AsyncTornado, _AsyncSocket): Poller._socket_class = Socket -class Context(_zmq.Context): +class Context(_zmq.Context[Socket]): # avoid sharing instance with base Context class _instance = None @@ -90,7 +92,7 @@ class Context(_zmq.Context): def _socket_class(self, socket_type): return Socket(self, socket_type) - def __init__(self, *args, **kwargs): + def __init__(self: "Context", *args: Any, **kwargs: Any) -> None: io_loop = kwargs.pop('io_loop', None) if io_loop is not None: warnings.warn( diff --git a/zmq/eventloop/zmqstream.py b/zmq/eventloop/zmqstream.py index 7f76f4d7c..2e14e6106 100644 --- a/zmq/eventloop/zmqstream.py +++ b/zmq/eventloop/zmqstream.py @@ -27,13 +27,16 @@ import sys import warnings from queue import Queue +from typing import Any, Callable, List, Optional, Sequence, Union, cast, overload import zmq +from zmq._typing import Literal from zmq.utils import jsonapi from .ioloop import IOLoop, gen_log try: + import tornado.ioloop from tornado.stack_context import wrap as stack_context_wrap # type: ignore except ImportError: if "zmq.eventloop.minitornado" in sys.modules: @@ -84,23 +87,25 @@ class ZMQStream: """ - socket = None - io_loop = None - poller = None - _send_queue = None - _recv_callback = None - _send_callback = None - _close_callback = None - _state = 0 - _flushed = False - _recv_copy = False - _fd = None - - def __init__(self, socket, io_loop=None): + socket: zmq.Socket + io_loop: "tornado.ioloop.IOLoop" + poller: zmq.Poller + _send_queue: Queue + _recv_callback: Optional[Callable] + _send_callback: Optional[Callable] + _close_callback = Optional[Callable] + _state: int = 0 + _flushed: bool = False + _recv_copy: bool = False + _fd: int + + def __init__( + self, socket: "zmq.Socket", io_loop: Optional["tornado.ioloop.IOLoop"] = None + ): self.socket = socket self.io_loop = io_loop or IOLoop.current() self.poller = zmq.Poller() - self._fd = self.socket.FD + self._fd = cast(int, self.socket.FD) self._send_queue = Queue() self._recv_callback = None @@ -135,11 +140,52 @@ def stop_on_err(self): """DEPRECATED, does nothing""" gen_log.warn("on_err does nothing, and will be removed") - def on_err(self, callback): + def on_err(self, callback: Callable): """DEPRECATED, does nothing""" gen_log.warn("on_err does nothing, and will be removed") - def on_recv(self, callback, copy=True): + @overload + def on_recv( + self, + callback: Callable[[List[bytes]], Any], + ) -> None: + ... + + @overload + def on_recv( + self, + callback: Callable[[List[bytes]], Any], + copy: Literal[True], + ) -> None: + ... + + @overload + def on_recv( + self, + callback: Callable[[List[zmq.Frame]], Any], + copy: Literal[False], + ) -> None: + ... + + @overload + def on_recv( + self, + callback: Union[ + Callable[[List[zmq.Frame]], Any], + Callable[[List[bytes]], Any], + ], + copy: bool = ..., + ): + ... + + def on_recv( + self, + callback: Union[ + Callable[[List[zmq.Frame]], Any], + Callable[[List[bytes]], Any], + ], + copy: bool = True, + ) -> None: """Register a callback for when a message is ready to recv. There can be only one callback registered at a time, so each @@ -174,7 +220,48 @@ def on_recv(self, callback, copy=True): else: self._add_io_state(zmq.POLLIN) - def on_recv_stream(self, callback, copy=True): + @overload + def on_recv_stream( + self, + callback: Callable[["ZMQStream", List[bytes]], Any], + ) -> None: + ... + + @overload + def on_recv_stream( + self, + callback: Callable[["ZMQStream", List[bytes]], Any], + copy: Literal[True], + ) -> None: + ... + + @overload + def on_recv_stream( + self, + callback: Callable[["ZMQStream", List[zmq.Frame]], Any], + copy: Literal[False], + ) -> None: + ... + + @overload + def on_recv_stream( + self, + callback: Union[ + Callable[["ZMQStream", List[zmq.Frame]], Any], + Callable[["ZMQStream", List[bytes]], Any], + ], + copy: bool = ..., + ): + ... + + def on_recv_stream( + self, + callback: Union[ + Callable[["ZMQStream", List[zmq.Frame]], Any], + Callable[["ZMQStream", List[bytes]], Any], + ], + copy: bool = True, + ): """Same as on_recv, but callback will get this stream as first argument callback must take exactly two arguments, as it will be called as:: @@ -186,9 +273,15 @@ def on_recv_stream(self, callback, copy=True): if callback is None: self.stop_on_recv() else: - self.on_recv(lambda msg: callback(self, msg), copy=copy) - def on_send(self, callback): + def stream_callback(msg): + return callback(self, msg) + + self.on_recv(stream_callback, copy=copy) + + def on_send( + self, callback: Callable[[Sequence[Any], Optional[zmq.MessageTracker]], Any] + ): """Register a callback to be called on each send There will be two arguments:: @@ -228,7 +321,12 @@ def on_send(self, callback): assert callback is None or callable(callback) self._send_callback = stack_context_wrap(callback) - def on_send_stream(self, callback): + def on_send_stream( + self, + callback: Callable[ + ["ZMQStream", Sequence[Any], Optional[zmq.MessageTracker]], Any + ], + ): """Same as on_send, but callback will get this stream as first argument Callback will be passed three arguments:: @@ -251,8 +349,14 @@ def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs): ) def send_multipart( - self, msg, flags=0, copy=True, track=False, callback=None, **kwargs - ): + self, + msg: Sequence[Any], + flags: int = 0, + copy: bool = True, + track: bool = False, + callback: Callable = None, + **kwargs: Any + ) -> None: """Send a multipart message, optionally also register a new callback for sends. See zmq.socket.send_multipart for details. """ @@ -266,7 +370,14 @@ def send_multipart( self.on_send(lambda *args: None) self._add_io_state(zmq.POLLOUT) - def send_string(self, u, flags=0, encoding='utf-8', callback=None, **kwargs): + def send_string( + self, + u: str, + flags: int = 0, + encoding: str = 'utf-8', + callback: Optional[Callable] = None, + **kwargs: Any + ): """Send a unicode message with an encoding. See zmq.socket.send_unicode for details. """ @@ -276,14 +387,27 @@ def send_string(self, u, flags=0, encoding='utf-8', callback=None, **kwargs): send_unicode = send_string - def send_json(self, obj, flags=0, callback=None, **kwargs): + def send_json( + self, + obj: Any, + flags: int = 0, + callback: Optional[Callable] = None, + **kwargs: Any + ): """Send json-serialized version of an object. See zmq.socket.send_json for details. """ msg = jsonapi.dumps(obj) return self.send(msg, flags=flags, callback=callback, **kwargs) - def send_pyobj(self, obj, flags=0, protocol=-1, callback=None, **kwargs): + def send_pyobj( + self, + obj: Any, + flags: int = 0, + protocol: int = -1, + callback: Optional[Callable] = None, + **kwargs: Any + ): """Send a Python object as a message using pickle to serialize. See zmq.socket.send_json for details. @@ -295,7 +419,7 @@ def _finish_flush(self): """callback for unsetting _flushed flag.""" self._flushed = False - def flush(self, flag=zmq.POLLIN | zmq.POLLOUT, limit=None): + def flush(self, flag: int = zmq.POLLIN | zmq.POLLOUT, limit: Optional[int] = None): """Flush pending messages. This method safely handles all pending incoming and/or outgoing messages, @@ -379,11 +503,11 @@ def update_flag(): self._rebuild_io_state() return count - def set_close_callback(self, callback): + def set_close_callback(self, callback: Optional[Callable]): """Call the given callback when the stream is closed.""" self._close_callback = stack_context_wrap(callback) - def close(self, linger=None): + def close(self, linger: Optional[int] = None) -> None: """Close this stream.""" if self.socket is not None: if self.socket.closed: @@ -401,19 +525,19 @@ def close(self, linger=None): else: self.io_loop.remove_handler(self.socket) self.socket.close(linger) - self.socket = None + self.socket = None # type: ignore if self._close_callback: self._run_callback(self._close_callback) - def receiving(self): + def receiving(self) -> bool: """Returns True if we are currently receiving from the stream.""" return self._recv_callback is not None - def sending(self): + def sending(self) -> bool: """Returns True if we are currently sending to the stream.""" return not self._send_queue.empty() - def closed(self): + def closed(self) -> bool: if self.socket is None: return True if self.socket.closed: @@ -421,6 +545,7 @@ def closed(self): # trigger our cleanup self.close() return True + return False def _run_callback(self, callback, *args, **kwargs): """Wrap running callbacks in try/except to allow us to diff --git a/zmq/green/core.py b/zmq/green/core.py index 4d865e83f..d8358c4a0 100644 --- a/zmq/green/core.py +++ b/zmq/green/core.py @@ -307,7 +307,7 @@ def set(self, opt, val): return super().set(opt, val) -class _Context(_original_Context): +class _Context(_original_Context[_Socket]): """Replacement for :class:`zmq.Context` Ensures that the greened Socket above is used in calls to `socket`. diff --git a/zmq/log/handlers.py b/zmq/log/handlers.py index c42ffe97d..5b2cba888 100644 --- a/zmq/log/handlers.py +++ b/zmq/log/handlers.py @@ -20,15 +20,13 @@ http://github.com/jtriley/StarCluster/blob/master/starcluster/logger.py """ +import logging + # Copyright (C) PyZMQ Developers # Distributed under the terms of the Modified BSD License. - - -import logging -from logging import DEBUG, ERROR, FATAL, INFO, WARN +from typing import Optional, Union import zmq -from zmq.utils.strtypes import bytes, cast_bytes, unicode TOPIC_DELIM = "::" # delimiter for splitting topics on the receiving end. @@ -56,11 +54,17 @@ class PUBHandler(logging.Handler): message by: log.debug("subtopic.subsub::the real message") """ - socket = None + ctx: zmq.Context + socket: zmq.Socket - def __init__(self, interface_or_socket, context=None, root_topic=''): + def __init__( + self, + interface_or_socket: Union[str, zmq.Socket], + context: Optional[zmq.Context] = None, + root_topic: str = '', + ) -> None: logging.Handler.__init__(self) - self._root_topic = root_topic + self.root_topic = root_topic self.formatters = { logging.DEBUG: logging.Formatter( "%(levelname)s %(filename)s:%(lineno)d - %(message)s\n" @@ -85,14 +89,14 @@ def __init__(self, interface_or_socket, context=None, root_topic=''): self.socket.bind(interface_or_socket) @property - def root_topic(self): + def root_topic(self) -> str: return self._root_topic @root_topic.setter - def root_topic(self, value): + def root_topic(self, value: str): self.setRootTopic(value) - def setRootTopic(self, root_topic): + def setRootTopic(self, root_topic: str): """Set the root topic for this handler. This value is prepended to all messages published by this handler, and it @@ -105,6 +109,8 @@ def setRootTopic(self, root_topic): the binary representation of the log level string (INFO, WARN, etc.). Note that ZMQ SUB sockets can have multiple subscriptions. """ + if isinstance(root_topic, bytes): + root_topic = root_topic.decode("utf8") self._root_topic = root_topic def setFormatter(self, fmt, level=logging.NOTSET): @@ -125,12 +131,13 @@ def format(self, record): def emit(self, record): """Emit a log message on my socket.""" + try: topic, record.msg = record.msg.split(TOPIC_DELIM, 1) - except Exception: + except ValueError: topic = "" try: - bmsg = cast_bytes(self.format(record)) + bmsg = self.format(record).encode("utf8") except Exception: self.handleError(record) return @@ -145,7 +152,7 @@ def emit(self, record): if topic: topic_list.append(topic) - btopic = b'.'.join(cast_bytes(t) for t in topic_list) + btopic = '.'.join(topic_list).encode("utf8") self.socket.send_multipart([btopic, bmsg]) diff --git a/zmq/ssh/tunnel.py b/zmq/ssh/tunnel.py index bb708028c..876e1c035 100644 --- a/zmq/ssh/tunnel.py +++ b/zmq/ssh/tunnel.py @@ -17,7 +17,6 @@ import warnings from getpass import getpass, getuser from multiprocessing import Process -from typing import Type try: with warnings.catch_warnings(): diff --git a/zmq/sugar/attrsettr.py b/zmq/sugar/attrsettr.py index 7f5d11ac7..5f7f232e9 100644 --- a/zmq/sugar/attrsettr.py +++ b/zmq/sugar/attrsettr.py @@ -4,12 +4,16 @@ # Distributed under the terms of the Modified BSD License. import errno +from typing import Generic, TypeVar, Union from .. import constants +T = TypeVar("T") +OptValT = Union[str, bytes, int] -class AttributeSetter: - def __setattr__(self, key, value): + +class AttributeSetter(Generic[T]): + def __setattr__(self, key: str, value: OptValT) -> None: """set zmq options by attribute""" if key in self.__dict__: @@ -31,11 +35,11 @@ def __setattr__(self, key, value): else: self._set_attr_opt(upper_key, opt, value) - def _set_attr_opt(self, name, opt, value): + def _set_attr_opt(self, name: str, opt: int, value: OptValT) -> None: """override if setattr should do something other than call self.set""" self.set(opt, value) - def __getattr__(self, key): + def __getattr__(self, key: str) -> OptValT: """get zmq options by attribute""" upper_key = key.upper() try: @@ -58,9 +62,15 @@ def __getattr__(self, key): else: raise - def _get_attr_opt(self, name, opt): + def _get_attr_opt(self, name, opt) -> OptValT: """override if getattr should do something other than call self.get""" return self.get(opt) + def get(self, opt: Union[T, int]) -> OptValT: + pass + + def set(self, opt: Union[T, int], val: OptValT) -> None: + pass + __all__ = ['AttributeSetter'] diff --git a/zmq/sugar/context.py b/zmq/sugar/context.py index 7356d946c..6821b0059 100644 --- a/zmq/sugar/context.py +++ b/zmq/sugar/context.py @@ -7,32 +7,32 @@ import os import warnings from threading import Lock -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar from weakref import WeakSet from zmq.backend import Context as ContextBase from zmq.constants import ContextOption, Errno, SocketOption from zmq.error import ZMQError -from .attrsettr import AttributeSetter +from .attrsettr import AttributeSetter, OptValT from .socket import Socket # notice when exiting, to avoid triggering term on exit _exiting = False -def _notice_atexit(): +def _notice_atexit() -> None: global _exiting _exiting = True atexit.register(_notice_atexit) - T = TypeVar('T', bound='Context') +ST = TypeVar('ST', bound='Socket', covariant=True) -class Context(ContextBase, AttributeSetter): +class Context(ContextBase, AttributeSetter, Generic[ST]): """Create a zmq Context A zmq Context creates sockets via its ``ctx.socket`` method. @@ -44,8 +44,10 @@ class Context(ContextBase, AttributeSetter): _instance_pid: Optional[int] = None _shadow = False _sockets: WeakSet + # mypy doesn't like a default value here + _socket_class: Type[ST] = Socket # type: ignore - def __init__(self, io_threads: int = 1, **kwargs): + def __init__(self: "Context[Socket]", io_threads: int = 1, **kwargs: Any) -> None: super().__init__(io_threads=io_threads, **kwargs) if kwargs.get('shadow', False): self._shadow = True @@ -54,7 +56,7 @@ def __init__(self, io_threads: int = 1, **kwargs): self.sockopts = {} self._sockets = WeakSet() - def __del__(self): + def __del__(self) -> None: """deleting a Context should terminate it, without trying non-threadsafe destroy""" # Calling locals() here conceals issue #1167 on Windows CPython 3.5.4. @@ -71,7 +73,7 @@ def __del__(self): _repr_cls = "zmq.Context" - def __repr__(self): + def __repr__(self) -> str: cls = self.__class__ # look up _repr_cls on exact class, not inherited _repr_cls = cls.__dict__.get("_repr_cls", None) @@ -87,20 +89,20 @@ def __repr__(self): sockets = "" return f"<{_repr_cls}({sockets}) at {hex(id(self))}{closed}>" - def __enter__(self): + def __enter__(self: T) -> T: return self - def __exit__(self, *args, **kwargs): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: self.term() - def __copy__(self, memo=None): + def __copy__(self: T, memo: Any = None) -> T: """Copying a Context creates a shadow copy""" return self.__class__.shadow(self.underlying) __deepcopy__ = __copy__ @classmethod - def shadow(cls, address): + def shadow(cls: Type[T], address: int) -> T: """Shadow an existing libzmq context address is the integer address of the libzmq context @@ -131,7 +133,7 @@ def shadow_pyczmq(cls: Type[T], ctx: Any) -> T: # static method copied from tornado IOLoop.instance @classmethod - def instance(cls: Type[T], io_threads=1) -> T: + def instance(cls: Type[T], io_threads: int = 1) -> T: """Returns a global Context instance. Most single-threaded applications have a single, global Context. @@ -193,7 +195,7 @@ def term(self) -> None: # Hooks for ctxopt completion # ------------------------------------------------------------------------- - def __dir__(self): + def __dir__(self) -> List[str]: keys = dir(self.__class__) keys.extend(ContextOption.__members__) return keys @@ -202,17 +204,17 @@ def __dir__(self): # Creating Sockets # ------------------------------------------------------------------------- - def _add_socket(self, socket: Any): + def _add_socket(self, socket: Any) -> None: """Add a weakref to a socket for Context.destroy / reference counting""" self._sockets.add(socket) - def _rm_socket(self, socket: Any): + def _rm_socket(self, socket: Any) -> None: """Remove a socket for Context.destroy / reference counting""" # allow _sockets to be None in case of process teardown if getattr(self, "_sockets", None) is not None: self._sockets.discard(socket) - def destroy(self, linger: Optional[float] = None): + def destroy(self, linger: Optional[float] = None) -> None: """Close all sockets associated with this context and then terminate the context. @@ -240,11 +242,7 @@ def destroy(self, linger: Optional[float] = None): self.term() - @property - def _socket_class(self): - return Socket - - def socket(self, socket_type: int, **kwargs): + def socket(self: T, socket_type: int, **kwargs: Any) -> ST: """Create a Socket associated with this Context. Parameters @@ -258,7 +256,7 @@ def socket(self, socket_type: int, **kwargs): """ if self.closed: raise ZMQError(Errno.ENOTSUP) - s = self._socket_class( # set PYTHONTRACEMALLOC=2 to get the calling frame + s: ST = self._socket_class( # set PYTHONTRACEMALLOC=2 to get the calling frame self, socket_type, **kwargs ) for opt, value in self.sockopts.items(): @@ -272,21 +270,21 @@ def socket(self, socket_type: int, **kwargs): self._add_socket(s) return s - def setsockopt(self, opt: int, value): + def setsockopt(self, opt: int, value: Any) -> None: """set default socket options for new sockets created by this Context .. versionadded:: 13.0 """ self.sockopts[opt] = value - def getsockopt(self, opt: int): + def getsockopt(self, opt: int) -> OptValT: """get default socket options for new sockets created by this Context .. versionadded:: 13.0 """ return self.sockopts[opt] - def _set_attr_opt(self, name: str, opt: int, value): + def _set_attr_opt(self, name: str, opt: int, value: OptValT) -> None: """set default sockopts as attributes""" if name in ContextOption.__members__: return self.set(opt, value) @@ -295,7 +293,7 @@ def _set_attr_opt(self, name: str, opt: int, value): else: raise AttributeError(f"No such context or socket option: {name}") - def _get_attr_opt(self, name: str, opt: int): + def _get_attr_opt(self, name: str, opt: int) -> OptValT: """get default sockopts as attributes""" if name in ContextOption.__members__: return self.get(opt) @@ -305,7 +303,7 @@ def _get_attr_opt(self, name: str, opt: int): else: return self.sockopts[opt] - def __delattr__(self, key: str): + def __delattr__(self, key: str) -> None: """delete default sockopts as attributes""" key = key.upper() try: diff --git a/zmq/sugar/poll.py b/zmq/sugar/poll.py index 37fd99d96..0c76dc084 100644 --- a/zmq/sugar/poll.py +++ b/zmq/sugar/poll.py @@ -3,9 +3,8 @@ # Copyright (C) PyZMQ Developers # Distributed under the terms of the Modified BSD License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple -import zmq from zmq.backend import zmq_poll from zmq.constants import POLLERR, POLLIN, POLLOUT @@ -20,11 +19,11 @@ class Poller: sockets: List[Tuple[Any, int]] _map: Dict - def __init__(self): + def __init__(self) -> None: self.sockets = [] self._map = {} - def __contains__(self, socket: Any): + def __contains__(self, socket: Any) -> bool: return socket in self._map def register(self, socket: Any, flags: int = POLLIN | POLLOUT): @@ -76,7 +75,7 @@ def unregister(self, socket: Any): for socket, flags in self.sockets[idx:]: self._map[socket] -= 1 - def poll(self, timeout: Optional[int] = None): + def poll(self, timeout: Optional[int] = None) -> List[Tuple[Any, int]]: """Poll the registered 0MQ or native fds for I/O. If there are currently events ready to be processed, this function will return immediately. diff --git a/zmq/sugar/socket.py b/zmq/sugar/socket.py index 92363bf13..0dd2d68a7 100644 --- a/zmq/sugar/socket.py +++ b/zmq/sugar/socket.py @@ -9,8 +9,22 @@ import random import sys import warnings +from typing import ( + Any, + Dict, + Generic, + List, + Optional, + Sequence, + Type, + TypeVar, + Union, + cast, + overload, +) import zmq +from zmq._typing import Literal from zmq.backend import Socket as SocketBase from zmq.error import ZMQBindError, ZMQError from zmq.utils import jsonapi @@ -25,20 +39,26 @@ except AttributeError: DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL +T = TypeVar("T", bound="Socket") -class _SocketContext: + +class _SocketContext(Generic[T]): """Context Manager for socket bind/unbind""" + socket: T + kind: str + addr: str + def __repr__(self): return f"" - def __init__(self, socket, kind, addr): + def __init__(self: "_SocketContext[T]", socket: T, kind: str, addr: str): assert kind in {"bind", "connect"} self.socket = socket self.kind = kind self.addr = addr - def __enter__(self): + def __enter__(self: "_SocketContext[T]") -> T: return self.socket def __exit__(self, *args): @@ -109,7 +129,7 @@ def __repr__(self): return f"<{_repr_cls}(zmq.{self._type_name}) at {hex(id(self))}{closed}>" # socket as context manager: - def __enter__(self): + def __enter__(self: T) -> T: """Sockets are context managers .. versionadded:: 14.4 @@ -123,14 +143,14 @@ def __exit__(self, *args, **kwargs): # Socket creation # ------------------------------------------------------------------------- - def __copy__(self, memo=None): + def __copy__(self: T, memo=None) -> T: """Copying a Socket creates a shadow copy""" return self.__class__.shadow(self.underlying) __deepcopy__ = __copy__ @classmethod - def shadow(cls, address): + def shadow(cls: Type[T], address: int) -> T: """Shadow an existing libzmq socket address is the integer address of the libzmq socket @@ -143,7 +163,7 @@ def shadow(cls, address): address = cast_int_addr(address) return cls(shadow=address) - def close(self, linger=None): + def close(self, linger=None) -> None: """ Close the socket. @@ -167,21 +187,21 @@ def close(self, linger=None): # Connect/Bind context managers # ------------------------------------------------------------------------- - def _connect_cm(self, addr): + def _connect_cm(self: T, addr: str) -> _SocketContext[T]: """Context manager to disconnect on exit .. versionadded:: 20.0 """ return _SocketContext(self, 'connect', addr) - def _bind_cm(self, addr): + def _bind_cm(self: T, addr: str) -> _SocketContext[T]: """Context manager to unbind on exit .. versionadded:: 20.0 """ return _SocketContext(self, 'bind', addr) - def bind(self, addr): + def bind(self: T, addr: str) -> _SocketContext[T]: """s.bind(addr) Bind the socket to an address. @@ -207,7 +227,7 @@ def bind(self, addr): super().bind(addr) return self._bind_cm(addr) - def connect(self, addr): + def connect(self: T, addr: str) -> _SocketContext[T]: """s.connect(addr) Connect to a remote 0MQ socket. @@ -238,7 +258,7 @@ def socket_type(self) -> int: warnings.warn( "Socket.socket_type is deprecated, use Socket.type", DeprecationWarning ) - return self.type + return cast(int, self.type) # ------------------------------------------------------------------------- # Hooks for sockopt completion @@ -283,7 +303,7 @@ def fileno(self): """ return self.FD - def subscribe(self, topic): + def subscribe(self, topic: Union[str, bytes]) -> None: """Subscribe to a topic Only for SUB sockets. @@ -294,7 +314,7 @@ def subscribe(self, topic): topic = topic.encode('utf8') self.set(zmq.SUBSCRIBE, topic) - def unsubscribe(self, topic): + def unsubscribe(self, topic: Union[str, bytes]) -> None: """Unsubscribe from a topic Only for SUB sockets. @@ -305,7 +325,7 @@ def unsubscribe(self, topic): topic = topic.encode('utf8') self.set(zmq.UNSUBSCRIBE, topic) - def set_string(self, option, optval, encoding='utf-8'): + def set_string(self, option: int, optval: str, encoding='utf-8') -> None: """Set socket options with a unicode object. This is simply a wrapper for setsockopt to protect from encoding ambiguity. @@ -328,7 +348,7 @@ def set_string(self, option, optval, encoding='utf-8'): setsockopt_unicode = setsockopt_string = set_string - def get_string(self, option, encoding='utf-8'): + def get_string(self, option: int, encoding='utf-8') -> str: """Get the value of a socket option. See the 0MQ documentation for details on specific options. @@ -346,11 +366,17 @@ def get_string(self, option, encoding='utf-8'): if SocketOption(option)._opt_type != _OptType.bytes: raise TypeError(f"option {option} will not return a string to be decoded") - return self.getsockopt(option).decode(encoding) + return cast(bytes, self.get(option)).decode(encoding) getsockopt_unicode = getsockopt_string = get_string - def bind_to_random_port(self, addr, min_port=49152, max_port=65536, max_tries=100): + def bind_to_random_port( + self: T, + addr: str, + min_port: int = 49152, + max_port: int = 65536, + max_tries: int = 100, + ) -> int: """Bind this socket to a random port in a range. If the port range is unspecified, the system will choose the port. @@ -384,7 +410,7 @@ def bind_to_random_port(self, addr, min_port=49152, max_port=65536, max_tries=10 # if LAST_ENDPOINT is supported, and min_port / max_port weren't specified, # we can bind to port 0 and let the OS do the work self.bind("%s:*" % addr) - url = self.last_endpoint.decode('ascii', 'replace') + url = cast(bytes, self.last_endpoint).decode('ascii', 'replace') _, port_s = url.rsplit(':', 1) return int(port_s) @@ -404,7 +430,7 @@ def bind_to_random_port(self, addr, min_port=49152, max_port=65536, max_tries=10 return port raise ZMQBindError("Could not bind socket to random port.") - def get_hwm(self): + def get_hwm(self) -> int: """Get the High Water Mark. On libzmq ≥ 3, this gets SNDHWM if available, otherwise RCVHWM @@ -413,15 +439,15 @@ def get_hwm(self): if major >= 3: # return sndhwm, fallback on rcvhwm try: - return self.getsockopt(zmq.SNDHWM) + return cast(int, self.get(zmq.SNDHWM)) except zmq.ZMQError: pass - return self.getsockopt(zmq.RCVHWM) + return cast(int, self.get(zmq.RCVHWM)) else: - return self.getsockopt(zmq.HWM) + return cast(int, self.get(zmq.HWM)) - def set_hwm(self, value): + def set_hwm(self, value: int) -> None: """Set the High Water Mark. On libzmq ≥ 3, this sets both SNDHWM and RCVHWM @@ -447,7 +473,7 @@ def set_hwm(self, value): if raised: raise raised else: - return self.setsockopt(zmq.HWM, value) + self.set(zmq.HWM, value) hwm = property( get_hwm, @@ -464,7 +490,65 @@ def set_hwm(self, value): # Sending and receiving messages # ------------------------------------------------------------------------- - def send(self, data, flags=0, copy=True, track=False, routing_id=None, group=None): + @overload + def send( + self, + data: Any, + flags: int = ..., + copy: bool = ..., + *, + track: Literal[True], + routing_id: Optional[int] = ..., + group: Optional[str] = ..., + ) -> "zmq.MessageTracker": + ... + + @overload + def send( + self, + data: Any, + flags: int = ..., + copy: bool = ..., + *, + track: Literal[False], + routing_id: Optional[int] = ..., + group: Optional[str] = ..., + ) -> None: + ... + + @overload + def send( + self, + data: Any, + flags: int = ..., + *, + copy: bool = ..., + routing_id: Optional[int] = ..., + group: Optional[str] = ..., + ) -> None: + ... + + @overload + def send( + self, + data: Any, + flags: int = ..., + copy: bool = ..., + track: bool = ..., + routing_id: Optional[int] = ..., + group: Optional[str] = ..., + ) -> Optional["zmq.MessageTracker"]: + ... + + def send( + self, + data: Any, + flags: int = 0, + copy: bool = True, + track: bool = False, + routing_id: Optional[int] = None, + group: Optional[str] = None, + ) -> Optional["zmq.MessageTracker"]: """Send a single zmq message frame on this socket. This queues the message to be sent by the IO thread at a later time. @@ -533,7 +617,14 @@ def send(self, data, flags=0, copy=True, track=False, routing_id=None, group=Non data.group = group return super().send(data, flags=flags, copy=copy, track=track) - def send_multipart(self, msg_parts, flags=0, copy=True, track=False, **kwargs): + def send_multipart( + self, + msg_parts: Sequence, + flags: int = 0, + copy: bool = True, + track: bool = False, + **kwargs, + ): """Send a sequence of buffers as a multipart message. The zmq.SNDMORE flag is added to all msg parts before the last. @@ -583,7 +674,31 @@ def send_multipart(self, msg_parts, flags=0, copy=True, track=False, **kwargs): # Send the last part without the extra SNDMORE flag. return self.send(msg_parts[-1], flags, copy=copy, track=track) - def recv_multipart(self, flags=0, copy=True, track=False): + @overload + def recv_multipart( + self, flags: int = ..., *, copy: Literal[True], track: bool = ... + ) -> List[bytes]: + ... + + @overload + def recv_multipart( + self, flags: int = ..., *, copy: Literal[False], track: bool = ... + ) -> List[zmq.Frame]: + ... + + @overload + def recv_multipart(self, flags: int = ..., *, track: bool = ...) -> List[bytes]: + ... + + @overload + def recv_multipart( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> Union[List[zmq.Frame], List[bytes]]: + ... + + def recv_multipart( + self, flags: int = 0, copy: bool = True, track: bool = False + ) -> Union[List[zmq.Frame], List[bytes]]: """Receive a multipart message as a list of bytes or Frame objects Parameters @@ -614,8 +729,9 @@ def recv_multipart(self, flags=0, copy=True, track=False): while self.getsockopt(zmq.RCVMORE): part = self.recv(flags, copy=copy, track=track) parts.append(part) - - return parts + # cast List[Union] to Union[List] + # how do we get mypy to recognize that return type is invariant on `copy`? + return cast(Union[List[zmq.Frame], List[bytes]], parts) def _deserialize(self, recvd, load): """Deserialize a received message @@ -685,7 +801,14 @@ def recv_serialized(self, deserialize, flags=0, copy=True): frames = self.recv_multipart(flags=flags, copy=copy) return self._deserialize(frames, deserialize) - def send_string(self, u, flags=0, copy=True, encoding='utf-8', **kwargs): + def send_string( + self, + u: str, + flags: int = 0, + copy: bool = True, + encoding: str = 'utf-8', + **kwargs, + ) -> Optional["zmq.Frame"]: """Send a Python unicode string as a message with an encoding. 0MQ communicates with raw bytes, so you must encode/decode @@ -706,7 +829,7 @@ def send_string(self, u, flags=0, copy=True, encoding='utf-8', **kwargs): send_unicode = send_string - def recv_string(self, flags=0, encoding='utf-8'): + def recv_string(self, flags: int = 0, encoding: str = 'utf-8') -> str: """Receive a unicode string, as sent by send_string. Parameters @@ -731,7 +854,9 @@ def recv_string(self, flags=0, encoding='utf-8'): recv_unicode = recv_string - def send_pyobj(self, obj, flags=0, protocol=DEFAULT_PROTOCOL, **kwargs): + def send_pyobj( + self, obj: Any, flags: int = 0, protocol: int = DEFAULT_PROTOCOL, **kwargs + ) -> Optional[zmq.Frame]: """Send a Python object as a message using pickle to serialize. Parameters @@ -747,7 +872,7 @@ def send_pyobj(self, obj, flags=0, protocol=DEFAULT_PROTOCOL, **kwargs): msg = pickle.dumps(obj, protocol) return self.send(msg, flags=flags, **kwargs) - def recv_pyobj(self, flags=0): + def recv_pyobj(self, flags: int = 0) -> Any: """Receive a Python object as a message using pickle to serialize. Parameters @@ -768,7 +893,7 @@ def recv_pyobj(self, flags=0): msg = self.recv(flags) return self._deserialize(msg, pickle.loads) - def send_json(self, obj, flags=0, **kwargs): + def send_json(self, obj: Any, flags: int = 0, **kwargs) -> None: """Send a Python object as a message using json to serialize. Keyword arguments are passed on to json.dumps @@ -787,7 +912,7 @@ def send_json(self, obj, flags=0, **kwargs): msg = jsonapi.dumps(obj, **kwargs) return self.send(msg, flags=flags, **send_kwargs) - def recv_json(self, flags=0, **kwargs): + def recv_json(self, flags: int = 0, **kwargs) -> Union[List, str, int, float, Dict]: """Receive a Python object as a message using json to serialize. Keyword arguments are passed on to json.loads @@ -812,7 +937,7 @@ def recv_json(self, flags=0, **kwargs): _poller_class = Poller - def poll(self, timeout=None, flags=zmq.POLLIN): + def poll(self, timeout=None, flags=zmq.POLLIN) -> int: """Poll the socket for events. See :class:`Poller` to wait for multiple sockets at once. @@ -840,7 +965,9 @@ def poll(self, timeout=None, flags=zmq.POLLIN): # return 0 if no events, otherwise return event bitfield return evts.get(self, 0) - def get_monitor_socket(self, events=None, addr=None): + def get_monitor_socket( + self: T, events: Optional[int] = None, addr: Optional[str] = None + ) -> T: """Return a connected PAIR socket ready to receive the event notifications. .. versionadded:: libzmq-4.0 @@ -873,7 +1000,7 @@ def get_monitor_socket(self, events=None, addr=None): if addr is None: # create endpoint name from internal fd - addr = "inproc://monitor.s-%d" % self.FD + addr = f"inproc://monitor.s-{self.FD}" if events is None: # use all events events = zmq.EVENT_ALL @@ -884,7 +1011,7 @@ def get_monitor_socket(self, events=None, addr=None): self._monitor_socket.connect(addr) return self._monitor_socket - def disable_monitor(self): + def disable_monitor(self) -> None: """Shutdown the PAIR socket (created using get_monitor_socket) that is serving socket events. diff --git a/zmq/sugar/version.py b/zmq/sugar/version.py index 050c94134..87192c360 100644 --- a/zmq/sugar/version.py +++ b/zmq/sugar/version.py @@ -11,7 +11,7 @@ VERSION_MINOR = 3 VERSION_PATCH = 0 VERSION_EXTRA = "" -__version__ = '%i.%i.%i' % (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) +__version__: str = '%i.%i.%i' % (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) version_info: Union[Tuple[int, int, int], Tuple[int, int, int, float]] = ( VERSION_MAJOR, @@ -28,7 +28,7 @@ float('inf'), ) -__revision__ = '' +__revision__: str = '' def pyzmq_version() -> str: diff --git a/zmq/tests/test_mypy.py b/zmq/tests/test_mypy.py index 85f59b632..026463404 100644 --- a/zmq/tests/test_mypy.py +++ b/zmq/tests/test_mypy.py @@ -34,12 +34,14 @@ def resolve_repo_dir(path): mypy_dir = resolve_repo_dir("mypy_tests") -def run_mypy(path): +def run_mypy(*mypy_args): """Run mypy for a path Captures output and reports it on errors """ - p = Popen([sys.executable, "-m", "mypy", path], stdout=PIPE, stderr=STDOUT) + p = Popen( + [sys.executable, "-m", "mypy"] + list(mypy_args), stdout=PIPE, stderr=STDOUT + ) o, _ = p.communicate() out = o.decode("utf8", "replace") print(out) @@ -60,7 +62,7 @@ def run_mypy(path): @pytest.mark.parametrize("example", examples) def test_mypy_example(example): example_dir = os.path.join(examples_dir, example) - run_mypy(example_dir) + run_mypy("--disallow-untyped-calls", example_dir) if os.path.exists(mypy_dir): @@ -69,4 +71,4 @@ def test_mypy_example(example): @pytest.mark.parametrize("filename", mypy_tests) def test_mypy(filename): - run_mypy(os.path.join(mypy_dir, filename)) + run_mypy("--disallow-untyped-calls", os.path.join(mypy_dir, filename)) diff --git a/zmq/tests/test_socket.py b/zmq/tests/test_socket.py index 5544477a5..c9d795f9a 100644 --- a/zmq/tests/test_socket.py +++ b/zmq/tests/test_socket.py @@ -211,6 +211,7 @@ def test_int_sockopts(self): continue if opt.name.startswith( ( + 'HWM', 'ROUTER', 'XPUB', 'TCP', diff --git a/zmq/utils/monitor.py b/zmq/utils/monitor.py index fb12125b5..e901a6c1d 100644 --- a/zmq/utils/monitor.py +++ b/zmq/utils/monitor.py @@ -4,12 +4,20 @@ # Distributed under the terms of the Modified BSD License. import struct +from typing import Any, List, Tuple, cast import zmq +from zmq._typing import TypedDict from zmq.error import _check_version -def parse_monitor_message(msg): +class _MonitorMessage(TypedDict): + event: int + value: int + endpoint: bytes + + +def parse_monitor_message(msg: List[bytes]) -> _MonitorMessage: """decode zmq_monitor event messages. Parameters @@ -30,18 +38,18 @@ def parse_monitor_message(msg): event : dict event description as dict with the keys `event`, `value`, and `endpoint`. """ - if len(msg) != 2 or len(msg[0]) != 6: raise RuntimeError("Invalid event message format: %s" % msg) - event = { - 'event': struct.unpack("=hi", msg[0])[0], - 'value': struct.unpack("=hi", msg[0])[1], + event_id, value = struct.unpack("=hi", msg[0]) + event: _MonitorMessage = { + 'event': event_id, + 'value': value, 'endpoint': msg[1], } return event -def recv_monitor_message(socket, flags=0): +def recv_monitor_message(socket: zmq.Socket, flags: int = 0) -> _MonitorMessage: """Receive and decode the given raw message from the monitoring socket and return a dict. Requires libzmq ≥ 4.0