Skip to content

Commit a83e775

Browse files
authored
Merge pull request #238 from moi-si/patch
Safely close StreamWriter and cancel stream task
2 parents 1133fd6 + c5ce421 commit a83e775

File tree

1 file changed

+38
-31
lines changed

1 file changed

+38
-31
lines changed

accesser/__init__.py

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import ssl
2626
import asyncio
2727
import traceback
28-
from contextlib import closing
2928
from urllib import request
3029
from urllib.parse import urlsplit
3130
from packaging.version import Version
@@ -46,9 +45,9 @@ async def update_cert(server_name):
4645
if not is_tld(server_name):
4746
res = get_tld(server_name, as_object=True, fix_protocol=True)
4847
if res.subdomain:
49-
server_name = res.subdomain.split('.', 1)[-1] + '.' + res.domain + '.' + res.tld
48+
server_name = f"{res.subdomain.split('.', 1)[-1]}.{res.domain}.{res.tld}"
5049
else:
51-
server_name = res.domain + '.' + res.tld
50+
server_name = f'{res.domain}.{res.tld}'
5251
async with cert_lock:
5352
if not server_name in cert_store:
5453
cm.create_certificate(server_name)
@@ -79,7 +78,7 @@ async def http_redirect(writer: asyncio.StreamWriter, path: str):
7978
if path.startswith(key):
8079
path = setting.config['http_redirect'][key] + path[len(key):]
8180
break
82-
logger.debug('Redirect to '+path)
81+
logger.debug('Redirect to %s', path)
8382
writer.write(f'HTTP/1.1 301 Moved Permanently\r\nLocation: https://{path}\r\n\r\n'.encode('iso-8859-1'))
8483
await writer.drain()
8584
writer.close()
@@ -95,18 +94,19 @@ async def forward_stream(reader: asyncio.StreamReader, writer: asyncio.StreamWri
9594

9695
async def handle(reader, writer):
9796
global context
98-
with closing(writer):
97+
try:
98+
remote_writer = None
9999
raw_request = await reader.readuntil(b'\r\n\r\n')
100100
requestline = raw_request.decode('iso-8859-1').splitlines()[0]
101101
i_addr, i_port, *_ = writer.get_extra_info('peername')
102-
logger.debug(f"{i_addr}:{i_port} say: {requestline}")
102+
logger.debug("%s:%d say: %s", i_addr, i_port, requestline)
103103
words = requestline.split()
104104
command, path = words[:2]
105105
match command:
106106
case 'CONNECT':
107107
host, port = path.split(':')
108108
remote_ip = await DNSquery(host)
109-
logger.debug(f'[{i_port:5}] DNS: {host} -> {remote_ip}')
109+
logger.debug('[%5d] DNS: %s -> %s', i_port, host, remote_ip)
110110
case 'GET':
111111
if path.startswith('/pac/'):
112112
return await send_pac(writer)
@@ -127,13 +127,13 @@ async def handle(reader, writer):
127127
writer._transport = await writer._loop.start_tls(writer.transport, writer._protocol, context, server_side=True)
128128
server_hostname_key = next(filter(lambda h:fnmatch.fnmatchcase(host, h), setting.config['alter_hostname']), None)
129129
server_hostname = '' if server_hostname_key is None else setting.config['alter_hostname'][server_hostname_key]
130-
logger.debug(f'[{i_port:5}] {server_hostname=}')
130+
logger.debug("[%5d] server_hostname: %s", i_port, server_hostname)
131131
remote_context = ssl.create_default_context()
132132
remote_context.check_hostname = False
133133
remote_reader, remote_writer = await asyncio.open_connection(remote_ip, port, ssl=remote_context, server_hostname=server_hostname)
134134
cert = remote_writer.get_extra_info('peercert')
135135
cert_message = f"subjectAltName: {cert.get('subjectAltName', ())}, subject: {cert.get('subject', ())}"
136-
logger.debug(f"[{i_port:5}] {cert_message}.")
136+
logger.debug("[%5d] %s", i_port, cert_message)
137137
cert_verify_key = next(filter(lambda h:fnmatch.fnmatchcase(host, h), setting.config.get('cert_verify', ())), None)
138138
if cert_verify_key is not None:
139139
cert_verify_list = setting.config['cert_verify'][cert_verify_key]
@@ -144,23 +144,29 @@ async def handle(reader, writer):
144144
else:
145145
cert_verify_list = [host]
146146
cert_policy = setting.config['check_hostname']
147-
if cert_policy is not False and not any(match_hostname(cert, h, cert_policy) for h in cert_verify_list):
148-
logger.warning(f"[{i_port:5}] {cert_verify_list} don't match either of {cert_message}.")
147+
if cert_policy is not False and not any(match_hostname(cert, h, cert_policy) for h in cert_verify_list):
148+
logger.warning("[%5d] %s don't match either of %s.", i_port, cert_verify_list, cert_message)
149149
return
150-
await asyncio.gather(
151-
forward_stream(reader, remote_writer),
152-
forward_stream(remote_reader, writer)
150+
tasks = (
151+
asyncio.create_task(forward_stream(reader, remote_writer)),
152+
asyncio.create_task(forward_stream(remote_reader, writer))
153153
)
154-
writer.close()
155-
remote_writer.close()
156-
await remote_writer.wait_closed()
157-
await writer.wait_closed()
154+
done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
155+
for task in pending:
156+
task.cancel()
157+
await asyncio.gather(*pending)
158+
finally:
159+
writer.close()
160+
if remote_writer:
161+
remote_writer.close()
162+
await remote_writer.wait_closed()
163+
await writer.wait_closed()
158164

159165
async def proxy():
160166
server = await asyncio.start_server(handle, setting.config['server']['address'], setting.config['server']['port'])
161167

162168
print(f"Serving on {', '.join(str(sock.getsockname()) for sock in server.sockets)}")
163-
sysproxy.set_pac('http://localhost:'+str(setting.config['server']['port'])+'/pac/?t='+str(random.randrange(2**16)))
169+
sysproxy.set_pac(f"http://localhost:{setting.config['server']['port']}/pac/?t={random.randrange(2**16)}")
164170

165171
try:
166172
async with server:
@@ -171,7 +177,7 @@ async def proxy():
171177
async def DNSquery(domain, hosts_only=False):
172178
global DNSresolver
173179
try:
174-
return next(v for k,v in setting.config['hosts'].items() if k==domain or (k.startswith('.') and domain.endswith(k)))
180+
return next(v for k, v in setting.config['hosts'].items() if k == domain or (k.startswith('.') and domain.endswith(k)))
175181
except StopIteration:
176182
if hosts_only:
177183
return
@@ -185,7 +191,7 @@ async def DNSquery(domain, hosts_only=False):
185191
return ret[0].to_text()
186192

187193
def update_checker():
188-
for pypi_url in ['https://pypi.org/pypi/accesser/json', 'https://mirrors.cloud.tencent.com/pypi/json/accesser']:
194+
for pypi_url in ('https://pypi.org/pypi/accesser/json', 'https://mirrors.cloud.tencent.com/pypi/json/accesser'):
189195
try:
190196
with request.urlopen(pypi_url) as f:
191197
v2 = Version(json.load(f)["info"]["version"])
@@ -197,7 +203,7 @@ def update_checker():
197203
v2 = Version(f.geturl().rsplit('/', maxsplit=1)[-1])
198204
v1 = Version(__version__)
199205
if v2 > v1:
200-
logger.warning("There is a new version, you can update with 'python3 -m pip install -U accesser' or download from GitHub")
206+
logger.warning("There is a new version. You can update with `python3 -m pip install -U accesser` or download from GitHub.")
201207

202208
async def main():
203209
global context, cert_store, cert_lock, DNSresolver
@@ -207,7 +213,7 @@ async def main():
207213
if setting.rules_update_case in ('old', 'missing'):
208214
logger.warning("Updated rules.toml because it is %s.", setting.rules_update_case)
209215
elif setting.rules_update_case == 'modified':
210-
logger.warning("You've already modified rules.toml, so it won't be updated automatically!")
216+
logger.warning("You've already modified rules.toml so it won't be updated automatically!")
211217
else:
212218
logger.debug("rules.toml status: %s", setting.rules_update_case)
213219

@@ -220,14 +226,15 @@ async def main():
220226
if (_url := urlsplit(nameserver)).netloc == '':
221227
_url = urlsplit('//' + nameserver)
222228
address = await DNSquery(_url.hostname, hosts_only=True)
223-
if _url.scheme == '':
224-
DNSresolver.nameservers.append(dns.nameserver.Do53Nameserver(_url.hostname if address is None else address, 53 if _url.port is None else _url.port))
225-
elif _url.scheme == 'https':
226-
DNSresolver.nameservers.append(dns.nameserver.DoHNameserver(nameserver, bootstrap_address=address))
227-
elif _url.scheme == 'tls':
228-
DNSresolver.nameservers.append(dns.nameserver.DoTNameserver(_url.hostname if address is None else address, 853 if _url.port is None else _url.port, hostname=_url.hostname))
229-
elif _url.scheme == 'quic':
230-
DNSresolver.nameservers.append(dns.nameserver.DoQNameserver(_url.hostname if address is None else address, 853 if _url.port is None else _url.port, server_hostname=_url.hostname))
229+
match _url.scheme:
230+
case '':
231+
DNSresolver.nameservers.append(dns.nameserver.Do53Nameserver(_url.hostname if address is None else address, 53 if _url.port is None else _url.port))
232+
case 'https':
233+
DNSresolver.nameservers.append(dns.nameserver.DoHNameserver(nameserver, bootstrap_address=address))
234+
case 'tls':
235+
DNSresolver.nameservers.append(dns.nameserver.DoTNameserver(_url.hostname if address is None else address, 853 if _url.port is None else _url.port, hostname=_url.hostname))
236+
case 'quic':
237+
DNSresolver.nameservers.append(dns.nameserver.DoQNameserver(_url.hostname if address is None else address, 853 if _url.port is None else _url.port, server_hostname=_url.hostname))
231238

232239
importca.import_ca()
233240

0 commit comments

Comments
 (0)