2525import ssl
2626import asyncio
2727import traceback
28- from contextlib import closing
2928from urllib import request
3029from urllib .parse import urlsplit
3130from packaging .version import Version
@@ -46,9 +45,9 @@ async def update_cert(server_name):
4645 if not is_tld (server_name ):
4746 res = get_tld (server_name , as_object = True , fix_protocol = True )
4847 if res .subdomain :
49- server_name = res .subdomain .split ('.' , 1 )[- 1 ] + '.' + res .domain + '.' + res .tld
48+ server_name = f" { res .subdomain .split ('.' , 1 )[- 1 ]} . { res .domain } . { res .tld } "
5049 else :
51- server_name = res .domain + '.' + res .tld
50+ server_name = f' { res .domain } . { res .tld } '
5251 async with cert_lock :
5352 if not server_name in cert_store :
5453 cm .create_certificate (server_name )
@@ -79,7 +78,7 @@ async def http_redirect(writer: asyncio.StreamWriter, path: str):
7978 if path .startswith (key ):
8079 path = setting .config ['http_redirect' ][key ] + path [len (key ):]
8180 break
82- logger .debug ('Redirect to ' + path )
81+ logger .debug ('Redirect to %s' , path )
8382 writer .write (f'HTTP/1.1 301 Moved Permanently\r \n Location: https://{ path } \r \n \r \n ' .encode ('iso-8859-1' ))
8483 await writer .drain ()
8584 writer .close ()
@@ -95,18 +94,19 @@ async def forward_stream(reader: asyncio.StreamReader, writer: asyncio.StreamWri
9594
9695async def handle (reader , writer ):
9796 global context
98- with closing (writer ):
97+ try :
98+ remote_writer = None
9999 raw_request = await reader .readuntil (b'\r \n \r \n ' )
100100 requestline = raw_request .decode ('iso-8859-1' ).splitlines ()[0 ]
101101 i_addr , i_port , * _ = writer .get_extra_info ('peername' )
102- logger .debug (f" { i_addr } : { i_port } say: { requestline } " )
102+ logger .debug ("%s:%d say: %s" , i_addr , i_port , requestline )
103103 words = requestline .split ()
104104 command , path = words [:2 ]
105105 match command :
106106 case 'CONNECT' :
107107 host , port = path .split (':' )
108108 remote_ip = await DNSquery (host )
109- logger .debug (f'[ { i_port :5 } ] DNS: { host } -> { remote_ip } ' )
109+ logger .debug ('[%5d ] DNS: %s -> %s' , i_port , host , remote_ip )
110110 case 'GET' :
111111 if path .startswith ('/pac/' ):
112112 return await send_pac (writer )
@@ -127,13 +127,13 @@ async def handle(reader, writer):
127127 writer ._transport = await writer ._loop .start_tls (writer .transport , writer ._protocol , context , server_side = True )
128128 server_hostname_key = next (filter (lambda h :fnmatch .fnmatchcase (host , h ), setting .config ['alter_hostname' ]), None )
129129 server_hostname = '' if server_hostname_key is None else setting .config ['alter_hostname' ][server_hostname_key ]
130- logger .debug (f'[ { i_port :5 } ] { server_hostname = } ' )
130+ logger .debug ("[%5d] server_hostname: %s" , i_port , server_hostname )
131131 remote_context = ssl .create_default_context ()
132132 remote_context .check_hostname = False
133133 remote_reader , remote_writer = await asyncio .open_connection (remote_ip , port , ssl = remote_context , server_hostname = server_hostname )
134134 cert = remote_writer .get_extra_info ('peercert' )
135135 cert_message = f"subjectAltName: { cert .get ('subjectAltName' , ())} , subject: { cert .get ('subject' , ())} "
136- logger .debug (f"[ { i_port :5 } ] { cert_message } ." )
136+ logger .debug ("[%5d] %s" , i_port , cert_message )
137137 cert_verify_key = next (filter (lambda h :fnmatch .fnmatchcase (host , h ), setting .config .get ('cert_verify' , ())), None )
138138 if cert_verify_key is not None :
139139 cert_verify_list = setting .config ['cert_verify' ][cert_verify_key ]
@@ -144,23 +144,29 @@ async def handle(reader, writer):
144144 else :
145145 cert_verify_list = [host ]
146146 cert_policy = setting .config ['check_hostname' ]
147- if cert_policy is not False and not any (match_hostname (cert , h , cert_policy ) for h in cert_verify_list ):
148- logger .warning (f"[ { i_port :5 } ] { cert_verify_list } don't match either of { cert_message } ." )
147+ if cert_policy is not False and not any (match_hostname (cert , h , cert_policy ) for h in cert_verify_list ):
148+ logger .warning ("[%5d] %s don't match either of %s." , i_port , cert_verify_list , cert_message )
149149 return
150- await asyncio . gather (
151- forward_stream (reader , remote_writer ),
152- forward_stream (remote_reader , writer )
150+ tasks = (
151+ asyncio . create_task ( forward_stream (reader , remote_writer ) ),
152+ asyncio . create_task ( forward_stream (remote_reader , writer ) )
153153 )
154- writer .close ()
155- remote_writer .close ()
156- await remote_writer .wait_closed ()
157- await writer .wait_closed ()
154+ done , pending = await asyncio .wait (tasks , return_when = asyncio .FIRST_COMPLETED )
155+ for task in pending :
156+ task .cancel ()
157+ await asyncio .gather (* pending )
158+ finally :
159+ writer .close ()
160+ if remote_writer :
161+ remote_writer .close ()
162+ await remote_writer .wait_closed ()
163+ await writer .wait_closed ()
158164
159165async def proxy ():
160166 server = await asyncio .start_server (handle , setting .config ['server' ]['address' ], setting .config ['server' ]['port' ])
161167
162168 print (f"Serving on { ', ' .join (str (sock .getsockname ()) for sock in server .sockets )} " )
163- sysproxy .set_pac (' http://localhost:' + str ( setting .config ['server' ]['port' ]) + ' /pac/?t=' + str ( random .randrange (2 ** 16 )) )
169+ sysproxy .set_pac (f" http://localhost:{ setting .config ['server' ]['port' ]} /pac/?t={ random .randrange (2 ** 16 )} " )
164170
165171 try :
166172 async with server :
@@ -171,7 +177,7 @@ async def proxy():
171177async def DNSquery (domain , hosts_only = False ):
172178 global DNSresolver
173179 try :
174- return next (v for k ,v in setting .config ['hosts' ].items () if k == domain or (k .startswith ('.' ) and domain .endswith (k )))
180+ return next (v for k , v in setting .config ['hosts' ].items () if k == domain or (k .startswith ('.' ) and domain .endswith (k )))
175181 except StopIteration :
176182 if hosts_only :
177183 return
@@ -185,7 +191,7 @@ async def DNSquery(domain, hosts_only=False):
185191 return ret [0 ].to_text ()
186192
187193def update_checker ():
188- for pypi_url in [ 'https://pypi.org/pypi/accesser/json' , 'https://mirrors.cloud.tencent.com/pypi/json/accesser' ] :
194+ for pypi_url in ( 'https://pypi.org/pypi/accesser/json' , 'https://mirrors.cloud.tencent.com/pypi/json/accesser' ) :
189195 try :
190196 with request .urlopen (pypi_url ) as f :
191197 v2 = Version (json .load (f )["info" ]["version" ])
@@ -197,7 +203,7 @@ def update_checker():
197203 v2 = Version (f .geturl ().rsplit ('/' , maxsplit = 1 )[- 1 ])
198204 v1 = Version (__version__ )
199205 if v2 > v1 :
200- logger .warning ("There is a new version, you can update with ' python3 -m pip install -U accesser' or download from GitHub" )
206+ logger .warning ("There is a new version. You can update with ` python3 -m pip install -U accesser` or download from GitHub. " )
201207
202208async def main ():
203209 global context , cert_store , cert_lock , DNSresolver
@@ -207,7 +213,7 @@ async def main():
207213 if setting .rules_update_case in ('old' , 'missing' ):
208214 logger .warning ("Updated rules.toml because it is %s." , setting .rules_update_case )
209215 elif setting .rules_update_case == 'modified' :
210- logger .warning ("You've already modified rules.toml, so it won't be updated automatically!" )
216+ logger .warning ("You've already modified rules.toml so it won't be updated automatically!" )
211217 else :
212218 logger .debug ("rules.toml status: %s" , setting .rules_update_case )
213219
@@ -220,14 +226,15 @@ async def main():
220226 if (_url := urlsplit (nameserver )).netloc == '' :
221227 _url = urlsplit ('//' + nameserver )
222228 address = await DNSquery (_url .hostname , hosts_only = True )
223- if _url .scheme == '' :
224- DNSresolver .nameservers .append (dns .nameserver .Do53Nameserver (_url .hostname if address is None else address , 53 if _url .port is None else _url .port ))
225- elif _url .scheme == 'https' :
226- DNSresolver .nameservers .append (dns .nameserver .DoHNameserver (nameserver , bootstrap_address = address ))
227- elif _url .scheme == 'tls' :
228- DNSresolver .nameservers .append (dns .nameserver .DoTNameserver (_url .hostname if address is None else address , 853 if _url .port is None else _url .port , hostname = _url .hostname ))
229- elif _url .scheme == 'quic' :
230- DNSresolver .nameservers .append (dns .nameserver .DoQNameserver (_url .hostname if address is None else address , 853 if _url .port is None else _url .port , server_hostname = _url .hostname ))
229+ match _url .scheme :
230+ case '' :
231+ DNSresolver .nameservers .append (dns .nameserver .Do53Nameserver (_url .hostname if address is None else address , 53 if _url .port is None else _url .port ))
232+ case 'https' :
233+ DNSresolver .nameservers .append (dns .nameserver .DoHNameserver (nameserver , bootstrap_address = address ))
234+ case 'tls' :
235+ DNSresolver .nameservers .append (dns .nameserver .DoTNameserver (_url .hostname if address is None else address , 853 if _url .port is None else _url .port , hostname = _url .hostname ))
236+ case 'quic' :
237+ DNSresolver .nameservers .append (dns .nameserver .DoQNameserver (_url .hostname if address is None else address , 853 if _url .port is None else _url .port , server_hostname = _url .hostname ))
231238
232239 importca .import_ca ()
233240
0 commit comments