Skip to content

Commit 755dcf3

Browse files
authored
Test stop of server task. (#1256)
1 parent 8f6dbec commit 755dcf3

File tree

3 files changed

+371
-100
lines changed

3 files changed

+371
-100
lines changed

pymodbus/server/async_io.py

Lines changed: 86 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
# pylint: disable=missing-type-doc
33
import asyncio
44
import logging
5-
import platform
65
import ssl
76
import traceback
87
from binascii import b2a_hex
@@ -500,7 +499,6 @@ def __init__(
500499
address=None,
501500
handler=None,
502501
allow_reuse_address=False,
503-
allow_reuse_port=False,
504502
defer_start=False,
505503
backlog=20,
506504
**kwargs,
@@ -519,8 +517,6 @@ def __init__(
519517
receives connection create/teardown events
520518
:param allow_reuse_address: Whether the server will allow the
521519
reuse of an address.
522-
:param allow_reuse_port: Whether the server will allow the
523-
reuse of a port.
524520
:param backlog: is the maximum number of queued connections
525521
passed to listen(). Defaults to 20, increase if many
526522
connections are being made and broken to your Modbus slave
@@ -556,7 +552,6 @@ def __init__(
556552
self.server = None
557553
self.factory_parms = {
558554
"reuse_address": allow_reuse_address,
559-
"reuse_port": allow_reuse_port,
560555
"backlog": backlog,
561556
"start_serving": not defer_start,
562557
}
@@ -621,7 +616,6 @@ def __init__( # pylint: disable=too-many-arguments
621616
reqclicert=False,
622617
handler=None,
623618
allow_reuse_address=False,
624-
allow_reuse_port=False,
625619
defer_start=False,
626620
backlog=20,
627621
**kwargs,
@@ -646,8 +640,6 @@ def __init__( # pylint: disable=too-many-arguments
646640
receives connection create/teardown events
647641
:param allow_reuse_address: Whether the server will allow the
648642
reuse of an address.
649-
:param allow_reuse_port: Whether the server will allow the
650-
reuse of a port.
651643
:param backlog: is the maximum number of queued connections
652644
passed to listen(). Defaults to 20, increase if many
653645
connections are being made and broken to your Modbus slave
@@ -665,7 +657,6 @@ def __init__( # pylint: disable=too-many-arguments
665657
address=address,
666658
handler=handler,
667659
allow_reuse_address=allow_reuse_address,
668-
allow_reuse_port=allow_reuse_port,
669660
defer_start=defer_start,
670661
backlog=backlog,
671662
**kwargs,
@@ -689,7 +680,6 @@ def __init__(
689680
identity=None,
690681
address=None,
691682
handler=None,
692-
allow_reuse_port=False,
693683
defer_start=False, # pylint: disable=unused-argument
694684
backlog=20, # pylint: disable=unused-argument
695685
**kwargs,
@@ -731,12 +721,10 @@ def __init__(
731721
self.protocol = None
732722
self.endpoint = None
733723
self.on_connection_terminated = None
734-
self.stop_serving = self.loop.create_future()
735724
# asyncio future that will be done once server has started
736725
self.serving = self.loop.create_future()
737726
self.factory_parms = {
738727
"local_addr": self.address,
739-
"reuse_port": allow_reuse_port,
740728
"allow_broadcast": True,
741729
}
742730

@@ -749,9 +737,12 @@ async def serve_forever(self):
749737
**self.factory_parms,
750738
)
751739
except asyncio.exceptions.CancelledError:
752-
pass
740+
raise
741+
except Exception as exc:
742+
txt = f"Server unexpected exception {exc}"
743+
_logger.error(txt)
744+
raise RuntimeError(exc) from exc
753745
self.serving.set_result(True)
754-
await self.stop_serving
755746
else:
756747
raise RuntimeError(
757748
"Can't call serve_forever on an already running server object"
@@ -765,13 +756,10 @@ async def server_close(self):
765756
"""Close server."""
766757
if self.endpoint:
767758
self.endpoint.running = False
768-
if not self.stop_serving.done():
769-
self.stop_serving.set_result(True)
770759
if self.endpoint is not None and self.endpoint.handler_task is not None:
771760
self.endpoint.handler_task.cancel()
772761
if self.protocol is not None:
773762
self.protocol.close()
774-
# TBD await self.protocol.wait_closed()
775763
self.protocol = None
776764

777765

@@ -812,6 +800,7 @@ def __init__(
812800
:param response_manipulator: Callback method for
813801
manipulating the response
814802
"""
803+
self.loop = kwargs.get("loop") or asyncio.get_event_loop()
815804
self.bytesize = kwargs.get("bytesize", Defaults.Bytesize)
816805
self.parity = kwargs.get("parity", Defaults.Parity)
817806
self.baudrate = kwargs.get("baudrate", Defaults.Baudrate)
@@ -862,7 +851,7 @@ async def _connect(self):
862851
return
863852
try:
864853
self.transport, self.protocol = await create_serial_connection(
865-
asyncio.get_event_loop(),
854+
self.loop,
866855
lambda: self.handler(self),
867856
self.device,
868857
baudrate=self.baudrate,
@@ -887,44 +876,56 @@ def on_connection_lost(self):
887876
self.transport.close()
888877
self.transport = None
889878
self.protocol = None
890-
891-
self._check_reconnect()
879+
if self.server is None:
880+
self._check_reconnect()
892881

893882
async def shutdown(self):
894883
"""Terminate server."""
895884
if self.transport is not None:
896-
self.transport.close()
897-
self.transport = None
898-
self.protocol = None
885+
self.transport.abort()
886+
if self.server is not None:
887+
self.server.close()
888+
await asyncio.wait_for(self.server.wait_closed(), 10)
889+
self.server = None
890+
self.transport = None
891+
self.protocol = None
899892

900893
def _check_reconnect(self):
901894
"""Check reconnect."""
902895
txt = f"checking autoreconnect {self.auto_reconnect} {self.reconnecting_task}"
903896
_logger.debug(txt)
904897
if self.auto_reconnect and (self.reconnecting_task is None):
905898
_logger.debug("Scheduling serial connection reconnect")
906-
loop = asyncio.get_event_loop()
907-
self.reconnecting_task = loop.create_task(self._delayed_connect())
899+
self.reconnecting_task = self.loop.create_task(self._delayed_connect())
908900

909901
async def serve_forever(self):
910902
"""Start endless loop."""
903+
if self.server:
904+
raise RuntimeError(
905+
"Can't call serve_forever on an already running server object"
906+
)
911907
if self.device.startswith("socket:"):
912908
# Socket server means listen so start a socket server
913-
parts = self.device[7:].split(":")
914-
host_port = ("", int(parts[1]))
915-
self.server = await asyncio.get_event_loop().create_server(
909+
parts = self.device[9:].split(":")
910+
host_addr = (parts[0], int(parts[1]))
911+
self.server = await self.loop.create_server(
916912
lambda: self.handler(self),
917-
*host_port,
913+
*host_addr,
918914
reuse_address=True,
919-
reuse_port=True,
920915
start_serving=True,
921916
backlog=20,
922917
)
923-
await self.server.serve_forever()
918+
try:
919+
await self.server.serve_forever()
920+
except asyncio.exceptions.CancelledError:
921+
raise
922+
except Exception as exc: # pylint: disable=broad-except
923+
txt = f"Server unexpected exception {exc}"
924+
_logger.error(txt)
924925
return
925926

926-
while True:
927-
await asyncio.sleep(360)
927+
while self.server or self.transport or self.protocol:
928+
await asyncio.sleep(10)
928929

929930

930931
# --------------------------------------------------------------------------- #
@@ -951,64 +952,50 @@ def __init__(self, server, custom_functions, register):
951952
self.job_stop = asyncio.Event()
952953
self.job_is_stopped = asyncio.Event()
953954
self.task = None
955+
self.loop = asyncio.get_event_loop()
954956

955957
@classmethod
956958
def get_server(cls):
957959
"""Get server at index."""
958-
return cls._servers[-1]
960+
return cls._servers[-1] if cls._servers else None
959961

960962
def _remove(self):
961963
"""Remove server from active list."""
962964
server = self._servers[-1]
963965
self._servers.pop()
964966
del server
965967

968+
async def _run(self):
969+
"""Help starting/stopping server."""
970+
# self.task = asyncio.create_task(self.server.serve_forever())
971+
# await self.job_stop.wait()
972+
# await self.server.shutdown()
973+
# await asyncio.sleep(0.1)
974+
# self.task.cancel()
975+
# await asyncio.sleep(0.1)
976+
# try:
977+
# await asyncio.wait_for(self.task, 10)
978+
# except asyncio.CancelledError:
979+
# pass
980+
# self.job_is_stopped.set()
981+
966982
async def run(self):
967983
"""Help starting/stopping server."""
968984
try:
969-
self.task = asyncio.create_task(self.server.serve_forever())
970-
except Exception as exc: # pylint: disable=broad-except
971-
txt = f"Server caught exception: {exc}"
972-
_logger.error(txt)
973-
await self.job_stop.wait()
974-
await self.server.shutdown()
975-
await asyncio.sleep(0.1)
976-
self.task.cancel()
977-
await asyncio.sleep(0.1)
978-
try:
979-
await asyncio.wait_for(self.task, 10)
985+
# await self._run()
986+
await self.server.serve_forever()
980987
except asyncio.CancelledError:
981988
pass
982-
if platform.system().lower() == "windows":
983-
owntask = asyncio.current_task()
984-
for task in asyncio.all_tasks():
985-
if task != owntask:
986-
task.cancel()
987-
try:
988-
await asyncio.wait_for(task, 10)
989-
except asyncio.CancelledError:
990-
pass
991-
self.job_is_stopped.set()
992-
993-
def request_stop(self):
994-
"""Request server stop."""
995-
self.job_stop.set()
996989

997990
async def async_await_stop(self):
998991
"""Wait for server stop."""
999-
try:
1000-
await self.job_is_stopped.wait()
1001-
except asyncio.exceptions.CancelledError:
1002-
pass
1003-
self._remove()
1004-
1005-
def await_stop(self):
1006-
"""Wait for server stop."""
1007-
for i in range(30): # Loop for 3 seconds
1008-
sleep(0.1) # in steps of 100 milliseconds.
1009-
if self.job_is_stopped.is_set():
1010-
break
1011-
self._remove()
992+
await self.server.shutdown()
993+
# self.job_stop.set()
994+
# try:
995+
# await asyncio.wait_for(self.job_is_stopped.wait(), 60)
996+
# except asyncio.exceptions.CancelledError:
997+
# pass
998+
# self._remove()
1012999

10131000

10141001
async def StartAsyncTcpServer( # pylint: disable=invalid-name,dangerous-default-value
@@ -1035,10 +1022,10 @@ async def StartAsyncTcpServer( # pylint: disable=invalid-name,dangerous-default
10351022
server = ModbusTcpServer(
10361023
context, kwargs.pop("framer", ModbusSocketFramer), identity, address, **kwargs
10371024
)
1038-
job = _serverList(server, custom_functions, not defer_start)
1039-
if defer_start:
1040-
return server
1041-
await job.run()
1025+
if not defer_start:
1026+
job = _serverList(server, custom_functions, not defer_start)
1027+
await job.run()
1028+
return server
10421029

10431030

10441031
async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default-value,too-many-arguments
@@ -1051,7 +1038,6 @@ async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default
10511038
password=None,
10521039
reqclicert=False,
10531040
allow_reuse_address=False,
1054-
allow_reuse_port=False,
10551041
custom_functions=[],
10561042
defer_start=False,
10571043
**kwargs,
@@ -1068,7 +1054,6 @@ async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default
10681054
:param reqclicert: Force the sever request client's certificate
10691055
:param allow_reuse_address: Whether the server will allow the reuse of an
10701056
address.
1071-
:param allow_reuse_port: Whether the server will allow the reuse of a port.
10721057
:param custom_functions: An optional list of custom function classes
10731058
supported by server instance.
10741059
:param defer_start: if set, the server object will be returned ready to start.
@@ -1088,13 +1073,12 @@ async def StartAsyncTlsServer( # pylint: disable=invalid-name,dangerous-default
10881073
password,
10891074
reqclicert,
10901075
allow_reuse_address=allow_reuse_address,
1091-
allow_reuse_port=allow_reuse_port,
10921076
**kwargs,
10931077
)
1094-
job = _serverList(server, custom_functions, not defer_start)
1095-
if defer_start:
1096-
return server
1097-
await job.run()
1078+
if not defer_start:
1079+
job = _serverList(server, custom_functions, not defer_start)
1080+
await job.run()
1081+
return server
10981082

10991083

11001084
async def StartAsyncUdpServer( # pylint: disable=invalid-name,dangerous-default-value
@@ -1120,10 +1104,10 @@ async def StartAsyncUdpServer( # pylint: disable=invalid-name,dangerous-default
11201104
server = ModbusUdpServer(
11211105
context, kwargs.pop("framer", ModbusSocketFramer), identity, address, **kwargs
11221106
)
1123-
job = _serverList(server, custom_functions, not defer_start)
1124-
if defer_start:
1125-
return server
1126-
await job.run()
1107+
if not defer_start:
1108+
job = _serverList(server, custom_functions, not defer_start)
1109+
await job.run()
1110+
return server
11271111

11281112

11291113
async def StartAsyncSerialServer( # pylint: disable=invalid-name,dangerous-default-value
@@ -1147,11 +1131,10 @@ async def StartAsyncSerialServer( # pylint: disable=invalid-name,dangerous-defa
11471131
server = ModbusSerialServer(
11481132
context, kwargs.pop("framer", ModbusAsciiFramer), identity=identity, **kwargs
11491133
)
1150-
job = _serverList(server, custom_functions, not defer_start)
1151-
if defer_start:
1152-
return server
1153-
await server.start()
1154-
await job.run()
1134+
if not defer_start:
1135+
job = _serverList(server, custom_functions, not defer_start)
1136+
await job.run()
1137+
return server
11551138

11561139

11571140
def StartSerialServer(**kwargs): # pylint: disable=invalid-name
@@ -1176,13 +1159,18 @@ def StartUdpServer(**kwargs): # pylint: disable=invalid-name
11761159

11771160
async def ServerAsyncStop(): # pylint: disable=invalid-name
11781161
"""Terminate server."""
1179-
my_job = _serverList.get_server()
1180-
my_job.request_stop()
1181-
await my_job.async_await_stop()
1162+
if my_job := _serverList.get_server():
1163+
await my_job.async_await_stop()
1164+
await asyncio.sleep(0.1)
1165+
else:
1166+
raise RuntimeError("ServerAsyncStop called without server task active.")
11821167

11831168

11841169
def ServerStop(): # pylint: disable=invalid-name
11851170
"""Terminate server."""
1186-
my_job = _serverList.get_server()
1187-
my_job.request_stop()
1188-
my_job.await_stop()
1171+
if my_job := _serverList.get_server():
1172+
if my_job.loop.is_running():
1173+
asyncio.run_coroutine_threadsafe(my_job.async_await_stop(), my_job.loop)
1174+
sleep(0.1)
1175+
else:
1176+
raise RuntimeError("ServerStop called without server task active.")

0 commit comments

Comments
 (0)