diff --git a/examples/asyncio/coroutines.py b/examples/asyncio/coroutines.py
index 3d1ad7d6a..f4355664c 100644
--- a/examples/asyncio/coroutines.py
+++ b/examples/asyncio/coroutines.py
@@ -13,14 +13,14 @@
ctx = Context.instance()
-async def ping():
+async def ping() -> None:
"""print dots to indicate idleness"""
while True:
await asyncio.sleep(0.5)
print('.')
-async def receiver():
+async def receiver() -> None:
"""receive messages with polling"""
pull = ctx.socket(zmq.PULL)
pull.connect(url)
@@ -34,7 +34,7 @@ async def receiver():
print('recvd', msg)
-async def sender():
+async def sender() -> None:
"""send a message every second"""
tic = time.time()
push = ctx.socket(zmq.PUSH)
diff --git a/examples/asyncio/helloworld_pubsub_dealerrouter.py b/examples/asyncio/helloworld_pubsub_dealerrouter.py
index dd6006754..f9ec53e41 100644
--- a/examples/asyncio/helloworld_pubsub_dealerrouter.py
+++ b/examples/asyncio/helloworld_pubsub_dealerrouter.py
@@ -12,17 +12,18 @@
import logging
import traceback
+import zmq
import zmq.asyncio
from zmq.asyncio import Context
# set message based on language
class HelloWorld:
- def __init__(self):
+ def __init__(self) -> None:
self.lang = 'eng'
self.msg = "Hello World"
- def change_language(self):
+ def change_language(self) -> None:
if self.lang == 'eng':
self.lang = 'jap'
self.msg = "Hello Sekai"
@@ -31,7 +32,7 @@ def change_language(self):
self.lang = 'eng'
self.msg = "Hello World"
- def msg_pub(self):
+ def msg_pub(self) -> str:
return self.msg
@@ -39,13 +40,13 @@ def msg_pub(self):
# changes "World" to "Sekai" and returns message 'sekai'
class HelloWorldPrinter:
# process received message
- def msg_sub(self, msg):
+ def msg_sub(self, msg: str) -> None:
print(f"message received world: {msg}")
# manages message flow between publishers and subscribers
class HelloWorldMessage:
- def __init__(self, url='127.0.0.1', port='5555'):
+ def __init__(self, url: str = '127.0.0.1', port: int = 5555):
# get ZeroMQ version
print("Current libzmq version is %s" % zmq.zmq_version())
print("Current pyzmq version is %s" % zmq.__version__)
@@ -57,6 +58,8 @@ def __init__(self, url='127.0.0.1', port='5555'):
# init hello world publisher obj
self.hello_world = HelloWorld()
+ def main(self) -> None:
+
# activate publishers / subscribers
asyncio.get_event_loop().run_until_complete(
asyncio.wait(
@@ -70,7 +73,7 @@ def __init__(self, url='127.0.0.1', port='5555'):
)
# generates message "Hello World" and publish to topic 'world'
- async def hello_world_pub(self):
+ async def hello_world_pub(self) -> None:
pub = self.ctx.socket(zmq.PUB)
pub.connect(self.url)
@@ -106,7 +109,7 @@ async def hello_world_pub(self):
pass
# processes message topic 'world'; "Hello World" or "Hello Sekai"
- async def hello_world_sub(self):
+ async def hello_world_sub(self) -> None:
print("Setting up world sub")
obj = HelloWorldPrinter()
# setup subscriber
@@ -120,7 +123,7 @@ async def hello_world_sub(self):
# keep listening to all published message on topic 'world'
while True:
[topic, msg] = await sub.recv_multipart()
- print(f"world sub; topic: {topic}\tmessage: {msg}")
+ print(f"world sub; topic: {topic.decode()}\tmessage: {msg.decode()}")
# process message
obj.msg_sub(msg.decode('utf-8'))
@@ -141,7 +144,7 @@ async def hello_world_sub(self):
pass
# Deal a message to topic 'lang' that language should be changed
- async def lang_changer_dealer(self):
+ async def lang_changer_dealer(self) -> None:
# setup dealer
deal = self.ctx.socket(zmq.DEALER)
deal.setsockopt(zmq.IDENTITY, b'lang_dealer')
@@ -176,7 +179,7 @@ async def lang_changer_dealer(self):
pass
# changes Hello xxx message when a command is received from topic 'lang'; keeps listening for commands
- async def lang_changer_router(self):
+ async def lang_changer_router(self) -> None:
# setup router
rout = self.ctx.socket(zmq.ROUTER)
rout.bind(self.url[:-1] + f"{int(self.url[-1]) + 1}")
@@ -188,7 +191,9 @@ async def lang_changer_router(self):
# keep listening to all published message on topic 'world'
while True:
[id_dealer, msg] = await rout.recv_multipart()
- print(f"Command rout; Sender ID: {id_dealer};\tmessage: {msg}")
+ print(
+ f"Command rout; Sender ID: {id_dealer!r};\tmessage: {msg.decode()}"
+ )
self.hello_world.change_language()
print(
@@ -208,5 +213,10 @@ async def lang_changer_router(self):
pass
+def main() -> None:
+ hello_world = HelloWorldMessage()
+ hello_world.main()
+
+
if __name__ == '__main__':
- HelloWorldMessage()
+ main()
diff --git a/examples/asyncio/tornado_asyncio.py b/examples/asyncio/tornado_asyncio.py
index fd3132b52..d3dbed277 100644
--- a/examples/asyncio/tornado_asyncio.py
+++ b/examples/asyncio/tornado_asyncio.py
@@ -5,35 +5,27 @@
import asyncio
from tornado.ioloop import IOLoop
-from tornado.platform.asyncio import AsyncIOMainLoop
import zmq.asyncio
-# Tell tornado to use asyncio
-AsyncIOMainLoop().install()
-# This must be instantiated after the installing the IOLoop
-queue = asyncio.Queue() # type: ignore
-ctx = zmq.asyncio.Context()
-
-
-async def pushing():
- server = ctx.socket(zmq.PUSH)
+async def pushing() -> None:
+ server = zmq.asyncio.Context.instance().socket(zmq.PUSH)
server.bind('tcp://*:9000')
while True:
await server.send(b"Hello")
await asyncio.sleep(1)
-async def pulling():
- client = ctx.socket(zmq.PULL)
+async def pulling() -> None:
+ client = zmq.asyncio.Context.instance().socket(zmq.PULL)
client.connect('tcp://127.0.0.1:9000')
while True:
greeting = await client.recv()
print(greeting)
-def zmq_tornado_loop():
+def main() -> None:
loop = IOLoop.current()
loop.spawn_callback(pushing)
loop.spawn_callback(pulling)
@@ -41,4 +33,4 @@ def zmq_tornado_loop():
if __name__ == '__main__':
- zmq_tornado_loop()
+ main()
diff --git a/examples/chat/display.py b/examples/chat/display.py
index dc4aad9a0..028904ae4 100644
--- a/examples/chat/display.py
+++ b/examples/chat/display.py
@@ -17,11 +17,12 @@
#
# You should have received a copy of the Lesser GNU General Public License
# along with this program. If not, see .
+from typing import List
import zmq
-def main(addrs):
+def main(addrs: List[str]):
context = zmq.Context()
socket = context.socket(zmq.SUB)
diff --git a/examples/chat/prompt.py b/examples/chat/prompt.py
index 9c540ad1d..707d02dc4 100644
--- a/examples/chat/prompt.py
+++ b/examples/chat/prompt.py
@@ -21,7 +21,7 @@
import zmq
-def main(addr, who):
+def main(addr: str, who: str):
ctx = zmq.Context()
socket = ctx.socket(zmq.PUB)
diff --git a/examples/cython/example.py b/examples/cython/example.py
index 111e29ba3..128d6ee51 100644
--- a/examples/cython/example.py
+++ b/examples/cython/example.py
@@ -7,7 +7,7 @@
import zmq
-def python_sender(url, n):
+def python_sender(url: str, n: int) -> None:
"""Use entirely high-level Python APIs to send messages"""
ctx = zmq.Context()
s = ctx.socket(zmq.PUSH)
@@ -23,7 +23,7 @@ def python_sender(url, n):
s.send(buf)
-def main():
+def main() -> None:
import argparse
parser = argparse.ArgumentParser(description="send & recv messages with Cython")
diff --git a/examples/draft/radio-dish.py b/examples/draft/radio-dish.py
index 5d8afa7f1..51d093aef 100644
--- a/examples/draft/radio-dish.py
+++ b/examples/draft/radio-dish.py
@@ -13,13 +13,13 @@
for i in range(10):
time.sleep(0.1)
- radio.send(b'%03i' % i, group='numbers')
+ radio.send(f'{i:03}'.encode('ascii'), group='numbers')
try:
msg = dish.recv(copy=False)
except zmq.Again:
print('missed a message')
continue
- print("Received {}:{}".format(msg.group, msg.bytes.decode('utf8')))
+ print(f"Received {msg.group}:{msg.bytes.decode('utf8')}")
dish.close()
radio.close()
diff --git a/examples/eventloop/asyncweb.py b/examples/eventloop/asyncweb.py
index 286fc63aa..305c73391 100644
--- a/examples/eventloop/asyncweb.py
+++ b/examples/eventloop/asyncweb.py
@@ -20,7 +20,7 @@
from zmq.eventloop.future import Context as FutureContext
-def slow_responder():
+def slow_responder() -> None:
"""thread for slowly responding to replies."""
ctx = zmq.Context()
socket = ctx.socket(zmq.ROUTER)
@@ -35,14 +35,14 @@ def slow_responder():
i += 1
-def dot():
+def dot() -> None:
"""callback for showing that IOLoop is still responsive while we wait"""
sys.stdout.write('.')
sys.stdout.flush()
class TestHandler(web.RequestHandler):
- async def get(self):
+ async def get(self) -> None:
ctx = FutureContext.instance()
s = ctx.socket(zmq.DEALER)
@@ -56,7 +56,7 @@ async def get(self):
self.write(reply)
-def main():
+def main() -> None:
worker = threading.Thread(target=slow_responder)
worker.daemon = True
worker.start()
diff --git a/examples/eventloop/echostream.py b/examples/eventloop/echostream.py
index 310d7652a..28e31844c 100644
--- a/examples/eventloop/echostream.py
+++ b/examples/eventloop/echostream.py
@@ -1,19 +1,22 @@
#!/usr/bin/env python
"""Adapted echo.py to put the send in the event loop using a ZMQStream.
"""
+from typing import List
+
+from tornado import ioloop
import zmq
-from zmq.eventloop import ioloop, zmqstream
+from zmq.eventloop import zmqstream
-loop = ioloop.IOLoop.instance()
+loop = ioloop.IOLoop.current()
ctx = zmq.Context()
s = ctx.socket(zmq.ROUTER)
s.bind('tcp://127.0.0.1:5555')
-stream = zmqstream.ZMQStream(s, loop)
+stream = zmqstream.ZMQStream(s)
-def echo(msg):
+def echo(msg: List[bytes]):
print(" ".join(map(repr, msg)))
stream.send_multipart(msg)
diff --git a/examples/gevent/poll.py b/examples/gevent/poll.py
index 722bcfc5a..09701ba02 100644
--- a/examples/gevent/poll.py
+++ b/examples/gevent/poll.py
@@ -34,11 +34,11 @@ def sender():
while msgcnt < 10:
socks = dict(poller.poll())
if receiver1 in socks and socks[receiver1] == zmq.POLLIN:
- print("Message from receiver1: %s" % receiver1.recv())
+ print(f"Message from receiver1: {receiver1.recv()!r}")
msgcnt += 1
if receiver2 in socks and socks[receiver2] == zmq.POLLIN:
- print("Message from receiver2: %s" % receiver2.recv())
+ print(f"Message from receiver2: {receiver2.recv()!r}")
msgcnt += 1
-print("%d messages received" % msgcnt)
+print(f"{msgcnt} messages received")
diff --git a/examples/gevent/simple.py b/examples/gevent/simple.py
index 1425eb7e7..1996b2000 100644
--- a/examples/gevent/simple.py
+++ b/examples/gevent/simple.py
@@ -1,3 +1,5 @@
+from typing import Optional
+
from gevent import spawn, spawn_later
import zmq.green as zmq
@@ -23,7 +25,7 @@
sock.connect('ipc:///tmp/zmqtest')
-def get_objs(sock):
+def get_objs(sock: zmq.Socket):
while True:
o = sock.recv_pyobj()
print('received python object:', o)
@@ -32,7 +34,7 @@ def get_objs(sock):
break
-def print_every(s, t=None):
+def print_every(s: str, t: Optional[float] = None):
print(s)
if t:
spawn_later(t, print_every, s, t)
diff --git a/examples/heartbeat/heartbeater.py b/examples/heartbeat/heartbeater.py
index d5df19c9a..d6a8032b1 100644
--- a/examples/heartbeat/heartbeater.py
+++ b/examples/heartbeat/heartbeater.py
@@ -14,9 +14,12 @@
"""
import time
+from typing import Set
+
+from tornado import ioloop
import zmq
-from zmq.eventloop import ioloop, zmqstream
+from zmq.eventloop import zmqstream
class HeartBeater:
@@ -24,7 +27,13 @@ class HeartBeater:
pingstream: a PUB stream
pongstream: an ROUTER stream"""
- def __init__(self, loop, pingstream, pongstream, period=1000):
+ def __init__(
+ self,
+ loop: ioloop.IOLoop,
+ pingstream: zmqstream.ZMQStream,
+ pongstream: zmqstream.ZMQStream,
+ period: int = 1000,
+ ):
self.loop = loop
self.period = period
@@ -32,16 +41,16 @@ def __init__(self, loop, pingstream, pongstream, period=1000):
self.pongstream = pongstream
self.pongstream.on_recv(self.handle_pong)
- self.hearts = set()
- self.responses = set()
+ self.hearts: Set = set()
+ self.responses: Set = set()
self.lifetime = 0
- self.tic = time.time()
+ self.tic = time.monotonic()
- self.caller = ioloop.PeriodicCallback(self.beat, period, self.loop)
+ self.caller = ioloop.PeriodicCallback(self.beat, period)
self.caller.start()
def beat(self):
- toc = time.time()
+ toc = time.monotonic()
self.lifetime += toc - self.tic
self.tic = toc
print(self.lifetime)
@@ -50,18 +59,20 @@ def beat(self):
heartfailures = self.hearts.difference(goodhearts)
newhearts = self.responses.difference(goodhearts)
# print(newhearts, goodhearts, heartfailures)
- map(self.handle_new_heart, newhearts)
- map(self.handle_heart_failure, heartfailures)
+ for heart in newhearts:
+ self.handle_new_heart(heart)
+ for heart in heartfailures:
+ self.handle_heart_failure(heart)
self.responses = set()
- print("%i beating hearts: %s" % (len(self.hearts), self.hearts))
+ print(f"{len(self.hearts)} beating hearts: {self.hearts}")
self.pingstream.send(str(self.lifetime))
def handle_new_heart(self, heart):
- print("yay, got new heart %s!" % heart)
+ print(f"yay, got new heart {heart}!")
self.hearts.add(heart)
def handle_heart_failure(self, heart):
- print("Heart %s failed :(" % heart)
+ print(f"Heart {heart} failed :(")
self.hearts.remove(heart)
def handle_pong(self, msg):
diff --git a/examples/heartbeat/ping.py b/examples/heartbeat/ping.py
index 49d8a9d1e..d5e3832f3 100644
--- a/examples/heartbeat/ping.py
+++ b/examples/heartbeat/ping.py
@@ -1,7 +1,7 @@
#!/usr/bin/env python
"""For use with pong.py
-This script simply pings a process started by pong.py or tspong.py, to
+This script simply pings a process started by pong.py or tspong.py, to
demonstrate that zmq remains responsive while Python blocks.
Authors
@@ -9,7 +9,6 @@
* MinRK
"""
-import sys
import time
import numpy
@@ -25,7 +24,8 @@
time.sleep(1)
n = 0
while True:
- time.sleep(numpy.random.random())
+ t: float = numpy.random.random()
+ time.sleep(t)
for i in range(4):
n += 1
msg = 'ping %i' % n
diff --git a/examples/logger/zmqlogger.py b/examples/logger/zmqlogger.py
index 08acee1b7..196379f42 100644
--- a/examples/logger/zmqlogger.py
+++ b/examples/logger/zmqlogger.py
@@ -25,7 +25,7 @@
)
-def sub_logger(port, level=logging.DEBUG):
+def sub_logger(port: int, level: int = logging.DEBUG) -> None:
ctx = zmq.Context()
sub = ctx.socket(zmq.SUB)
sub.bind('tcp://127.0.0.1:%i' % port)
@@ -33,23 +33,24 @@ def sub_logger(port, level=logging.DEBUG):
logging.basicConfig(level=level)
while True:
- level, message = sub.recv_multipart()
+ level_name, message = sub.recv_multipart()
+ level_name = level_name.decode('ascii').lower()
message = message.decode('ascii')
if message.endswith('\n'):
# trim trailing newline, which will get appended again
message = message[:-1]
- log = getattr(logging, level.lower().decode('ascii'))
+ log = getattr(logging, level_name)
log(message)
-def log_worker(port, interval=1, level=logging.DEBUG):
+def log_worker(port: int, interval: float = 1, level: int = logging.DEBUG) -> None:
ctx = zmq.Context()
pub = ctx.socket(zmq.PUB)
pub.connect('tcp://127.0.0.1:%i' % port)
logger = logging.getLogger(str(os.getpid()))
logger.setLevel(level)
- handler = PUBHandler(pub)
+ handler: PUBHandler = PUBHandler(pub)
logger.addHandler(handler)
print("starting logger at %i with level=%s" % (os.getpid(), level))
diff --git a/examples/mongodb/client.py b/examples/mongodb/client.py
index 39a5dbb40..3456e2980 100644
--- a/examples/mongodb/client.py
+++ b/examples/mongodb/client.py
@@ -6,6 +6,7 @@
# -----------------------------------------------------------------------------
import json
+from typing import Any, Dict, List
import zmq
@@ -15,26 +16,26 @@ class MongoZMQClient:
Client that connects with MongoZMQ server to add/fetch docs
"""
- def __init__(self, connect_addr='tcp://127.0.0.1:5000'):
+ def __init__(self, connect_addr: str = 'tcp://127.0.0.1:5000'):
self._context = zmq.Context()
self._socket = self._context.socket(zmq.DEALER)
self._socket.connect(connect_addr)
- def _send_recv_msg(self, msg):
+ def _send_recv_msg(self, msg: List[bytes]) -> str:
self._socket.send_multipart(msg)
- return self._socket.recv_multipart()[0]
+ return self._socket.recv_multipart()[0].decode("utf8")
- def get_doc(self, keys):
- msg = ['get', json.dumps(keys)]
+ def get_doc(self, keys: Dict[str, Any]) -> Dict:
+ msg = [b'get', json.dumps(keys).encode("utf8")]
json_str = self._send_recv_msg(msg)
return json.loads(json_str)
- def add_doc(self, doc):
- msg = ['add', json.dumps(doc)]
+ def add_doc(self, doc: Dict) -> str:
+ msg = [b'add', json.dumps(doc).encode("utf8")]
return self._send_recv_msg(msg)
-def main():
+def main() -> None:
client = MongoZMQClient()
for i in range(10):
doc = {'job': str(i)}
diff --git a/examples/mongodb/controller.py b/examples/mongodb/controller.py
index 5c16cc9a1..4d51a0a24 100644
--- a/examples/mongodb/controller.py
+++ b/examples/mongodb/controller.py
@@ -7,6 +7,7 @@
import json
import sys
+from typing import Any, Dict, Optional, Union
import pymongo
import pymongo.json_util
@@ -21,7 +22,9 @@ class MongoZMQ:
NOTE: mongod must be started before using this class
"""
- def __init__(self, db_name, table_name, bind_addr="tcp://127.0.0.1:5000"):
+ def __init__(
+ self, db_name: str, table_name: str, bind_addr: str = "tcp://127.0.0.1:5000"
+ ):
"""
bind_addr: address to bind zmq socket on
db_name: name of database to write to (created if doesn't exist)
@@ -34,20 +37,21 @@ def __init__(self, db_name, table_name, bind_addr="tcp://127.0.0.1:5000"):
self._db = self._conn[self._db_name]
self._table = self._db[self._table_name]
- def _doc_to_json(self, doc):
+ def _doc_to_json(self, doc: Any) -> str:
return json.dumps(doc, default=pymongo.json_util.default)
- def add_document(self, doc):
+ def add_document(self, doc: Dict) -> Optional[str]:
"""
Inserts a document (dictionary) into mongo database table
"""
- print('adding docment %s' % (doc))
+ print(f'adding document {doc}')
try:
self._table.insert(doc)
except Exception as e:
return 'Error: %s' % e
+ return None
- def get_document_by_keys(self, keys):
+ def get_document_by_keys(self, keys: Dict[str, Any]) -> Union[Dict, str]:
"""
Attempts to return a single document from database table that matches
each key/value in keys dictionary.
@@ -58,7 +62,7 @@ def get_document_by_keys(self, keys):
except Exception as e:
return 'Error: %s' % e
- def start(self):
+ def start(self) -> None:
context = zmq.Context()
socket = context.socket(zmq.ROUTER)
socket.bind(self._bind_addr)
@@ -88,7 +92,7 @@ def start(self):
socket.send_multipart(reply)
-def main():
+def main() -> None:
MongoZMQ('ipcontroller', 'jobs').start()
diff --git a/examples/monitoring/simple_monitor.py b/examples/monitoring/simple_monitor.py
index 890051695..597b2a3dd 100644
--- a/examples/monitoring/simple_monitor.py
+++ b/examples/monitoring/simple_monitor.py
@@ -10,11 +10,14 @@
import threading
import time
+from typing import Any, Dict
import zmq
from zmq.utils.monitor import recv_monitor_message
-line = lambda: print('-' * 40)
+
+def line() -> None:
+ print('-' * 40)
print("libzmq-%s" % zmq.zmq_version())
@@ -30,10 +33,12 @@
EVENT_MAP[value] = name
-def event_monitor(monitor):
+def event_monitor(monitor: zmq.Socket) -> None:
while monitor.poll():
- evt = recv_monitor_message(monitor)
- evt.update({'description': EVENT_MAP[evt['event']]})
+ evt: Dict[str, Any] = {}
+ mon_evt = recv_monitor_message(monitor)
+ evt.update(mon_evt)
+ evt['description'] = EVENT_MAP[evt['event']]
print(f"Event: {evt}")
if evt['event'] == zmq.EVENT_MONITOR_STOPPED:
break
diff --git a/examples/pubsub/publisher.py b/examples/pubsub/publisher.py
index 7f51f462c..fbfc6d4ff 100644
--- a/examples/pubsub/publisher.py
+++ b/examples/pubsub/publisher.py
@@ -18,7 +18,7 @@
import zmq
-def sync(bind_to):
+def sync(bind_to: str) -> None:
# use bind socket + 1
sync_with = ':'.join(
bind_to.split(':')[:-1] + [str(int(bind_to.split(':')[-1]) + 1)]
@@ -32,7 +32,7 @@ def sync(bind_to):
s.send(b'GO')
-def main():
+def main() -> None:
if len(sys.argv) != 4:
print('usage: publisher ')
sys.exit(1)
diff --git a/examples/pubsub/subscriber.py b/examples/pubsub/subscriber.py
index 6d7e9d9e6..061a9e525 100644
--- a/examples/pubsub/subscriber.py
+++ b/examples/pubsub/subscriber.py
@@ -19,7 +19,7 @@
import zmq
-def sync(connect_to):
+def sync(connect_to: str) -> None:
# use connect socket + 1
sync_with = ':'.join(
connect_to.split(':')[:-1] + [str(int(connect_to.split(':')[-1]) + 1)]
@@ -31,7 +31,7 @@ def sync(connect_to):
s.recv()
-def main():
+def main() -> None:
if len(sys.argv) != 3:
print('usage: subscriber ')
sys.exit(1)
diff --git a/examples/pubsub/topics_pub.py b/examples/pubsub/topics_pub.py
index 9f9cfafaf..dfad578b5 100755
--- a/examples/pubsub/topics_pub.py
+++ b/examples/pubsub/topics_pub.py
@@ -24,7 +24,7 @@
import zmq
-def main():
+def main() -> None:
if len(sys.argv) != 2:
print('usage: publisher ')
sys.exit(1)
@@ -46,7 +46,7 @@ def main():
s.bind(bind_to)
print("Starting broadcast on topics:")
- print(" %s" % all_topics)
+ print(f" {all_topics}")
print("Hit Ctrl-C to stop broadcasting.")
print("Waiting so subscriber sockets can connect...")
print("")
@@ -55,9 +55,9 @@ def main():
msg_counter = itertools.count()
try:
for topic in itertools.cycle(all_topics):
- msg_body = str(next(msg_counter)).encode('utf-8')
- print(f' Topic: {topic}, msg:{msg_body}')
- s.send_multipart([topic, msg_body])
+ msg_body = str(next(msg_counter))
+ print(f" Topic: {topic.decode('utf8')}, msg:{msg_body}")
+ s.send_multipart([topic, msg_body.encode("utf8")])
# short wait so we don't hog the cpu
time.sleep(0.1)
except KeyboardInterrupt:
diff --git a/examples/pubsub/topics_sub.py b/examples/pubsub/topics_sub.py
index 360f2b8ee..3915b4cb0 100755
--- a/examples/pubsub/topics_sub.py
+++ b/examples/pubsub/topics_sub.py
@@ -25,7 +25,7 @@
import zmq
-def main():
+def main() -> None:
if len(sys.argv) < 2:
print('usage: subscriber [topic topic ...]')
sys.exit(1)
diff --git a/examples/security/asyncio-ironhouse.py b/examples/security/asyncio-ironhouse.py
index 435bcf54c..69fbca752 100644
--- a/examples/security/asyncio-ironhouse.py
+++ b/examples/security/asyncio-ironhouse.py
@@ -23,7 +23,7 @@
from zmq.auth.asyncio import AsyncioAuthenticator
-async def run():
+async def run() -> None:
'''Run Ironhouse example'''
# These directories are generated by the generate_certificates script
@@ -82,7 +82,7 @@ async def run():
# use copy=False to allow access to message properties via the zmq.Frame API
# default recv(copy=True) returns only bytes, discarding properties
identity, msg = await server.recv_multipart(copy=False)
- logging.info(f"Received {msg.bytes} from {msg['User-Id']!r}")
+ logging.info(f"Received {msg.bytes!r} from {msg['User-Id']!r}")
if msg.bytes == b"Hello":
logging.info("Ironhouse test OK")
else:
diff --git a/examples/security/generate_certificates.py b/examples/security/generate_certificates.py
index 627c858af..2b1511989 100644
--- a/examples/security/generate_certificates.py
+++ b/examples/security/generate_certificates.py
@@ -12,11 +12,12 @@
import os
import shutil
+from typing import Union
import zmq.auth
-def generate_certificates(base_dir):
+def generate_certificates(base_dir: Union[str, os.PathLike]) -> None:
'''Generate client and server CURVE certificate files'''
keys_dir = os.path.join(base_dir, 'certificates')
public_keys_dir = os.path.join(base_dir, 'public_keys')
diff --git a/examples/security/ioloop-ironhouse.py b/examples/security/ioloop-ironhouse.py
index e318c861e..e5c3793c6 100644
--- a/examples/security/ioloop-ironhouse.py
+++ b/examples/security/ioloop-ironhouse.py
@@ -16,21 +16,24 @@
import logging
import os
import sys
+from typing import List
+
+from tornado import ioloop
import zmq
import zmq.auth
from zmq.auth.ioloop import IOLoopAuthenticator
-from zmq.eventloop import ioloop, zmqstream
+from zmq.eventloop import zmqstream
-def echo(server, msg):
+def echo(server: zmqstream.ZMQStream, msg: List[bytes]) -> None:
logging.debug("server recvd %s", msg)
reply = msg + [b'World']
logging.debug("server sending %s", reply)
server.send_multipart(reply)
-def setup_server(server_secret_file, endpoint='tcp://127.0.0.1:9000'):
+def setup_server(server_secret_file: str, endpoint: str = 'tcp://127.0.0.1:9000'):
"""setup a simple echo server with CURVE auth"""
server = zmq.Context.instance().socket(zmq.ROUTER)
@@ -46,7 +49,7 @@ def setup_server(server_secret_file, endpoint='tcp://127.0.0.1:9000'):
return server_stream
-def client_msg_recvd(msg):
+def client_msg_recvd(msg: List[bytes]):
logging.debug("client recvd %s", msg)
logging.info("Ironhouse test OK")
# stop the loop when we get the reply
@@ -54,7 +57,9 @@ def client_msg_recvd(msg):
def setup_client(
- client_secret_file, server_public_file, endpoint='tcp://127.0.0.1:9000'
+ client_secret_file: str,
+ server_public_file: str,
+ endpoint: str = 'tcp://127.0.0.1:9000',
):
"""setup a simple client with CURVE auth"""
@@ -77,7 +82,7 @@ def setup_client(
return client_stream
-def run():
+def run() -> None:
'''Run Ironhouse example'''
# These direcotries are generated by the generate_certificates script
diff --git a/examples/security/ironhouse.py b/examples/security/ironhouse.py
index 8d683210f..02912b117 100644
--- a/examples/security/ironhouse.py
+++ b/examples/security/ironhouse.py
@@ -20,7 +20,7 @@
from zmq.auth.thread import ThreadAuthenticator
-def run():
+def run() -> None:
'''Run Ironhouse example'''
# These directories are generated by the generate_certificates script
diff --git a/examples/security/stonehouse.py b/examples/security/stonehouse.py
index ce730b084..2e4f0eb02 100644
--- a/examples/security/stonehouse.py
+++ b/examples/security/stonehouse.py
@@ -21,7 +21,7 @@
from zmq.auth.thread import ThreadAuthenticator
-def run():
+def run() -> None:
'''Run Stonehouse example'''
# These directories are generated by the generate_certificates script
diff --git a/examples/security/strawhouse.py b/examples/security/strawhouse.py
index 0cc6057e3..034c681e3 100644
--- a/examples/security/strawhouse.py
+++ b/examples/security/strawhouse.py
@@ -19,7 +19,7 @@
from zmq.auth.thread import ThreadAuthenticator
-def run():
+def run() -> None:
'''Run strawhouse client'''
allow_test_pass = False
diff --git a/examples/security/woodhouse.py b/examples/security/woodhouse.py
index c1186cea6..2a5f6599d 100644
--- a/examples/security/woodhouse.py
+++ b/examples/security/woodhouse.py
@@ -18,7 +18,7 @@
from zmq.auth.thread import ThreadAuthenticator
-def run():
+def run() -> None:
'''Run woodhouse example'''
valid_client_test_pass = False
diff --git a/examples/serialization/serialsocket.py b/examples/serialization/serialsocket.py
index 823c82890..9a070581c 100644
--- a/examples/serialization/serialsocket.py
+++ b/examples/serialization/serialsocket.py
@@ -2,6 +2,7 @@
import pickle
import zlib
+from typing import Any, Dict, cast
import numpy
@@ -18,20 +19,24 @@ class SerializingSocket(zmq.Socket):
for reconstructing the array on the other side (dtype,shape).
"""
- def send_zipped_pickle(self, obj, flags=0, protocol=-1):
+ def send_zipped_pickle(
+ self, obj: Any, flags: int = 0, protocol: int = pickle.HIGHEST_PROTOCOL
+ ) -> None:
"""pack and compress an object with pickle and zlib."""
pobj = pickle.dumps(obj, protocol)
zobj = zlib.compress(pobj)
print('zipped pickle is %i bytes' % len(zobj))
return self.send(zobj, flags=flags)
- def recv_zipped_pickle(self, flags=0):
+ def recv_zipped_pickle(self, flags: int = 0) -> Any:
"""reconstruct a Python object sent with zipped_pickle"""
zobj = self.recv(flags)
pobj = zlib.decompress(zobj)
return pickle.loads(pobj)
- def send_array(self, A, flags=0, copy=True, track=False):
+ def send_array(
+ self, A: numpy.ndarray, flags: int = 0, copy: bool = True, track: bool = False
+ ) -> Any:
"""send a numpy array with metadata"""
md = dict(
dtype=str(A.dtype),
@@ -40,19 +45,21 @@ def send_array(self, A, flags=0, copy=True, track=False):
self.send_json(md, flags | zmq.SNDMORE)
return self.send(A, flags, copy=copy, track=track)
- def recv_array(self, flags=0, copy=True, track=False):
+ def recv_array(
+ self, flags: int = 0, copy: bool = True, track: bool = False
+ ) -> numpy.ndarray:
"""recv a numpy array"""
- md = self.recv_json(flags=flags)
+ md = cast(Dict[str, Any], self.recv_json(flags=flags))
msg = self.recv(flags=flags, copy=copy, track=track)
A = numpy.frombuffer(msg, dtype=md['dtype'])
return A.reshape(md['shape'])
-class SerializingContext(zmq.Context):
+class SerializingContext(zmq.Context[SerializingSocket]):
_socket_class = SerializingSocket
-def main():
+def main() -> None:
ctx = SerializingContext()
req = ctx.socket(zmq.REQ)
rep = ctx.socket(zmq.REP)
diff --git a/examples/win32-interrupt/display.py b/examples/win32-interrupt/display.py
index 5a026c1fc..ee41ef9ca 100644
--- a/examples/win32-interrupt/display.py
+++ b/examples/win32-interrupt/display.py
@@ -1,12 +1,13 @@
"""The display part of a simply two process chat app."""
# This file has been placed in the public domain.
+from typing import List
import zmq
from zmq.utils.win32 import allow_interrupt
-def main(addrs):
+def main(addrs: List[str]):
context = zmq.Context()
control = context.socket(zmq.PUB)
control.bind('inproc://control')
@@ -24,14 +25,14 @@ def interrupt_polling():
with allow_interrupt(interrupt_polling):
message = ''
while message != 'quit':
- message = updates.recv_multipart()
- if len(message) < 2:
+ recvd = updates.recv_multipart()
+ if len(recvd) < 2:
print('Invalid message.')
continue
- account = message[0]
- message = ' '.join(message[1:])
+ account = recvd[0].decode("utf8")
+ message = ' '.join(b.decode("utf8") for b in recvd[1:])
if message == 'quit':
- print('Killed by "%s".' % account)
+ print(f'Killed by {account}.')
break
print(f'{account}: {message}')
diff --git a/examples/win32-interrupt/prompt.py b/examples/win32-interrupt/prompt.py
index 89d470aed..588c56f8f 100644
--- a/examples/win32-interrupt/prompt.py
+++ b/examples/win32-interrupt/prompt.py
@@ -21,7 +21,7 @@
import zmq
-def main(addr, account):
+def main(addr: str, account: str) -> None:
ctx = zmq.Context()
socket = ctx.socket(zmq.PUB)
diff --git a/mypy_tests/test_context.py b/mypy_tests/test_context.py
index ec650dd0b..270574804 100644
--- a/mypy_tests/test_context.py
+++ b/mypy_tests/test_context.py
@@ -3,7 +3,6 @@
ctx = zmq.Context.instance()
s = ctx.socket(zmq.PUSH)
s.send(b"buf")
-
ctx2 = zmq.Context.shadow_pyczmq(123)
s2 = ctx2.socket(zmq.PUSH)
s.send(b"buf")
diff --git a/zmq/__init__.pyi b/zmq/__init__.pyi
index 99a714007..c9bc245d3 100644
--- a/zmq/__init__.pyi
+++ b/zmq/__init__.pyi
@@ -2,6 +2,10 @@ from typing import List
from . import backend, sugar
+COPY_THRESHOLD: int
+DRAFT_API: bool
+__version__: str
+
# mypy doesn't like overwriting symbols with * so be explicit
# about what comes from backend, not from sugar
# see tools/backend_imports.py to generate this list
@@ -21,8 +25,5 @@ from .constants import *
from .error import *
from .sugar import *
-COPY_THRESHOLD: int
-DRAFT_API: bool
-
def get_includes() -> List[str]: ...
def get_library_dirs() -> List[str]: ...
diff --git a/zmq/_future.py b/zmq/_future.py
index 3a1dc01cb..b2b3c2b8b 100644
--- a/zmq/_future.py
+++ b/zmq/_future.py
@@ -4,14 +4,37 @@
# Distributed under the terms of the Modified BSD License.
import warnings
+from asyncio import Future
from collections import deque, namedtuple
from itertools import chain
-from typing import Type
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
import zmq as _zmq
from zmq import EVENTS, POLLIN, POLLOUT
+from zmq._typing import Literal
+
+
+class _FutureEvent(NamedTuple):
+ future: Future
+ kind: str
+ kwargs: Dict
+ msg: Any
+ timer: Any
-_FutureEvent = namedtuple('_FutureEvent', ('future', 'kind', 'kwargs', 'msg', 'timer'))
# These are incomplete classes and need a Mixin for compatibility with an eventloop
# defining the following attributes:
@@ -25,9 +48,10 @@
class _Async:
"""Mixin for common async logic"""
- _current_loop = None
+ _current_loop: Any = None
+ _Future: Type[Future]
- def _get_loop(self):
+ def _get_loop(self) -> Any:
"""Get event loop
Notice if event loop has changed,
@@ -44,19 +68,30 @@ def _get_loop(self):
self._init_io_state(current_loop)
return current_loop
- def _default_loop(self):
+ def _default_loop(self) -> Any:
raise NotImplementedError("Must be implemented in a subclass")
- def _init_io_state(self, loop=None):
+ def _init_io_state(self, loop=None) -> None:
pass
class _AsyncPoller(_Async, _zmq.Poller):
"""Poller that returns a Future on poll, instead of blocking."""
- _socket_class = None # type: Type[_AsyncSocket]
+ _socket_class: Type["_AsyncSocket"]
+ _READ: int
+ _WRITE: int
+ raw_sockets: List[Any]
- def poll(self, timeout=-1):
+ def _watch_raw_socket(self, loop: Any, socket: Any, evt: int, f: Callable) -> None:
+ """Schedule callback for a raw socket"""
+ raise NotImplementedError()
+
+ def _unwatch_raw_sockets(self, loop: Any, *sockets: Any) -> None:
+ """Unschedule callback for a raw socket"""
+ raise NotImplementedError()
+
+ def poll(self, timeout=-1) -> Awaitable[List[Tuple[Any, int]]]: # type: ignore
"""Return a Future for a poll event"""
future = self._Future()
if timeout == 0:
@@ -74,7 +109,7 @@ def poll(self, timeout=-1):
watcher = self._Future()
# watch raw sockets:
- raw_sockets = []
+ raw_sockets: List[Any] = []
def wake_raw(*args):
if not watcher.done():
@@ -155,6 +190,9 @@ def cancel():
pass
+T = TypeVar("T", bound="_AsyncSocket")
+
+
class _AsyncSocket(_Async, _zmq.Socket):
# Warning : these class variables are only here to allow to call super().__setattr__.
@@ -162,18 +200,23 @@ class _AsyncSocket(_Async, _zmq.Socket):
_recv_futures = None
_send_futures = None
_state = 0
- _shadow_sock = None
+ _shadow_sock: "_zmq.Socket"
_poller_class = _AsyncPoller
_fd = None
- def __init__(self, context=None, socket_type=-1, io_loop=None, **kwargs):
+ def __init__(
+ self,
+ context=None,
+ socket_type=-1,
+ io_loop=None,
+ _from_socket: Optional["_zmq.Socket"] = None,
+ **kwargs,
+ ) -> None:
if isinstance(context, _zmq.Socket):
- context, from_socket = (None, context)
- else:
- from_socket = kwargs.pop('_from_socket', None)
- if from_socket is not None:
- super().__init__(shadow=from_socket.underlying)
- self._shadow_sock = from_socket
+ context, _from_socket = (None, context)
+ if _from_socket is not None:
+ super().__init__(shadow=_from_socket.underlying)
+ self._shadow_sock = _from_socket
else:
super().__init__(context, socket_type, **kwargs)
self._shadow_sock = _zmq.Socket.shadow(self.underlying)
@@ -191,15 +234,16 @@ def __init__(self, context=None, socket_type=-1, io_loop=None, **kwargs):
self._fd = self._shadow_sock.FD
@classmethod
- def from_socket(cls, socket, io_loop=None):
+ def from_socket(cls: Type[T], socket: "_zmq.Socket", io_loop: Any = None) -> T:
"""Create an async socket from an existing Socket"""
return cls(_from_socket=socket, io_loop=io_loop)
- def close(self, linger=None):
+ def close(self, linger: Optional[int] = None) -> None:
if not self.closed and self._fd is not None:
- for event in list(
+ event_list: List[_FutureEvent] = list(
chain(self._recv_futures or [], self._send_futures or [])
- ):
+ )
+ for event in event_list:
if not event.future.done():
try:
event.future.cancel()
@@ -219,7 +263,33 @@ def get(self, key):
get.__doc__ = _zmq.Socket.get.__doc__
- def recv_multipart(self, flags=0, copy=True, track=False):
+ @overload # type: ignore
+ def recv_multipart(
+ self, flags: int = 0, *, track: bool = False
+ ) -> Awaitable[List[bytes]]:
+ ...
+
+ @overload
+ def recv_multipart(
+ self, flags: int = 0, *, copy: Literal[True], track: bool = False
+ ) -> Awaitable[List[bytes]]:
+ ...
+
+ @overload
+ def recv_multipart(
+ self, flags: int = 0, *, copy: Literal[False], track: bool = False
+ ) -> Awaitable[List[_zmq.Frame]]: # type: ignore
+ ...
+
+ @overload
+ def recv_multipart(
+ self, flags: int = 0, copy: bool = True, track: bool = False
+ ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]:
+ ...
+
+ def recv_multipart(
+ self, flags: int = 0, copy: bool = True, track: bool = False
+ ) -> Awaitable[Union[List[bytes], List[_zmq.Frame]]]:
"""Receive a complete multipart zmq message.
Returns a Future whose result will be a multipart message.
@@ -228,7 +298,9 @@ def recv_multipart(self, flags=0, copy=True, track=False):
'recv_multipart', dict(flags=flags, copy=copy, track=track)
)
- def recv(self, flags=0, copy=True, track=False):
+ def recv( # type: ignore
+ self, flags: int = 0, copy: bool = True, track: bool = False
+ ) -> Awaitable[Union[bytes, _zmq.Frame]]:
"""Receive a single zmq frame.
Returns a Future, whose result will be the received frame.
@@ -237,7 +309,9 @@ def recv(self, flags=0, copy=True, track=False):
"""
return self._add_recv_event('recv', dict(flags=flags, copy=copy, track=track))
- def send_multipart(self, msg, flags=0, copy=True, track=False, **kwargs):
+ def send_multipart( # type: ignore
+ self, msg_parts: Any, flags: int = 0, copy: bool = True, track=False, **kwargs
+ ) -> Awaitable[Optional[_zmq.MessageTracker]]:
"""Send a complete multipart zmq message.
Returns a Future that resolves when sending is complete.
@@ -245,9 +319,16 @@ def send_multipart(self, msg, flags=0, copy=True, track=False, **kwargs):
kwargs['flags'] = flags
kwargs['copy'] = copy
kwargs['track'] = track
- return self._add_send_event('send_multipart', msg=msg, kwargs=kwargs)
-
- def send(self, msg, flags=0, copy=True, track=False, **kwargs):
+ return self._add_send_event('send_multipart', msg=msg_parts, kwargs=kwargs)
+
+ def send( # type: ignore
+ self,
+ data: Any,
+ flags: int = 0,
+ copy: bool = True,
+ track: bool = False,
+ **kwargs: Any,
+ ) -> Awaitable[Optional[_zmq.MessageTracker]]:
"""Send a single zmq frame.
Returns a Future that resolves when sending is complete.
@@ -258,7 +339,7 @@ def send(self, msg, flags=0, copy=True, track=False, **kwargs):
kwargs['copy'] = copy
kwargs['track'] = track
kwargs.update(dict(flags=flags, copy=copy, track=track))
- return self._add_send_event('send', msg=msg, kwargs=kwargs)
+ return self._add_send_event('send', msg=data, kwargs=kwargs)
def _deserialize(self, recvd, load):
"""Deserialize with Futures"""
@@ -292,7 +373,7 @@ def _chain_cancel(_):
return f
- def poll(self, timeout=None, flags=_zmq.POLLIN):
+ def poll(self, timeout=None, flags=_zmq.POLLIN) -> Awaitable[int]: # type: ignore
"""poll the socket for events
returns a Future for the poll results.
@@ -303,7 +384,7 @@ def poll(self, timeout=None, flags=_zmq.POLLIN):
p = self._poller_class()
p.register(self, flags)
- f = p.poll(timeout)
+ f = cast(Future, p.poll(timeout))
future = self._Future()
@@ -330,6 +411,13 @@ def unwrap_result(f):
f.add_done_callback(unwrap_result)
return future
+ # overrides only necessary for updated types
+ def recv_string(self, *args, **kwargs) -> Awaitable[str]: # type: ignore
+ return super().recv_string(*args, **kwargs) # type: ignore
+
+ def send_string(self, s: str, flags: int = 0, encoding: str = 'utf-8') -> Awaitable[None]: # type: ignore
+ return super().send_string(s, flags=flags, encoding=encoding) # type: ignore
+
def _add_timeout(self, future, timeout):
"""Add a timeout for a send or recv Future"""
diff --git a/zmq/_typing.py b/zmq/_typing.py
new file mode 100644
index 000000000..e08013605
--- /dev/null
+++ b/zmq/_typing.py
@@ -0,0 +1,19 @@
+import sys
+from typing import Any, Dict
+
+if sys.version_info >= (3, 8):
+ from typing import Literal, TypedDict
+else:
+ # avoid runtime dependency on typing_extensions on py37
+ try:
+ from typing_extensions import Literal, TypedDict # type: ignore
+ except ImportError:
+
+ class _Literal:
+ def __getitem__(self, key):
+ return Any
+
+ Literal = _Literal() # type: ignore
+
+ class TypedDict(Dict): # type: ignore
+ pass
diff --git a/zmq/asyncio.py b/zmq/asyncio.py
index 2078636a2..8a0963e5b 100644
--- a/zmq/asyncio.py
+++ b/zmq/asyncio.py
@@ -151,7 +151,7 @@ def _clear_io_state(self):
Poller._socket_class = Socket
-class Context(_zmq.Context):
+class Context(_zmq.Context[Socket]):
"""Context for creating asyncio-compatible Sockets"""
_socket_class = Socket
diff --git a/zmq/auth/asyncio.py b/zmq/auth/asyncio.py
index 51a264d5b..b1f3e7bdb 100644
--- a/zmq/auth/asyncio.py
+++ b/zmq/auth/asyncio.py
@@ -7,6 +7,8 @@
# Distributed under the terms of the Modified BSD License.
import asyncio
+import warnings
+from typing import Any, Optional
import zmq
from zmq.asyncio import Poller
@@ -17,27 +19,34 @@
class AsyncioAuthenticator(Authenticator):
"""ZAP authentication for use in the asyncio IO loop"""
- def __init__(self, context=None, loop=None):
+ __poller: Optional[Poller]
+ __task: Any
+ zap_socket: "zmq.asyncio.Socket"
+
+ def __init__(self, context: Optional["zmq.Context"] = None, loop: Any = None):
super().__init__(context)
- self.loop = loop or asyncio.get_event_loop()
+ if loop is not None:
+ warnings.warn(f"{self.__class__.__name__}(loop) is deprecated and ignored")
self.__poller = None
self.__task = None
- async def __handle_zap(self):
+ async def __handle_zap(self) -> None:
while True:
+ if self.__poller is None:
+ break
events = await self.__poller.poll()
if self.zap_socket in dict(events):
msg = await self.zap_socket.recv_multipart()
self.handle_zap_message(msg)
- def start(self):
+ def start(self) -> None:
"""Start ZAP authentication"""
super().start()
self.__poller = Poller()
self.__poller.register(self.zap_socket, zmq.POLLIN)
self.__task = asyncio.ensure_future(self.__handle_zap())
- def stop(self):
+ def stop(self) -> None:
"""Stop ZAP authentication"""
if self.__task:
self.__task.cancel()
diff --git a/zmq/auth/base.py b/zmq/auth/base.py
index 861b0617d..87acdddc0 100644
--- a/zmq/auth/base.py
+++ b/zmq/auth/base.py
@@ -4,6 +4,8 @@
# Distributed under the terms of the Modified BSD License.
import logging
+import os
+from typing import Any, Dict, List, Optional, Set, Tuple, Union
import zmq
from zmq.error import _check_version
@@ -43,13 +45,29 @@ class Authenticator:
- GSSAPI requires no configuration.
"""
- def __init__(self, context=None, encoding='utf-8', log=None):
+ context: "zmq.Context"
+ encoding: str
+ allow_any: bool
+ credentials_providers: Dict[str, Any]
+ zap_socket: "zmq.Socket"
+ whitelist: Set[str]
+ blacklist: Set[str]
+ passwords: Dict[str, Dict[str, str]]
+ certs: Dict[str, Dict[bytes, Any]]
+ log: Any
+
+ def __init__(
+ self,
+ context: Optional["zmq.Context"] = None,
+ encoding: str = 'utf-8',
+ log: Any = None,
+ ):
_check_version((4, 0), "security")
self.context = context or zmq.Context.instance()
self.encoding = encoding
self.allow_any = False
self.credentials_providers = {}
- self.zap_socket = None
+ self.zap_socket = None # type: ignore
self.whitelist = set()
self.blacklist = set()
# passwords is a dict keyed by domain and contains values
@@ -60,20 +78,20 @@ def __init__(self, context=None, encoding='utf-8', log=None):
self.certs = {}
self.log = log or logging.getLogger('zmq.auth')
- def start(self):
+ def start(self) -> None:
"""Create and bind the ZAP socket"""
self.zap_socket = self.context.socket(zmq.REP)
self.zap_socket.linger = 1
self.zap_socket.bind("inproc://zeromq.zap.01")
self.log.debug("Starting")
- def stop(self):
+ def stop(self) -> None:
"""Close the ZAP socket"""
if self.zap_socket:
self.zap_socket.close()
- self.zap_socket = None
+ self.zap_socket = None # type: ignore
- def allow(self, *addresses):
+ def allow(self, *addresses: str) -> None:
"""Allow (whitelist) IP address(es).
Connections from addresses not in the whitelist will be rejected.
@@ -88,7 +106,7 @@ def allow(self, *addresses):
self.log.debug("Allowing %s", ','.join(addresses))
self.whitelist.update(addresses)
- def deny(self, *addresses):
+ def deny(self, *addresses: str) -> None:
"""Deny (blacklist) IP address(es).
Addresses not in the blacklist will be allowed to continue with authentication.
@@ -100,7 +118,9 @@ def deny(self, *addresses):
self.log.debug("Denying %s", ','.join(addresses))
self.blacklist.update(addresses)
- def configure_plain(self, domain='*', passwords=None):
+ def configure_plain(
+ self, domain: str = '*', passwords: Dict[str, str] = None
+ ) -> None:
"""Configure PLAIN authentication for a given domain.
PLAIN authentication uses a plain-text password file.
@@ -111,7 +131,9 @@ def configure_plain(self, domain='*', passwords=None):
self.passwords[domain] = passwords
self.log.debug("Configure plain: %s", domain)
- def configure_curve(self, domain='*', location=None):
+ def configure_curve(
+ self, domain: str = '*', location: Union[str, os.PathLike] = "."
+ ) -> None:
"""Configure CURVE authentication for a given domain.
CURVE authentication uses a directory that holds all public client certificates,
@@ -136,7 +158,9 @@ def configure_curve(self, domain='*', location=None):
except Exception as e:
self.log.error("Failed to load CURVE certs from %s: %s", location, e)
- def configure_curve_callback(self, domain='*', credentials_provider=None):
+ def configure_curve_callback(
+ self, domain: str = '*', credentials_provider: Any = None
+ ) -> None:
"""Configure CURVE authentication for a given domain.
CURVE authentication using a callback function validating
@@ -170,7 +194,7 @@ def callback(self, domain, key):
else:
self.log.error("None credentials_provider provided for domain:%s", domain)
- def curve_user_id(self, client_public_key):
+ def curve_user_id(self, client_public_key: bytes) -> str:
"""Return the User-Id corresponding to a CURVE client's public key
Default implementation uses the z85-encoding of the public key.
@@ -191,14 +215,16 @@ def curve_user_id(self, client_public_key):
"""
return z85.encode(client_public_key).decode('ascii')
- def configure_gssapi(self, domain='*', location=None):
+ def configure_gssapi(
+ self, domain: str = '*', location: Optional[str] = None
+ ) -> None:
"""Configure GSSAPI authentication
Currently this is a no-op because there is nothing to configure with GSSAPI.
"""
pass
- def handle_zap_message(self, msg):
+ def handle_zap_message(self, msg: List[bytes]):
"""Perform ZAP authentication"""
if len(msg) < 6:
self.log.error("Invalid ZAP message, not enough frames: %r", msg)
@@ -298,7 +324,9 @@ def handle_zap_message(self, msg):
else:
self._send_zap_reply(request_id, b"400", reason)
- def _authenticate_plain(self, domain, username, password):
+ def _authenticate_plain(
+ self, domain: str, username: str, password: str
+ ) -> Tuple[bool, bytes]:
"""PLAIN ZAP authentication"""
allowed = False
reason = b""
@@ -334,7 +362,7 @@ def _authenticate_plain(self, domain, username, password):
return allowed, reason
- def _authenticate_curve(self, domain, client_key):
+ def _authenticate_curve(self, domain: str, client_key: bytes) -> Tuple[bool, bytes]:
"""CURVE ZAP authentication"""
allowed = False
reason = b""
@@ -391,14 +419,18 @@ def _authenticate_curve(self, domain, client_key):
return allowed, reason
- def _authenticate_gssapi(self, domain, principal):
+ def _authenticate_gssapi(self, domain: str, principal: bytes) -> Tuple[bool, bytes]:
"""Nothing to do for GSSAPI, which has already been handled by an external service."""
self.log.debug("ALLOWED (GSSAPI) domain=%s principal=%s", domain, principal)
return True, b'OK'
def _send_zap_reply(
- self, request_id, status_code, status_text, user_id='anonymous'
- ):
+ self,
+ request_id: bytes,
+ status_code: bytes,
+ status_text: bytes,
+ user_id: str = 'anonymous',
+ ) -> None:
"""Send a ZAP reply to finish the authentication."""
user_id = user_id if status_code == b'200' else b''
if isinstance(user_id, unicode):
diff --git a/zmq/auth/certs.py b/zmq/auth/certs.py
index b56507f52..cda8ba446 100644
--- a/zmq/auth/certs.py
+++ b/zmq/auth/certs.py
@@ -8,32 +8,34 @@
import glob
import io
import os
+from typing import Dict, Optional, Tuple, Union
import zmq
-from zmq.utils.strtypes import b, bytes, u, unicode
-_cert_secret_banner = u(
- """# **** Generated on {0} by pyzmq ****
+_cert_secret_banner = """# **** Generated on {0} by pyzmq ****
# ZeroMQ CURVE **Secret** Certificate
# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions.
"""
-)
-_cert_public_banner = u(
- """# **** Generated on {0} by pyzmq ****
+
+_cert_public_banner = """# **** Generated on {0} by pyzmq ****
# ZeroMQ CURVE Public Certificate
# Exchange securely, or use a secure mechanism to verify the contents
# of this file after exchange. Store public certificates in your home
# directory, in the .curve subdirectory.
"""
-)
def _write_key_file(
- key_filename, banner, public_key, secret_key=None, metadata=None, encoding='utf-8'
-):
+ key_filename: Union[str, os.PathLike],
+ banner: str,
+ public_key: Union[str, bytes],
+ secret_key: Optional[Union[str, bytes]] = None,
+ metadata: Optional[Dict[str, str]] = None,
+ encoding: str = 'utf-8',
+) -> None:
"""Create a certificate file"""
if isinstance(public_key, bytes):
public_key = public_key.decode(encoding)
@@ -42,23 +44,27 @@ def _write_key_file(
with open(key_filename, 'w', encoding='utf8') as f:
f.write(banner.format(datetime.datetime.now()))
- f.write(u('metadata\n'))
+ f.write('metadata\n')
if metadata:
for k, v in metadata.items():
if isinstance(k, bytes):
k = k.decode(encoding)
if isinstance(v, bytes):
v = v.decode(encoding)
- f.write(u(" {0} = {1}\n").format(k, v))
+ f.write(f" {k} = {v}\n")
- f.write(u('curve\n'))
- f.write(u(" public-key = \"{0}\"\n").format(public_key))
+ f.write('curve\n')
+ f.write(f" public-key = \"{public_key}\"\n")
if secret_key:
- f.write(u(" secret-key = \"{0}\"\n").format(secret_key))
+ f.write(f" secret-key = \"{secret_key}\"\n")
-def create_certificates(key_dir, name, metadata=None):
+def create_certificates(
+ key_dir: Union[str, os.PathLike],
+ name: str,
+ metadata: Optional[Dict[str, str]] = None,
+) -> Tuple[str, str]:
"""Create zmq certificates.
Returns the file paths to the public and secret certificate files.
@@ -82,7 +88,9 @@ def create_certificates(key_dir, name, metadata=None):
return public_key_file, secret_key_file
-def load_certificate(filename):
+def load_certificate(
+ filename: Union[str, os.PathLike]
+) -> Tuple[bytes, Optional[bytes]]:
"""Load public and secret key from a zmq certificate.
Returns (public_key, secret_key)
@@ -115,7 +123,7 @@ def load_certificate(filename):
return public_key, secret_key
-def load_certificates(directory='.'):
+def load_certificates(directory: Union[str, os.PathLike] = '.') -> Dict[bytes, bool]:
"""Load public keys from all certificates in a directory"""
certs = {}
if not os.path.isdir(directory):
diff --git a/zmq/auth/ioloop.py b/zmq/auth/ioloop.py
index b063b0cf2..29771cbf2 100644
--- a/zmq/auth/ioloop.py
+++ b/zmq/auth/ioloop.py
@@ -5,9 +5,11 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
+from typing import Any, Optional
from tornado import ioloop
+import zmq
from zmq.eventloop import zmqstream
from .base import Authenticator
@@ -16,12 +18,21 @@
class IOLoopAuthenticator(Authenticator):
"""ZAP authentication for use in the tornado IOLoop"""
- def __init__(self, context=None, encoding='utf-8', log=None, io_loop=None):
+ zap_stream: zmqstream.ZMQStream
+ io_loop: ioloop.IOLoop
+
+ def __init__(
+ self,
+ context: Optional["zmq.Context"] = None,
+ encoding: str = 'utf-8',
+ log: Any = None,
+ io_loop: Optional[ioloop.IOLoop] = None,
+ ):
super().__init__(context, encoding, log)
- self.zap_stream = None
+ self.zap_stream = None # type: ignore
self.io_loop = io_loop or ioloop.IOLoop.current()
- def start(self):
+ def start(self) -> None:
"""Start ZAP authentication"""
super().start()
self.zap_stream = zmqstream.ZMQStream(self.zap_socket, self.io_loop)
diff --git a/zmq/auth/thread.py b/zmq/auth/thread.py
index 876ed51ec..0d42f7b9f 100644
--- a/zmq/auth/thread.py
+++ b/zmq/auth/thread.py
@@ -8,7 +8,9 @@
import logging
import sys
+from itertools import chain
from threading import Event, Thread
+from typing import Any, Dict, List, Optional, TypeVar, cast
import zmq
from zmq.utils import jsonapi
@@ -24,14 +26,19 @@ class AuthenticationThread(Thread):
"""
def __init__(
- self, context, endpoint, encoding='utf-8', log=None, authenticator=None
- ):
+ self,
+ context: "zmq.Context",
+ endpoint: str,
+ encoding: str = 'utf-8',
+ log: Any = None,
+ authenticator: Optional[Authenticator] = None,
+ ) -> None:
super().__init__()
self.context = context or zmq.Context.instance()
self.encoding = encoding
self.log = log = log or logging.getLogger('zmq.auth')
self.started = Event()
- self.authenticator = authenticator or Authenticator(
+ self.authenticator: Authenticator = authenticator or Authenticator(
context, encoding=encoding, log=log
)
@@ -40,7 +47,7 @@ def __init__(
self.pipe.linger = 1
self.pipe.connect(endpoint)
- def run(self):
+ def run(self) -> None:
"""Start the Authentication Agent thread task"""
self.authenticator.start()
self.started.set()
@@ -75,16 +82,18 @@ def run(self):
self.pipe.close()
self.authenticator.stop()
- def _handle_zap(self):
+ def _handle_zap(self) -> None:
"""
Handle a message from the ZAP socket.
"""
+ if self.authenticator.zap_socket is None:
+ raise RuntimeError("ZAP socket closed")
msg = self.authenticator.zap_socket.recv_multipart()
if not msg:
return
self.authenticator.handle_zap_message(msg)
- def _handle_pipe(self, msg):
+ def _handle_pipe(self, msg: List[bytes]) -> bool:
"""
Handle a message from front-end API.
"""
@@ -114,7 +123,10 @@ def _handle_pipe(self, msg):
elif command == b'PLAIN':
domain = u(msg[1], self.encoding)
json_passwords = msg[2]
- self.authenticator.configure_plain(domain, jsonapi.loads(json_passwords))
+ passwords: Dict[str, str] = cast(
+ Dict[str, str], jsonapi.loads(json_passwords)
+ )
+ self.authenticator.configure_plain(domain, passwords)
elif command == b'CURVE':
# For now we don't do anything with domains
@@ -134,7 +146,10 @@ def _handle_pipe(self, msg):
return terminate
-def _inherit_docstrings(cls):
+T = TypeVar("T", bound=type)
+
+
+def _inherit_docstrings(cls: T) -> T:
"""inherit docstrings from Authenticator, so we don't duplicate them"""
for name, method in cls.__dict__.items():
if name.startswith('_') or not callable(method):
@@ -149,59 +164,64 @@ def _inherit_docstrings(cls):
class ThreadAuthenticator:
"""Run ZAP authentication in a background thread"""
- context = None
- log = None
- encoding = None
- pipe = None
- pipe_endpoint = ''
- thread = None
- auth = None
+ context: "zmq.Context"
+ log: Any
+ encoding: str
+ pipe: "zmq.Socket"
+ pipe_endpoint: str = ''
+ thread: AuthenticationThread
- def __init__(self, context=None, encoding='utf-8', log=None):
- self.context = context or zmq.Context.instance()
+ def __init__(
+ self,
+ context: Optional["zmq.Context"] = None,
+ encoding: str = 'utf-8',
+ log: Any = None,
+ ):
self.log = log
self.encoding = encoding
- self.pipe = None
+ self.pipe = None # type: ignore
self.pipe_endpoint = f"inproc://{id(self)}.inproc"
- self.thread = None
+ self.thread = None # type: ignore
+ self.context = context or zmq.Context.instance()
# proxy base Authenticator attributes
- def __setattr__(self, key, value):
- for obj in [self] + self.__class__.mro():
- if key in obj.__dict__:
+ def __setattr__(self, key: str, value: Any):
+ for obj in chain([self], self.__class__.mro()):
+ if key in obj.__dict__ or (key in getattr(obj, "__annotations__", {})):
object.__setattr__(self, key, value)
return
setattr(self.thread.authenticator, key, value)
- def __getattr__(self, key):
- try:
- object.__getattr__(self, key)
- except AttributeError:
- return getattr(self.thread.authenticator, key)
+ def __getattr__(self, key: str):
+ return getattr(self.thread.authenticator, key)
- def allow(self, *addresses):
+ def allow(self, *addresses: str):
self.pipe.send_multipart([b'ALLOW'] + [b(a, self.encoding) for a in addresses])
- def deny(self, *addresses):
+ def deny(self, *addresses: str):
self.pipe.send_multipart([b'DENY'] + [b(a, self.encoding) for a in addresses])
- def configure_plain(self, domain='*', passwords=None):
+ def configure_plain(
+ self, domain: str = '*', passwords: Optional[Dict[str, str]] = None
+ ):
self.pipe.send_multipart(
[b'PLAIN', b(domain, self.encoding), jsonapi.dumps(passwords or {})]
)
- def configure_curve(self, domain='*', location=''):
+ def configure_curve(self, domain: str = '*', location: str = ''):
domain = b(domain, self.encoding)
location = b(location, self.encoding)
self.pipe.send_multipart([b'CURVE', domain, location])
- def configure_curve_callback(self, domain='*', credentials_provider=None):
+ def configure_curve_callback(
+ self, domain: str = '*', credentials_provider: Any = None
+ ):
self.thread.authenticator.configure_curve_callback(
domain, credentials_provider=credentials_provider
)
- def start(self):
+ def start(self) -> None:
"""Start the authentication thread"""
# create a socket to communicate with auth thread.
self.pipe = self.context.socket(zmq.PAIR)
@@ -214,23 +234,23 @@ def start(self):
if not self.thread.started.wait(timeout=10):
raise RuntimeError("Authenticator thread failed to start")
- def stop(self):
+ def stop(self) -> None:
"""Stop the authentication thread"""
if self.pipe:
self.pipe.send(b'TERMINATE')
if self.is_alive():
self.thread.join()
- self.thread = None
+ self.thread = None # type: ignore
self.pipe.close()
- self.pipe = None
+ self.pipe = None # type: ignore
- def is_alive(self):
+ def is_alive(self) -> bool:
"""Is the ZAP thread currently running?"""
if self.thread and self.thread.is_alive():
return True
return False
- def __del__(self):
+ def __del__(self) -> None:
self.stop()
diff --git a/zmq/backend/__init__.pyi b/zmq/backend/__init__.pyi
index 1b141f157..4c699a8fe 100644
--- a/zmq/backend/__init__.pyi
+++ b/zmq/backend/__init__.pyi
@@ -1,44 +1,92 @@
-from typing import Any, ByteString, List, Optional, Set, Tuple, Union
+from typing import Any, List, Optional, Set, Tuple, TypeVar, Union, overload
+
+from typing_extensions import Literal
+
+import zmq
from .select import select_backend
+# avoid collision in Frame.bytes
+_bytestr = bytes
+
+T = TypeVar("T")
+
class Frame:
buffer: Any
- bytes: ByteString
+ bytes: bytes
more: bool
tracker: Any
- def copy_fast(self) -> Frame: ...
- def get(self, option: int) -> Union[int, ByteString, str]: ...
- def set(self, option: int, value: Union[int, ByteString, str]) -> None: ...
+ def __init__(
+ self,
+ data: Any = None,
+ track: bool = False,
+ copy: Optional[bool] = None,
+ copy_threshold: Optional[int] = None,
+ ): ...
+ def copy_fast(self: T) -> T: ...
+ def get(self, option: int) -> Union[int, _bytestr, str]: ...
+ def set(self, option: int, value: Union[int, _bytestr, str]) -> None: ...
class Socket:
underlying: int
+ context: "zmq.Context"
+ copy_threshold: int
+
+ # specific option types
+ FD: int
def close(self, linger: Optional[int] = ...) -> None: ...
- def get(self, option: int) -> Union[int, ByteString, str]: ...
- def set(self, option: int, value: Union[int, ByteString, str]) -> None: ...
- def connect(self, url: str) -> None: ...
+ def get(self, option: int) -> Union[int, bytes, str]: ...
+ def set(self, option: int, value: Union[int, bytes, str]) -> None: ...
+ def connect(self, url: str): ...
def disconnect(self, url: str) -> None: ...
- def bind(self, url: str) -> None: ...
+ def bind(self, url: str): ...
def unbind(self, url: str) -> None: ...
def send(
self,
data: Any,
- flags: Optional[int] = ...,
- copy: Optional[bool] = ...,
- track: Optional[bool] = ...,
- ) -> Optional[Frame]: ...
+ flags: int = ...,
+ copy: bool = ...,
+ track: bool = ...,
+ ) -> Optional["zmq.MessageTracker"]: ...
+ @overload
+ def recv(
+ self,
+ flags: int = ...,
+ *,
+ copy: Literal[False],
+ track: bool = ...,
+ ) -> "zmq.Frame": ...
+ @overload
+ def recv(
+ self,
+ flags: int = ...,
+ *,
+ copy: Literal[True],
+ track: bool = ...,
+ ) -> bytes: ...
+ @overload
+ def recv(
+ self,
+ flags: int = ...,
+ track: bool = False,
+ ) -> bytes: ...
+ @overload
def recv(
self,
flags: Optional[int] = ...,
- copy: Optional[bool] = ...,
- track: Optional[bool] = ...,
- ) -> Union[Frame, ByteString]: ...
+ copy: bool = ...,
+ track: Optional[bool] = False,
+ ) -> Union["zmq.Frame", bytes]: ...
+ def monitor(self, addr: Optional[str], events: int) -> None: ...
+ # draft methods
+ def join(self, group: str) -> None: ...
+ def leave(self, group: str) -> None: ...
class Context:
underlying: int
- def __init__(self, io_threads: int = 1, **kwargs): ...
- def get(self, option: int) -> Union[int, ByteString, str]: ...
- def set(self, option: int, value: Union[int, ByteString, str]) -> None: ...
+ def __init__(self, io_threads: int = 1, shadow: Any = None): ...
+ def get(self, option: int) -> Union[int, bytes, str]: ...
+ def set(self, option: int, value: Union[int, bytes, str]) -> None: ...
def socket(self, socket_type: int) -> Socket: ...
def term(self) -> None: ...
diff --git a/zmq/backend/cython/constant_enums.pxi b/zmq/backend/cython/constant_enums.pxi
index 607f5971b..adf019cb1 100644
--- a/zmq/backend/cython/constant_enums.pxi
+++ b/zmq/backend/cython/constant_enums.pxi
@@ -96,6 +96,7 @@ cdef extern from "zmq.h" nogil:
enum: ZMQ_PLAIN
enum: ZMQ_CURVE
enum: ZMQ_GSSAPI
+ enum: ZMQ_HWM
enum: ZMQ_AFFINITY
enum: ZMQ_ROUTING_ID
enum: ZMQ_SUBSCRIBE
diff --git a/zmq/backend/cython/context.pyx b/zmq/backend/cython/context.pyx
index 004e1cebc..b12cd2473 100644
--- a/zmq/backend/cython/context.pyx
+++ b/zmq/backend/cython/context.pyx
@@ -28,12 +28,12 @@ cdef class Context:
io_threads : int
The number of IO threads.
"""
-
+
# no-op for the signature
def __init__(self, io_threads=1, shadow=0):
pass
-
- def __cinit__(self, int io_threads=1, size_t shadow=0, **kwargs):
+
+ def __cinit__(self, int io_threads=1, size_t shadow=0):
self.handle = NULL
if shadow:
self.handle = shadow
@@ -68,12 +68,12 @@ cdef class Context:
rc = zmq_ctx_destroy(self.handle)
self.handle = NULL
return rc
-
+
def term(self):
"""ctx.term()
Close or terminate the context.
-
+
This can be called to close the context by hand. If this is not called,
the context will automatically be closed when it is garbage collected.
"""
@@ -85,9 +85,9 @@ cdef class Context:
# ignore interrupted term
# see PEP 475 notes about close & EINTR for why
pass
-
+
self.closed = True
-
+
def set(self, int option, optval):
"""ctx.set(option, optval)
@@ -95,7 +95,7 @@ cdef class Context:
See the 0MQ API documentation for zmq_ctx_set
for details on specific options.
-
+
.. versionadded:: libzmq-3.2
.. versionadded:: 13.0
@@ -104,9 +104,9 @@ cdef class Context:
option : int
The option to set. Available values will depend on your
version of libzmq. Examples include::
-
+
zmq.IO_THREADS, zmq.MAX_SOCKETS
-
+
optval : int
The value of the option to set.
"""
@@ -116,7 +116,7 @@ cdef class Context:
if self.closed:
raise RuntimeError("Context has been destroyed")
-
+
if not isinstance(optval, int):
raise TypeError('expected int, got: %r' % optval)
optval_int_c = optval
@@ -130,7 +130,7 @@ cdef class Context:
See the 0MQ API documentation for zmq_ctx_get
for details on specific options.
-
+
.. versionadded:: libzmq-3.2
.. versionadded:: 13.0
@@ -139,9 +139,9 @@ cdef class Context:
option : int
The option to get. Available values will depend on your
version of libzmq. Examples include::
-
+
zmq.IO_THREADS, zmq.MAX_SOCKETS
-
+
Returns
-------
optval : int
diff --git a/zmq/constants.py b/zmq/constants.py
index 77748fd1d..b76a3bb15 100644
--- a/zmq/constants.py
+++ b/zmq/constants.py
@@ -121,6 +121,8 @@ class _OptType(Enum):
class SocketOption(IntEnum):
"""Options for Socket.get/set"""
+ _opt_type: str
+
def __new__(cls, value, opt_type=_OptType.int):
"""Attach option type as `._opt_type`"""
obj = int.__new__(cls, value)
@@ -128,6 +130,7 @@ def __new__(cls, value, opt_type=_OptType.int):
obj._opt_type = opt_type
return obj
+ HWM = 1
AFFINITY = 4, _OptType.int64
ROUTING_ID = 5, _OptType.bytes
SUBSCRIBE = 6, _OptType.bytes
@@ -451,6 +454,7 @@ class DeviceType(IntEnum):
PLAIN: int = SecurityMechanism.PLAIN
CURVE: int = SecurityMechanism.CURVE
GSSAPI: int = SecurityMechanism.GSSAPI
+HWM: int = SocketOption.HWM
AFFINITY: int = SocketOption.AFFINITY
ROUTING_ID: int = SocketOption.ROUTING_ID
SUBSCRIBE: int = SocketOption.SUBSCRIBE
@@ -686,6 +690,7 @@ class DeviceType(IntEnum):
"CURVE",
"GSSAPI",
"SocketOption",
+ "HWM",
"AFFINITY",
"ROUTING_ID",
"SUBSCRIBE",
diff --git a/zmq/devices/basedevice.py b/zmq/devices/basedevice.py
index ee039e14e..aa29f6893 100644
--- a/zmq/devices/basedevice.py
+++ b/zmq/devices/basedevice.py
@@ -7,8 +7,9 @@
import time
from multiprocessing import Process
from threading import Thread
-from typing import Any
+from typing import Any, List, Optional, Tuple
+import zmq
from zmq import ETERM, QUEUE, REQ, Context, ZMQBindError, ZMQError, device
@@ -68,7 +69,24 @@ class Device:
depending on whether the device should share the global instance or not.
"""
- def __init__(self, device_type=QUEUE, in_type=None, out_type=None):
+ device_type: int
+ in_type: int
+ out_type: int
+
+ _in_binds: List[str]
+ _in_connects: List[str]
+ _in_sockopts: List[Tuple[int, Any]]
+ _out_binds: List[str]
+ _out_connects: List[str]
+ _out_sockopts: List[Tuple[int, Any]]
+ _random_addrs: List[str]
+
+ def __init__(
+ self,
+ device_type: int = QUEUE,
+ in_type: Optional[int] = None,
+ out_type: Optional[int] = None,
+ ) -> None:
self.device_type = device_type
if in_type is None:
raise TypeError("in_type must be specified")
@@ -86,14 +104,14 @@ def __init__(self, device_type=QUEUE, in_type=None, out_type=None):
self.daemon = True
self.done = False
- def bind_in(self, addr):
+ def bind_in(self, addr: str) -> None:
"""Enqueue ZMQ address for binding on in_socket.
See zmq.Socket.bind for details.
"""
self._in_binds.append(addr)
- def bind_in_to_random_port(self, addr, *args, **kwargs):
+ def bind_in_to_random_port(self, addr: str, *args, **kwargs) -> int:
"""Enqueue a random port on the given interface for binding on
in_socket.
@@ -107,28 +125,28 @@ def bind_in_to_random_port(self, addr, *args, **kwargs):
return port
- def connect_in(self, addr):
+ def connect_in(self, addr: str) -> None:
"""Enqueue ZMQ address for connecting on in_socket.
See zmq.Socket.connect for details.
"""
self._in_connects.append(addr)
- def setsockopt_in(self, opt, value):
+ def setsockopt_in(self, opt: int, value: Any) -> None:
"""Enqueue setsockopt(opt, value) for in_socket
See zmq.Socket.setsockopt for details.
"""
self._in_sockopts.append((opt, value))
- def bind_out(self, addr):
+ def bind_out(self, addr: str) -> None:
"""Enqueue ZMQ address for binding on out_socket.
See zmq.Socket.bind for details.
"""
self._out_binds.append(addr)
- def bind_out_to_random_port(self, addr, *args, **kwargs):
+ def bind_out_to_random_port(self, addr: str, *args, **kwargs) -> int:
"""Enqueue a random port on the given interface for binding on
out_socket.
@@ -142,21 +160,21 @@ def bind_out_to_random_port(self, addr, *args, **kwargs):
return port
- def connect_out(self, addr):
+ def connect_out(self, addr: str):
"""Enqueue ZMQ address for connecting on out_socket.
See zmq.Socket.connect for details.
"""
self._out_connects.append(addr)
- def setsockopt_out(self, opt, value):
+ def setsockopt_out(self, opt: int, value: Any):
"""Enqueue setsockopt(opt, value) for out_socket
See zmq.Socket.setsockopt for details.
"""
self._out_sockopts.append((opt, value))
- def _reserve_random_port(self, addr, *args, **kwargs):
+ def _reserve_random_port(self, addr: str, *args, **kwargs) -> int:
ctx = Context()
binder = ctx.socket(REQ)
@@ -179,9 +197,8 @@ def _reserve_random_port(self, addr, *args, **kwargs):
return port
- def _setup_sockets(self):
- ctx = self.context_factory()
-
+ def _setup_sockets(self) -> Tuple[zmq.Socket, zmq.Socket]:
+ ctx: zmq.Context[zmq.Socket] = self.context_factory() # type: ignore
self._context = ctx
# create the sockets
@@ -209,7 +226,7 @@ def _setup_sockets(self):
return ins, outs
- def run_device(self):
+ def run_device(self) -> None:
"""The runner method.
Do not call me directly, instead call ``self.start()``, just like a Thread.
@@ -217,7 +234,7 @@ def run_device(self):
ins, outs = self._setup_sockets()
device(self.device_type, ins, outs)
- def run(self):
+ def run(self) -> None:
"""wrap run_device in try/catch ETERM"""
try:
self.run_device()
@@ -230,11 +247,11 @@ def run(self):
finally:
self.done = True
- def start(self):
+ def start(self) -> None:
"""Start the device. Override me in subclass for other launchers."""
return self.run()
- def join(self, timeout=None):
+ def join(self, timeout: Optional[float] = None) -> None:
"""wait for me to finish, like Thread.join.
Reimplemented appropriately by subclasses."""
@@ -251,12 +268,12 @@ class BackgroundDevice(Device):
launcher: Any = None
_launch_class: Any = None
- def start(self):
+ def start(self) -> None:
self.launcher = self._launch_class(target=self.run)
self.launcher.daemon = self.daemon
return self.launcher.start()
- def join(self, timeout=None):
+ def join(self, timeout: Optional[float] = None) -> None:
return self.launcher.join(timeout=timeout)
diff --git a/zmq/error.py b/zmq/error.py
index f5522bc3c..94603fb8c 100644
--- a/zmq/error.py
+++ b/zmq/error.py
@@ -4,7 +4,7 @@
# Distributed under the terms of the Modified BSD License.
from errno import EINTR
-from typing import Optional, Tuple
+from typing import Optional, Tuple, Union
class ZMQBaseError(Exception):
@@ -188,7 +188,10 @@ def __str__(self):
)
-def _check_version(min_version_info: Tuple[int, int, int], msg: str = "Feature"):
+def _check_version(
+ min_version_info: Union[Tuple[int], Tuple[int, int], Tuple[int, int, int]],
+ msg: str = "Feature",
+):
"""Check for libzmq
raises ZMQVersionError if current zmq version is not at least min_version
diff --git a/zmq/eventloop/future.py b/zmq/eventloop/future.py
index 8604fd56b..051ab6580 100644
--- a/zmq/eventloop/future.py
+++ b/zmq/eventloop/future.py
@@ -9,7 +9,9 @@
# Copyright (c) PyZMQ Developers.
# Distributed under the terms of the Modified BSD License.
+import asyncio
import warnings
+from typing import Any, Type
from tornado.concurrent import Future
from tornado.ioloop import IOLoop
@@ -48,7 +50,7 @@ def cancel(self):
class _AsyncTornado:
- _Future = _TornadoFuture
+ _Future: Type[asyncio.Future] = _TornadoFuture
_READ = IOLoop.READ
_WRITE = IOLoop.WRITE
@@ -79,7 +81,7 @@ class Socket(_AsyncTornado, _AsyncSocket):
Poller._socket_class = Socket
-class Context(_zmq.Context):
+class Context(_zmq.Context[Socket]):
# avoid sharing instance with base Context class
_instance = None
@@ -90,7 +92,7 @@ class Context(_zmq.Context):
def _socket_class(self, socket_type):
return Socket(self, socket_type)
- def __init__(self, *args, **kwargs):
+ def __init__(self: "Context", *args: Any, **kwargs: Any) -> None:
io_loop = kwargs.pop('io_loop', None)
if io_loop is not None:
warnings.warn(
diff --git a/zmq/eventloop/zmqstream.py b/zmq/eventloop/zmqstream.py
index 7f76f4d7c..2e14e6106 100644
--- a/zmq/eventloop/zmqstream.py
+++ b/zmq/eventloop/zmqstream.py
@@ -27,13 +27,16 @@
import sys
import warnings
from queue import Queue
+from typing import Any, Callable, List, Optional, Sequence, Union, cast, overload
import zmq
+from zmq._typing import Literal
from zmq.utils import jsonapi
from .ioloop import IOLoop, gen_log
try:
+ import tornado.ioloop
from tornado.stack_context import wrap as stack_context_wrap # type: ignore
except ImportError:
if "zmq.eventloop.minitornado" in sys.modules:
@@ -84,23 +87,25 @@ class ZMQStream:
"""
- socket = None
- io_loop = None
- poller = None
- _send_queue = None
- _recv_callback = None
- _send_callback = None
- _close_callback = None
- _state = 0
- _flushed = False
- _recv_copy = False
- _fd = None
-
- def __init__(self, socket, io_loop=None):
+ socket: zmq.Socket
+ io_loop: "tornado.ioloop.IOLoop"
+ poller: zmq.Poller
+ _send_queue: Queue
+ _recv_callback: Optional[Callable]
+ _send_callback: Optional[Callable]
+ _close_callback = Optional[Callable]
+ _state: int = 0
+ _flushed: bool = False
+ _recv_copy: bool = False
+ _fd: int
+
+ def __init__(
+ self, socket: "zmq.Socket", io_loop: Optional["tornado.ioloop.IOLoop"] = None
+ ):
self.socket = socket
self.io_loop = io_loop or IOLoop.current()
self.poller = zmq.Poller()
- self._fd = self.socket.FD
+ self._fd = cast(int, self.socket.FD)
self._send_queue = Queue()
self._recv_callback = None
@@ -135,11 +140,52 @@ def stop_on_err(self):
"""DEPRECATED, does nothing"""
gen_log.warn("on_err does nothing, and will be removed")
- def on_err(self, callback):
+ def on_err(self, callback: Callable):
"""DEPRECATED, does nothing"""
gen_log.warn("on_err does nothing, and will be removed")
- def on_recv(self, callback, copy=True):
+ @overload
+ def on_recv(
+ self,
+ callback: Callable[[List[bytes]], Any],
+ ) -> None:
+ ...
+
+ @overload
+ def on_recv(
+ self,
+ callback: Callable[[List[bytes]], Any],
+ copy: Literal[True],
+ ) -> None:
+ ...
+
+ @overload
+ def on_recv(
+ self,
+ callback: Callable[[List[zmq.Frame]], Any],
+ copy: Literal[False],
+ ) -> None:
+ ...
+
+ @overload
+ def on_recv(
+ self,
+ callback: Union[
+ Callable[[List[zmq.Frame]], Any],
+ Callable[[List[bytes]], Any],
+ ],
+ copy: bool = ...,
+ ):
+ ...
+
+ def on_recv(
+ self,
+ callback: Union[
+ Callable[[List[zmq.Frame]], Any],
+ Callable[[List[bytes]], Any],
+ ],
+ copy: bool = True,
+ ) -> None:
"""Register a callback for when a message is ready to recv.
There can be only one callback registered at a time, so each
@@ -174,7 +220,48 @@ def on_recv(self, callback, copy=True):
else:
self._add_io_state(zmq.POLLIN)
- def on_recv_stream(self, callback, copy=True):
+ @overload
+ def on_recv_stream(
+ self,
+ callback: Callable[["ZMQStream", List[bytes]], Any],
+ ) -> None:
+ ...
+
+ @overload
+ def on_recv_stream(
+ self,
+ callback: Callable[["ZMQStream", List[bytes]], Any],
+ copy: Literal[True],
+ ) -> None:
+ ...
+
+ @overload
+ def on_recv_stream(
+ self,
+ callback: Callable[["ZMQStream", List[zmq.Frame]], Any],
+ copy: Literal[False],
+ ) -> None:
+ ...
+
+ @overload
+ def on_recv_stream(
+ self,
+ callback: Union[
+ Callable[["ZMQStream", List[zmq.Frame]], Any],
+ Callable[["ZMQStream", List[bytes]], Any],
+ ],
+ copy: bool = ...,
+ ):
+ ...
+
+ def on_recv_stream(
+ self,
+ callback: Union[
+ Callable[["ZMQStream", List[zmq.Frame]], Any],
+ Callable[["ZMQStream", List[bytes]], Any],
+ ],
+ copy: bool = True,
+ ):
"""Same as on_recv, but callback will get this stream as first argument
callback must take exactly two arguments, as it will be called as::
@@ -186,9 +273,15 @@ def on_recv_stream(self, callback, copy=True):
if callback is None:
self.stop_on_recv()
else:
- self.on_recv(lambda msg: callback(self, msg), copy=copy)
- def on_send(self, callback):
+ def stream_callback(msg):
+ return callback(self, msg)
+
+ self.on_recv(stream_callback, copy=copy)
+
+ def on_send(
+ self, callback: Callable[[Sequence[Any], Optional[zmq.MessageTracker]], Any]
+ ):
"""Register a callback to be called on each send
There will be two arguments::
@@ -228,7 +321,12 @@ def on_send(self, callback):
assert callback is None or callable(callback)
self._send_callback = stack_context_wrap(callback)
- def on_send_stream(self, callback):
+ def on_send_stream(
+ self,
+ callback: Callable[
+ ["ZMQStream", Sequence[Any], Optional[zmq.MessageTracker]], Any
+ ],
+ ):
"""Same as on_send, but callback will get this stream as first argument
Callback will be passed three arguments::
@@ -251,8 +349,14 @@ def send(self, msg, flags=0, copy=True, track=False, callback=None, **kwargs):
)
def send_multipart(
- self, msg, flags=0, copy=True, track=False, callback=None, **kwargs
- ):
+ self,
+ msg: Sequence[Any],
+ flags: int = 0,
+ copy: bool = True,
+ track: bool = False,
+ callback: Callable = None,
+ **kwargs: Any
+ ) -> None:
"""Send a multipart message, optionally also register a new callback for sends.
See zmq.socket.send_multipart for details.
"""
@@ -266,7 +370,14 @@ def send_multipart(
self.on_send(lambda *args: None)
self._add_io_state(zmq.POLLOUT)
- def send_string(self, u, flags=0, encoding='utf-8', callback=None, **kwargs):
+ def send_string(
+ self,
+ u: str,
+ flags: int = 0,
+ encoding: str = 'utf-8',
+ callback: Optional[Callable] = None,
+ **kwargs: Any
+ ):
"""Send a unicode message with an encoding.
See zmq.socket.send_unicode for details.
"""
@@ -276,14 +387,27 @@ def send_string(self, u, flags=0, encoding='utf-8', callback=None, **kwargs):
send_unicode = send_string
- def send_json(self, obj, flags=0, callback=None, **kwargs):
+ def send_json(
+ self,
+ obj: Any,
+ flags: int = 0,
+ callback: Optional[Callable] = None,
+ **kwargs: Any
+ ):
"""Send json-serialized version of an object.
See zmq.socket.send_json for details.
"""
msg = jsonapi.dumps(obj)
return self.send(msg, flags=flags, callback=callback, **kwargs)
- def send_pyobj(self, obj, flags=0, protocol=-1, callback=None, **kwargs):
+ def send_pyobj(
+ self,
+ obj: Any,
+ flags: int = 0,
+ protocol: int = -1,
+ callback: Optional[Callable] = None,
+ **kwargs: Any
+ ):
"""Send a Python object as a message using pickle to serialize.
See zmq.socket.send_json for details.
@@ -295,7 +419,7 @@ def _finish_flush(self):
"""callback for unsetting _flushed flag."""
self._flushed = False
- def flush(self, flag=zmq.POLLIN | zmq.POLLOUT, limit=None):
+ def flush(self, flag: int = zmq.POLLIN | zmq.POLLOUT, limit: Optional[int] = None):
"""Flush pending messages.
This method safely handles all pending incoming and/or outgoing messages,
@@ -379,11 +503,11 @@ def update_flag():
self._rebuild_io_state()
return count
- def set_close_callback(self, callback):
+ def set_close_callback(self, callback: Optional[Callable]):
"""Call the given callback when the stream is closed."""
self._close_callback = stack_context_wrap(callback)
- def close(self, linger=None):
+ def close(self, linger: Optional[int] = None) -> None:
"""Close this stream."""
if self.socket is not None:
if self.socket.closed:
@@ -401,19 +525,19 @@ def close(self, linger=None):
else:
self.io_loop.remove_handler(self.socket)
self.socket.close(linger)
- self.socket = None
+ self.socket = None # type: ignore
if self._close_callback:
self._run_callback(self._close_callback)
- def receiving(self):
+ def receiving(self) -> bool:
"""Returns True if we are currently receiving from the stream."""
return self._recv_callback is not None
- def sending(self):
+ def sending(self) -> bool:
"""Returns True if we are currently sending to the stream."""
return not self._send_queue.empty()
- def closed(self):
+ def closed(self) -> bool:
if self.socket is None:
return True
if self.socket.closed:
@@ -421,6 +545,7 @@ def closed(self):
# trigger our cleanup
self.close()
return True
+ return False
def _run_callback(self, callback, *args, **kwargs):
"""Wrap running callbacks in try/except to allow us to
diff --git a/zmq/green/core.py b/zmq/green/core.py
index 4d865e83f..d8358c4a0 100644
--- a/zmq/green/core.py
+++ b/zmq/green/core.py
@@ -307,7 +307,7 @@ def set(self, opt, val):
return super().set(opt, val)
-class _Context(_original_Context):
+class _Context(_original_Context[_Socket]):
"""Replacement for :class:`zmq.Context`
Ensures that the greened Socket above is used in calls to `socket`.
diff --git a/zmq/log/handlers.py b/zmq/log/handlers.py
index c42ffe97d..5b2cba888 100644
--- a/zmq/log/handlers.py
+++ b/zmq/log/handlers.py
@@ -20,15 +20,13 @@
http://github.com/jtriley/StarCluster/blob/master/starcluster/logger.py
"""
+import logging
+
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
-
-
-import logging
-from logging import DEBUG, ERROR, FATAL, INFO, WARN
+from typing import Optional, Union
import zmq
-from zmq.utils.strtypes import bytes, cast_bytes, unicode
TOPIC_DELIM = "::" # delimiter for splitting topics on the receiving end.
@@ -56,11 +54,17 @@ class PUBHandler(logging.Handler):
message by: log.debug("subtopic.subsub::the real message")
"""
- socket = None
+ ctx: zmq.Context
+ socket: zmq.Socket
- def __init__(self, interface_or_socket, context=None, root_topic=''):
+ def __init__(
+ self,
+ interface_or_socket: Union[str, zmq.Socket],
+ context: Optional[zmq.Context] = None,
+ root_topic: str = '',
+ ) -> None:
logging.Handler.__init__(self)
- self._root_topic = root_topic
+ self.root_topic = root_topic
self.formatters = {
logging.DEBUG: logging.Formatter(
"%(levelname)s %(filename)s:%(lineno)d - %(message)s\n"
@@ -85,14 +89,14 @@ def __init__(self, interface_or_socket, context=None, root_topic=''):
self.socket.bind(interface_or_socket)
@property
- def root_topic(self):
+ def root_topic(self) -> str:
return self._root_topic
@root_topic.setter
- def root_topic(self, value):
+ def root_topic(self, value: str):
self.setRootTopic(value)
- def setRootTopic(self, root_topic):
+ def setRootTopic(self, root_topic: str):
"""Set the root topic for this handler.
This value is prepended to all messages published by this handler, and it
@@ -105,6 +109,8 @@ def setRootTopic(self, root_topic):
the binary representation of the log level string (INFO, WARN, etc.).
Note that ZMQ SUB sockets can have multiple subscriptions.
"""
+ if isinstance(root_topic, bytes):
+ root_topic = root_topic.decode("utf8")
self._root_topic = root_topic
def setFormatter(self, fmt, level=logging.NOTSET):
@@ -125,12 +131,13 @@ def format(self, record):
def emit(self, record):
"""Emit a log message on my socket."""
+
try:
topic, record.msg = record.msg.split(TOPIC_DELIM, 1)
- except Exception:
+ except ValueError:
topic = ""
try:
- bmsg = cast_bytes(self.format(record))
+ bmsg = self.format(record).encode("utf8")
except Exception:
self.handleError(record)
return
@@ -145,7 +152,7 @@ def emit(self, record):
if topic:
topic_list.append(topic)
- btopic = b'.'.join(cast_bytes(t) for t in topic_list)
+ btopic = '.'.join(topic_list).encode("utf8")
self.socket.send_multipart([btopic, bmsg])
diff --git a/zmq/ssh/tunnel.py b/zmq/ssh/tunnel.py
index bb708028c..876e1c035 100644
--- a/zmq/ssh/tunnel.py
+++ b/zmq/ssh/tunnel.py
@@ -17,7 +17,6 @@
import warnings
from getpass import getpass, getuser
from multiprocessing import Process
-from typing import Type
try:
with warnings.catch_warnings():
diff --git a/zmq/sugar/attrsettr.py b/zmq/sugar/attrsettr.py
index 7f5d11ac7..5f7f232e9 100644
--- a/zmq/sugar/attrsettr.py
+++ b/zmq/sugar/attrsettr.py
@@ -4,12 +4,16 @@
# Distributed under the terms of the Modified BSD License.
import errno
+from typing import Generic, TypeVar, Union
from .. import constants
+T = TypeVar("T")
+OptValT = Union[str, bytes, int]
-class AttributeSetter:
- def __setattr__(self, key, value):
+
+class AttributeSetter(Generic[T]):
+ def __setattr__(self, key: str, value: OptValT) -> None:
"""set zmq options by attribute"""
if key in self.__dict__:
@@ -31,11 +35,11 @@ def __setattr__(self, key, value):
else:
self._set_attr_opt(upper_key, opt, value)
- def _set_attr_opt(self, name, opt, value):
+ def _set_attr_opt(self, name: str, opt: int, value: OptValT) -> None:
"""override if setattr should do something other than call self.set"""
self.set(opt, value)
- def __getattr__(self, key):
+ def __getattr__(self, key: str) -> OptValT:
"""get zmq options by attribute"""
upper_key = key.upper()
try:
@@ -58,9 +62,15 @@ def __getattr__(self, key):
else:
raise
- def _get_attr_opt(self, name, opt):
+ def _get_attr_opt(self, name, opt) -> OptValT:
"""override if getattr should do something other than call self.get"""
return self.get(opt)
+ def get(self, opt: Union[T, int]) -> OptValT:
+ pass
+
+ def set(self, opt: Union[T, int], val: OptValT) -> None:
+ pass
+
__all__ = ['AttributeSetter']
diff --git a/zmq/sugar/context.py b/zmq/sugar/context.py
index 7356d946c..6821b0059 100644
--- a/zmq/sugar/context.py
+++ b/zmq/sugar/context.py
@@ -7,32 +7,32 @@
import os
import warnings
from threading import Lock
-from typing import Any, Dict, Optional, Type, TypeVar
+from typing import Any, Dict, Generic, List, Optional, Type, TypeVar
from weakref import WeakSet
from zmq.backend import Context as ContextBase
from zmq.constants import ContextOption, Errno, SocketOption
from zmq.error import ZMQError
-from .attrsettr import AttributeSetter
+from .attrsettr import AttributeSetter, OptValT
from .socket import Socket
# notice when exiting, to avoid triggering term on exit
_exiting = False
-def _notice_atexit():
+def _notice_atexit() -> None:
global _exiting
_exiting = True
atexit.register(_notice_atexit)
-
T = TypeVar('T', bound='Context')
+ST = TypeVar('ST', bound='Socket', covariant=True)
-class Context(ContextBase, AttributeSetter):
+class Context(ContextBase, AttributeSetter, Generic[ST]):
"""Create a zmq Context
A zmq Context creates sockets via its ``ctx.socket`` method.
@@ -44,8 +44,10 @@ class Context(ContextBase, AttributeSetter):
_instance_pid: Optional[int] = None
_shadow = False
_sockets: WeakSet
+ # mypy doesn't like a default value here
+ _socket_class: Type[ST] = Socket # type: ignore
- def __init__(self, io_threads: int = 1, **kwargs):
+ def __init__(self: "Context[Socket]", io_threads: int = 1, **kwargs: Any) -> None:
super().__init__(io_threads=io_threads, **kwargs)
if kwargs.get('shadow', False):
self._shadow = True
@@ -54,7 +56,7 @@ def __init__(self, io_threads: int = 1, **kwargs):
self.sockopts = {}
self._sockets = WeakSet()
- def __del__(self):
+ def __del__(self) -> None:
"""deleting a Context should terminate it, without trying non-threadsafe destroy"""
# Calling locals() here conceals issue #1167 on Windows CPython 3.5.4.
@@ -71,7 +73,7 @@ def __del__(self):
_repr_cls = "zmq.Context"
- def __repr__(self):
+ def __repr__(self) -> str:
cls = self.__class__
# look up _repr_cls on exact class, not inherited
_repr_cls = cls.__dict__.get("_repr_cls", None)
@@ -87,20 +89,20 @@ def __repr__(self):
sockets = ""
return f"<{_repr_cls}({sockets}) at {hex(id(self))}{closed}>"
- def __enter__(self):
+ def __enter__(self: T) -> T:
return self
- def __exit__(self, *args, **kwargs):
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
self.term()
- def __copy__(self, memo=None):
+ def __copy__(self: T, memo: Any = None) -> T:
"""Copying a Context creates a shadow copy"""
return self.__class__.shadow(self.underlying)
__deepcopy__ = __copy__
@classmethod
- def shadow(cls, address):
+ def shadow(cls: Type[T], address: int) -> T:
"""Shadow an existing libzmq context
address is the integer address of the libzmq context
@@ -131,7 +133,7 @@ def shadow_pyczmq(cls: Type[T], ctx: Any) -> T:
# static method copied from tornado IOLoop.instance
@classmethod
- def instance(cls: Type[T], io_threads=1) -> T:
+ def instance(cls: Type[T], io_threads: int = 1) -> T:
"""Returns a global Context instance.
Most single-threaded applications have a single, global Context.
@@ -193,7 +195,7 @@ def term(self) -> None:
# Hooks for ctxopt completion
# -------------------------------------------------------------------------
- def __dir__(self):
+ def __dir__(self) -> List[str]:
keys = dir(self.__class__)
keys.extend(ContextOption.__members__)
return keys
@@ -202,17 +204,17 @@ def __dir__(self):
# Creating Sockets
# -------------------------------------------------------------------------
- def _add_socket(self, socket: Any):
+ def _add_socket(self, socket: Any) -> None:
"""Add a weakref to a socket for Context.destroy / reference counting"""
self._sockets.add(socket)
- def _rm_socket(self, socket: Any):
+ def _rm_socket(self, socket: Any) -> None:
"""Remove a socket for Context.destroy / reference counting"""
# allow _sockets to be None in case of process teardown
if getattr(self, "_sockets", None) is not None:
self._sockets.discard(socket)
- def destroy(self, linger: Optional[float] = None):
+ def destroy(self, linger: Optional[float] = None) -> None:
"""Close all sockets associated with this context and then terminate
the context.
@@ -240,11 +242,7 @@ def destroy(self, linger: Optional[float] = None):
self.term()
- @property
- def _socket_class(self):
- return Socket
-
- def socket(self, socket_type: int, **kwargs):
+ def socket(self: T, socket_type: int, **kwargs: Any) -> ST:
"""Create a Socket associated with this Context.
Parameters
@@ -258,7 +256,7 @@ def socket(self, socket_type: int, **kwargs):
"""
if self.closed:
raise ZMQError(Errno.ENOTSUP)
- s = self._socket_class( # set PYTHONTRACEMALLOC=2 to get the calling frame
+ s: ST = self._socket_class( # set PYTHONTRACEMALLOC=2 to get the calling frame
self, socket_type, **kwargs
)
for opt, value in self.sockopts.items():
@@ -272,21 +270,21 @@ def socket(self, socket_type: int, **kwargs):
self._add_socket(s)
return s
- def setsockopt(self, opt: int, value):
+ def setsockopt(self, opt: int, value: Any) -> None:
"""set default socket options for new sockets created by this Context
.. versionadded:: 13.0
"""
self.sockopts[opt] = value
- def getsockopt(self, opt: int):
+ def getsockopt(self, opt: int) -> OptValT:
"""get default socket options for new sockets created by this Context
.. versionadded:: 13.0
"""
return self.sockopts[opt]
- def _set_attr_opt(self, name: str, opt: int, value):
+ def _set_attr_opt(self, name: str, opt: int, value: OptValT) -> None:
"""set default sockopts as attributes"""
if name in ContextOption.__members__:
return self.set(opt, value)
@@ -295,7 +293,7 @@ def _set_attr_opt(self, name: str, opt: int, value):
else:
raise AttributeError(f"No such context or socket option: {name}")
- def _get_attr_opt(self, name: str, opt: int):
+ def _get_attr_opt(self, name: str, opt: int) -> OptValT:
"""get default sockopts as attributes"""
if name in ContextOption.__members__:
return self.get(opt)
@@ -305,7 +303,7 @@ def _get_attr_opt(self, name: str, opt: int):
else:
return self.sockopts[opt]
- def __delattr__(self, key: str):
+ def __delattr__(self, key: str) -> None:
"""delete default sockopts as attributes"""
key = key.upper()
try:
diff --git a/zmq/sugar/poll.py b/zmq/sugar/poll.py
index 37fd99d96..0c76dc084 100644
--- a/zmq/sugar/poll.py
+++ b/zmq/sugar/poll.py
@@ -3,9 +3,8 @@
# Copyright (C) PyZMQ Developers
# Distributed under the terms of the Modified BSD License.
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple
-import zmq
from zmq.backend import zmq_poll
from zmq.constants import POLLERR, POLLIN, POLLOUT
@@ -20,11 +19,11 @@ class Poller:
sockets: List[Tuple[Any, int]]
_map: Dict
- def __init__(self):
+ def __init__(self) -> None:
self.sockets = []
self._map = {}
- def __contains__(self, socket: Any):
+ def __contains__(self, socket: Any) -> bool:
return socket in self._map
def register(self, socket: Any, flags: int = POLLIN | POLLOUT):
@@ -76,7 +75,7 @@ def unregister(self, socket: Any):
for socket, flags in self.sockets[idx:]:
self._map[socket] -= 1
- def poll(self, timeout: Optional[int] = None):
+ def poll(self, timeout: Optional[int] = None) -> List[Tuple[Any, int]]:
"""Poll the registered 0MQ or native fds for I/O.
If there are currently events ready to be processed, this function will return immediately.
diff --git a/zmq/sugar/socket.py b/zmq/sugar/socket.py
index 92363bf13..0dd2d68a7 100644
--- a/zmq/sugar/socket.py
+++ b/zmq/sugar/socket.py
@@ -9,8 +9,22 @@
import random
import sys
import warnings
+from typing import (
+ Any,
+ Dict,
+ Generic,
+ List,
+ Optional,
+ Sequence,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
import zmq
+from zmq._typing import Literal
from zmq.backend import Socket as SocketBase
from zmq.error import ZMQBindError, ZMQError
from zmq.utils import jsonapi
@@ -25,20 +39,26 @@
except AttributeError:
DEFAULT_PROTOCOL = pickle.HIGHEST_PROTOCOL
+T = TypeVar("T", bound="Socket")
-class _SocketContext:
+
+class _SocketContext(Generic[T]):
"""Context Manager for socket bind/unbind"""
+ socket: T
+ kind: str
+ addr: str
+
def __repr__(self):
return f""
- def __init__(self, socket, kind, addr):
+ def __init__(self: "_SocketContext[T]", socket: T, kind: str, addr: str):
assert kind in {"bind", "connect"}
self.socket = socket
self.kind = kind
self.addr = addr
- def __enter__(self):
+ def __enter__(self: "_SocketContext[T]") -> T:
return self.socket
def __exit__(self, *args):
@@ -109,7 +129,7 @@ def __repr__(self):
return f"<{_repr_cls}(zmq.{self._type_name}) at {hex(id(self))}{closed}>"
# socket as context manager:
- def __enter__(self):
+ def __enter__(self: T) -> T:
"""Sockets are context managers
.. versionadded:: 14.4
@@ -123,14 +143,14 @@ def __exit__(self, *args, **kwargs):
# Socket creation
# -------------------------------------------------------------------------
- def __copy__(self, memo=None):
+ def __copy__(self: T, memo=None) -> T:
"""Copying a Socket creates a shadow copy"""
return self.__class__.shadow(self.underlying)
__deepcopy__ = __copy__
@classmethod
- def shadow(cls, address):
+ def shadow(cls: Type[T], address: int) -> T:
"""Shadow an existing libzmq socket
address is the integer address of the libzmq socket
@@ -143,7 +163,7 @@ def shadow(cls, address):
address = cast_int_addr(address)
return cls(shadow=address)
- def close(self, linger=None):
+ def close(self, linger=None) -> None:
"""
Close the socket.
@@ -167,21 +187,21 @@ def close(self, linger=None):
# Connect/Bind context managers
# -------------------------------------------------------------------------
- def _connect_cm(self, addr):
+ def _connect_cm(self: T, addr: str) -> _SocketContext[T]:
"""Context manager to disconnect on exit
.. versionadded:: 20.0
"""
return _SocketContext(self, 'connect', addr)
- def _bind_cm(self, addr):
+ def _bind_cm(self: T, addr: str) -> _SocketContext[T]:
"""Context manager to unbind on exit
.. versionadded:: 20.0
"""
return _SocketContext(self, 'bind', addr)
- def bind(self, addr):
+ def bind(self: T, addr: str) -> _SocketContext[T]:
"""s.bind(addr)
Bind the socket to an address.
@@ -207,7 +227,7 @@ def bind(self, addr):
super().bind(addr)
return self._bind_cm(addr)
- def connect(self, addr):
+ def connect(self: T, addr: str) -> _SocketContext[T]:
"""s.connect(addr)
Connect to a remote 0MQ socket.
@@ -238,7 +258,7 @@ def socket_type(self) -> int:
warnings.warn(
"Socket.socket_type is deprecated, use Socket.type", DeprecationWarning
)
- return self.type
+ return cast(int, self.type)
# -------------------------------------------------------------------------
# Hooks for sockopt completion
@@ -283,7 +303,7 @@ def fileno(self):
"""
return self.FD
- def subscribe(self, topic):
+ def subscribe(self, topic: Union[str, bytes]) -> None:
"""Subscribe to a topic
Only for SUB sockets.
@@ -294,7 +314,7 @@ def subscribe(self, topic):
topic = topic.encode('utf8')
self.set(zmq.SUBSCRIBE, topic)
- def unsubscribe(self, topic):
+ def unsubscribe(self, topic: Union[str, bytes]) -> None:
"""Unsubscribe from a topic
Only for SUB sockets.
@@ -305,7 +325,7 @@ def unsubscribe(self, topic):
topic = topic.encode('utf8')
self.set(zmq.UNSUBSCRIBE, topic)
- def set_string(self, option, optval, encoding='utf-8'):
+ def set_string(self, option: int, optval: str, encoding='utf-8') -> None:
"""Set socket options with a unicode object.
This is simply a wrapper for setsockopt to protect from encoding ambiguity.
@@ -328,7 +348,7 @@ def set_string(self, option, optval, encoding='utf-8'):
setsockopt_unicode = setsockopt_string = set_string
- def get_string(self, option, encoding='utf-8'):
+ def get_string(self, option: int, encoding='utf-8') -> str:
"""Get the value of a socket option.
See the 0MQ documentation for details on specific options.
@@ -346,11 +366,17 @@ def get_string(self, option, encoding='utf-8'):
if SocketOption(option)._opt_type != _OptType.bytes:
raise TypeError(f"option {option} will not return a string to be decoded")
- return self.getsockopt(option).decode(encoding)
+ return cast(bytes, self.get(option)).decode(encoding)
getsockopt_unicode = getsockopt_string = get_string
- def bind_to_random_port(self, addr, min_port=49152, max_port=65536, max_tries=100):
+ def bind_to_random_port(
+ self: T,
+ addr: str,
+ min_port: int = 49152,
+ max_port: int = 65536,
+ max_tries: int = 100,
+ ) -> int:
"""Bind this socket to a random port in a range.
If the port range is unspecified, the system will choose the port.
@@ -384,7 +410,7 @@ def bind_to_random_port(self, addr, min_port=49152, max_port=65536, max_tries=10
# if LAST_ENDPOINT is supported, and min_port / max_port weren't specified,
# we can bind to port 0 and let the OS do the work
self.bind("%s:*" % addr)
- url = self.last_endpoint.decode('ascii', 'replace')
+ url = cast(bytes, self.last_endpoint).decode('ascii', 'replace')
_, port_s = url.rsplit(':', 1)
return int(port_s)
@@ -404,7 +430,7 @@ def bind_to_random_port(self, addr, min_port=49152, max_port=65536, max_tries=10
return port
raise ZMQBindError("Could not bind socket to random port.")
- def get_hwm(self):
+ def get_hwm(self) -> int:
"""Get the High Water Mark.
On libzmq ≥ 3, this gets SNDHWM if available, otherwise RCVHWM
@@ -413,15 +439,15 @@ def get_hwm(self):
if major >= 3:
# return sndhwm, fallback on rcvhwm
try:
- return self.getsockopt(zmq.SNDHWM)
+ return cast(int, self.get(zmq.SNDHWM))
except zmq.ZMQError:
pass
- return self.getsockopt(zmq.RCVHWM)
+ return cast(int, self.get(zmq.RCVHWM))
else:
- return self.getsockopt(zmq.HWM)
+ return cast(int, self.get(zmq.HWM))
- def set_hwm(self, value):
+ def set_hwm(self, value: int) -> None:
"""Set the High Water Mark.
On libzmq ≥ 3, this sets both SNDHWM and RCVHWM
@@ -447,7 +473,7 @@ def set_hwm(self, value):
if raised:
raise raised
else:
- return self.setsockopt(zmq.HWM, value)
+ self.set(zmq.HWM, value)
hwm = property(
get_hwm,
@@ -464,7 +490,65 @@ def set_hwm(self, value):
# Sending and receiving messages
# -------------------------------------------------------------------------
- def send(self, data, flags=0, copy=True, track=False, routing_id=None, group=None):
+ @overload
+ def send(
+ self,
+ data: Any,
+ flags: int = ...,
+ copy: bool = ...,
+ *,
+ track: Literal[True],
+ routing_id: Optional[int] = ...,
+ group: Optional[str] = ...,
+ ) -> "zmq.MessageTracker":
+ ...
+
+ @overload
+ def send(
+ self,
+ data: Any,
+ flags: int = ...,
+ copy: bool = ...,
+ *,
+ track: Literal[False],
+ routing_id: Optional[int] = ...,
+ group: Optional[str] = ...,
+ ) -> None:
+ ...
+
+ @overload
+ def send(
+ self,
+ data: Any,
+ flags: int = ...,
+ *,
+ copy: bool = ...,
+ routing_id: Optional[int] = ...,
+ group: Optional[str] = ...,
+ ) -> None:
+ ...
+
+ @overload
+ def send(
+ self,
+ data: Any,
+ flags: int = ...,
+ copy: bool = ...,
+ track: bool = ...,
+ routing_id: Optional[int] = ...,
+ group: Optional[str] = ...,
+ ) -> Optional["zmq.MessageTracker"]:
+ ...
+
+ def send(
+ self,
+ data: Any,
+ flags: int = 0,
+ copy: bool = True,
+ track: bool = False,
+ routing_id: Optional[int] = None,
+ group: Optional[str] = None,
+ ) -> Optional["zmq.MessageTracker"]:
"""Send a single zmq message frame on this socket.
This queues the message to be sent by the IO thread at a later time.
@@ -533,7 +617,14 @@ def send(self, data, flags=0, copy=True, track=False, routing_id=None, group=Non
data.group = group
return super().send(data, flags=flags, copy=copy, track=track)
- def send_multipart(self, msg_parts, flags=0, copy=True, track=False, **kwargs):
+ def send_multipart(
+ self,
+ msg_parts: Sequence,
+ flags: int = 0,
+ copy: bool = True,
+ track: bool = False,
+ **kwargs,
+ ):
"""Send a sequence of buffers as a multipart message.
The zmq.SNDMORE flag is added to all msg parts before the last.
@@ -583,7 +674,31 @@ def send_multipart(self, msg_parts, flags=0, copy=True, track=False, **kwargs):
# Send the last part without the extra SNDMORE flag.
return self.send(msg_parts[-1], flags, copy=copy, track=track)
- def recv_multipart(self, flags=0, copy=True, track=False):
+ @overload
+ def recv_multipart(
+ self, flags: int = ..., *, copy: Literal[True], track: bool = ...
+ ) -> List[bytes]:
+ ...
+
+ @overload
+ def recv_multipart(
+ self, flags: int = ..., *, copy: Literal[False], track: bool = ...
+ ) -> List[zmq.Frame]:
+ ...
+
+ @overload
+ def recv_multipart(self, flags: int = ..., *, track: bool = ...) -> List[bytes]:
+ ...
+
+ @overload
+ def recv_multipart(
+ self, flags: int = 0, copy: bool = True, track: bool = False
+ ) -> Union[List[zmq.Frame], List[bytes]]:
+ ...
+
+ def recv_multipart(
+ self, flags: int = 0, copy: bool = True, track: bool = False
+ ) -> Union[List[zmq.Frame], List[bytes]]:
"""Receive a multipart message as a list of bytes or Frame objects
Parameters
@@ -614,8 +729,9 @@ def recv_multipart(self, flags=0, copy=True, track=False):
while self.getsockopt(zmq.RCVMORE):
part = self.recv(flags, copy=copy, track=track)
parts.append(part)
-
- return parts
+ # cast List[Union] to Union[List]
+ # how do we get mypy to recognize that return type is invariant on `copy`?
+ return cast(Union[List[zmq.Frame], List[bytes]], parts)
def _deserialize(self, recvd, load):
"""Deserialize a received message
@@ -685,7 +801,14 @@ def recv_serialized(self, deserialize, flags=0, copy=True):
frames = self.recv_multipart(flags=flags, copy=copy)
return self._deserialize(frames, deserialize)
- def send_string(self, u, flags=0, copy=True, encoding='utf-8', **kwargs):
+ def send_string(
+ self,
+ u: str,
+ flags: int = 0,
+ copy: bool = True,
+ encoding: str = 'utf-8',
+ **kwargs,
+ ) -> Optional["zmq.Frame"]:
"""Send a Python unicode string as a message with an encoding.
0MQ communicates with raw bytes, so you must encode/decode
@@ -706,7 +829,7 @@ def send_string(self, u, flags=0, copy=True, encoding='utf-8', **kwargs):
send_unicode = send_string
- def recv_string(self, flags=0, encoding='utf-8'):
+ def recv_string(self, flags: int = 0, encoding: str = 'utf-8') -> str:
"""Receive a unicode string, as sent by send_string.
Parameters
@@ -731,7 +854,9 @@ def recv_string(self, flags=0, encoding='utf-8'):
recv_unicode = recv_string
- def send_pyobj(self, obj, flags=0, protocol=DEFAULT_PROTOCOL, **kwargs):
+ def send_pyobj(
+ self, obj: Any, flags: int = 0, protocol: int = DEFAULT_PROTOCOL, **kwargs
+ ) -> Optional[zmq.Frame]:
"""Send a Python object as a message using pickle to serialize.
Parameters
@@ -747,7 +872,7 @@ def send_pyobj(self, obj, flags=0, protocol=DEFAULT_PROTOCOL, **kwargs):
msg = pickle.dumps(obj, protocol)
return self.send(msg, flags=flags, **kwargs)
- def recv_pyobj(self, flags=0):
+ def recv_pyobj(self, flags: int = 0) -> Any:
"""Receive a Python object as a message using pickle to serialize.
Parameters
@@ -768,7 +893,7 @@ def recv_pyobj(self, flags=0):
msg = self.recv(flags)
return self._deserialize(msg, pickle.loads)
- def send_json(self, obj, flags=0, **kwargs):
+ def send_json(self, obj: Any, flags: int = 0, **kwargs) -> None:
"""Send a Python object as a message using json to serialize.
Keyword arguments are passed on to json.dumps
@@ -787,7 +912,7 @@ def send_json(self, obj, flags=0, **kwargs):
msg = jsonapi.dumps(obj, **kwargs)
return self.send(msg, flags=flags, **send_kwargs)
- def recv_json(self, flags=0, **kwargs):
+ def recv_json(self, flags: int = 0, **kwargs) -> Union[List, str, int, float, Dict]:
"""Receive a Python object as a message using json to serialize.
Keyword arguments are passed on to json.loads
@@ -812,7 +937,7 @@ def recv_json(self, flags=0, **kwargs):
_poller_class = Poller
- def poll(self, timeout=None, flags=zmq.POLLIN):
+ def poll(self, timeout=None, flags=zmq.POLLIN) -> int:
"""Poll the socket for events.
See :class:`Poller` to wait for multiple sockets at once.
@@ -840,7 +965,9 @@ def poll(self, timeout=None, flags=zmq.POLLIN):
# return 0 if no events, otherwise return event bitfield
return evts.get(self, 0)
- def get_monitor_socket(self, events=None, addr=None):
+ def get_monitor_socket(
+ self: T, events: Optional[int] = None, addr: Optional[str] = None
+ ) -> T:
"""Return a connected PAIR socket ready to receive the event notifications.
.. versionadded:: libzmq-4.0
@@ -873,7 +1000,7 @@ def get_monitor_socket(self, events=None, addr=None):
if addr is None:
# create endpoint name from internal fd
- addr = "inproc://monitor.s-%d" % self.FD
+ addr = f"inproc://monitor.s-{self.FD}"
if events is None:
# use all events
events = zmq.EVENT_ALL
@@ -884,7 +1011,7 @@ def get_monitor_socket(self, events=None, addr=None):
self._monitor_socket.connect(addr)
return self._monitor_socket
- def disable_monitor(self):
+ def disable_monitor(self) -> None:
"""Shutdown the PAIR socket (created using get_monitor_socket)
that is serving socket events.
diff --git a/zmq/sugar/version.py b/zmq/sugar/version.py
index 050c94134..87192c360 100644
--- a/zmq/sugar/version.py
+++ b/zmq/sugar/version.py
@@ -11,7 +11,7 @@
VERSION_MINOR = 3
VERSION_PATCH = 0
VERSION_EXTRA = ""
-__version__ = '%i.%i.%i' % (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH)
+__version__: str = '%i.%i.%i' % (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH)
version_info: Union[Tuple[int, int, int], Tuple[int, int, int, float]] = (
VERSION_MAJOR,
@@ -28,7 +28,7 @@
float('inf'),
)
-__revision__ = ''
+__revision__: str = ''
def pyzmq_version() -> str:
diff --git a/zmq/tests/test_mypy.py b/zmq/tests/test_mypy.py
index 85f59b632..026463404 100644
--- a/zmq/tests/test_mypy.py
+++ b/zmq/tests/test_mypy.py
@@ -34,12 +34,14 @@ def resolve_repo_dir(path):
mypy_dir = resolve_repo_dir("mypy_tests")
-def run_mypy(path):
+def run_mypy(*mypy_args):
"""Run mypy for a path
Captures output and reports it on errors
"""
- p = Popen([sys.executable, "-m", "mypy", path], stdout=PIPE, stderr=STDOUT)
+ p = Popen(
+ [sys.executable, "-m", "mypy"] + list(mypy_args), stdout=PIPE, stderr=STDOUT
+ )
o, _ = p.communicate()
out = o.decode("utf8", "replace")
print(out)
@@ -60,7 +62,7 @@ def run_mypy(path):
@pytest.mark.parametrize("example", examples)
def test_mypy_example(example):
example_dir = os.path.join(examples_dir, example)
- run_mypy(example_dir)
+ run_mypy("--disallow-untyped-calls", example_dir)
if os.path.exists(mypy_dir):
@@ -69,4 +71,4 @@ def test_mypy_example(example):
@pytest.mark.parametrize("filename", mypy_tests)
def test_mypy(filename):
- run_mypy(os.path.join(mypy_dir, filename))
+ run_mypy("--disallow-untyped-calls", os.path.join(mypy_dir, filename))
diff --git a/zmq/tests/test_socket.py b/zmq/tests/test_socket.py
index 5544477a5..c9d795f9a 100644
--- a/zmq/tests/test_socket.py
+++ b/zmq/tests/test_socket.py
@@ -211,6 +211,7 @@ def test_int_sockopts(self):
continue
if opt.name.startswith(
(
+ 'HWM',
'ROUTER',
'XPUB',
'TCP',
diff --git a/zmq/utils/monitor.py b/zmq/utils/monitor.py
index fb12125b5..e901a6c1d 100644
--- a/zmq/utils/monitor.py
+++ b/zmq/utils/monitor.py
@@ -4,12 +4,20 @@
# Distributed under the terms of the Modified BSD License.
import struct
+from typing import Any, List, Tuple, cast
import zmq
+from zmq._typing import TypedDict
from zmq.error import _check_version
-def parse_monitor_message(msg):
+class _MonitorMessage(TypedDict):
+ event: int
+ value: int
+ endpoint: bytes
+
+
+def parse_monitor_message(msg: List[bytes]) -> _MonitorMessage:
"""decode zmq_monitor event messages.
Parameters
@@ -30,18 +38,18 @@ def parse_monitor_message(msg):
event : dict
event description as dict with the keys `event`, `value`, and `endpoint`.
"""
-
if len(msg) != 2 or len(msg[0]) != 6:
raise RuntimeError("Invalid event message format: %s" % msg)
- event = {
- 'event': struct.unpack("=hi", msg[0])[0],
- 'value': struct.unpack("=hi", msg[0])[1],
+ event_id, value = struct.unpack("=hi", msg[0])
+ event: _MonitorMessage = {
+ 'event': event_id,
+ 'value': value,
'endpoint': msg[1],
}
return event
-def recv_monitor_message(socket, flags=0):
+def recv_monitor_message(socket: zmq.Socket, flags: int = 0) -> _MonitorMessage:
"""Receive and decode the given raw message from the monitoring socket and return a dict.
Requires libzmq ≥ 4.0