diff --git a/configloader.py b/configloader.py index cf9d61961..1794641a8 100644 --- a/configloader.py +++ b/configloader.py @@ -1,5 +1,5 @@ #!/usr/bin/python -# -*- coding: UTF-8 -*- +# -*- coding: utf-8 -*- import importloader g_config = None diff --git a/db_transfer.py b/db_transfer.py index 67bda6083..e1a404c6a 100644 --- a/db_transfer.py +++ b/db_transfer.py @@ -1,5 +1,5 @@ #!/usr/bin/python -# -*- coding: UTF-8 -*- +# -*- coding: utf-8 -*- import logging import time @@ -9,6 +9,7 @@ from shadowsocks import common, shell, lru_cache, obfs from configloader import load_config, get_config import importloader +import copy switchrule = None db_instance = None @@ -80,8 +81,10 @@ def push_db_all_user(self): def del_server_out_of_bound_safe(self, last_rows, rows): #停止超流量的服务 #启动没超流量的服务 + keymap = {} try: switchrule = importloader.load('switchrule') + keymap = switchrule.getRowMap() except Exception as e: logging.error('load switchrule.py fail') cur_servers = {} @@ -106,7 +109,10 @@ def del_server_out_of_bound_safe(self, last_rows, rows): read_config_keys = ['method', 'obfs', 'obfs_param', 'protocol', 'protocol_param', 'forbidden_ip', 'forbidden_port', 'speed_limit_per_con', 'speed_limit_per_user'] for name in read_config_keys: if name in row and row[name]: - cfg[name] = row[name] + if name in keymap: + cfg[keymap[name]] = row[name] + else: + cfg[name] = row[name] merge_config_keys = ['password'] + read_config_keys for name in cfg.keys(): @@ -392,11 +398,17 @@ def pull_db_all_user(self): return rows def pull_db_users(self, conn): + keys = copy.copy(self.key_list) try: switchrule = importloader.load('switchrule') - keys = switchrule.getKeys(self.key_list) + keymap = switchrule.getRowMap() + for key in keymap: + if keymap[key] in keys: + keys.remove(keymap[key]) + keys.append(key) + keys = switchrule.getKeys(keys) except Exception as e: - keys = self.key_list + logging.error('load switchrule.py fail') cur = conn.cursor() cur.execute("SELECT " + ','.join(keys) + " FROM user") @@ -520,11 +532,17 @@ def update_all_user(self, dt_transfer): return update_transfer def pull_db_users(self, conn): + keys = copy.copy(self.key_list) try: switchrule = importloader.load('switchrule') - keys = switchrule.getKeys(self.key_list) + keymap = switchrule.getRowMap() + for key in keymap: + if keymap[key] in keys: + keys.remove(keymap[key]) + keys.append(key) + keys = switchrule.getKeys(keys) except Exception as e: - keys = self.key_list + logging.error('load switchrule.py fail') cur = conn.cursor() diff --git a/importloader.py b/importloader.py index c917cb7d9..cedc526a7 100644 --- a/importloader.py +++ b/importloader.py @@ -1,5 +1,5 @@ #!/usr/bin/python -# -*- coding: UTF-8 -*- +# -*- coding: utf-8 -*- def load(name): try: diff --git a/server_pool.py b/server_pool.py index d159817a3..fcf6caf06 100644 --- a/server_pool.py +++ b/server_pool.py @@ -117,14 +117,14 @@ def new_server(self, port, user_config): else: a_config = self.config.copy() a_config.update(user_config) - if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][-1] == "]": + if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == b"[" and a_config['server_ipv6'][-1] == b"]": a_config['server_ipv6'] = a_config['server_ipv6'][1:-1] - a_config['server'] = a_config['server_ipv6'] + a_config['server'] = common.to_str(a_config['server_ipv6']) a_config['server_port'] = port a_config['max_connect'] = 128 a_config['method'] = common.to_str(a_config['method']) try: - logging.info("starting server at [%s]:%d" % (common.to_str(a_config['server']), port)) + logging.info("starting server at [%s]:%d" % (a_config['server'], port)) tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False, stat_counter=self.stat_counter) tcp_server.add_to_loop(self.loop) @@ -134,7 +134,7 @@ def new_server(self, port, user_config): udp_server.add_to_loop(self.loop) self.udp_ipv6_servers_pool.update({port: udp_server}) - if common.to_str(a_config['server_ipv6']) == "::": + if a_config['server_ipv6'] == "::": ipv6_ok = True except Exception as e: logging.warn("IPV6 %s " % (e,)) @@ -150,7 +150,7 @@ def new_server(self, port, user_config): a_config['max_connect'] = 128 a_config['method'] = common.to_str(a_config['method']) try: - logging.info("starting server at %s:%d" % (common.to_str(a_config['server']), port)) + logging.info("starting server at %s:%d" % (a_config['server'], port)) tcp_server = tcprelay.TCPRelay(a_config, self.dns_resolver, False) tcp_server.add_to_loop(self.loop) diff --git a/shadowsocks/asyncdns.py b/shadowsocks/asyncdns.py index 797704e35..868ea6140 100644 --- a/shadowsocks/asyncdns.py +++ b/shadowsocks/asyncdns.py @@ -27,12 +27,12 @@ if __name__ == '__main__': import sys import inspect + file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) sys.path.insert(0, os.path.join(file_path, '../')) from shadowsocks import common, lru_cache, eventloop, shell - CACHE_SWEEP_INTERVAL = 30 VALID_HOSTNAME = re.compile(br"(?!-)[A-Z\d_-]{1,63}(?> %s', hostname, self._cache[hostname]) ip = self._cache[hostname] callback((hostname, ip), None) + elif any(hostname.endswith(t) for t in self._black_hostname_list): + callback(None, Exception('hostname <%s> is block by the black hostname list' % hostname)) + return else: if not is_valid_hostname(hostname): callback(None, Exception('invalid hostname: %s' % hostname)) return if False: addrs = socket.getaddrinfo(hostname, 0, 0, - socket.SOCK_DGRAM, socket.SOL_UDP) + socket.SOCK_DGRAM, socket.SOL_UDP) if addrs: af, socktype, proto, canonname, sa = addrs[0] - logging.debug('DNS resolve %s %s' % (hostname, sa[0]) ) + logging.debug('DNS resolve %s %s' % (hostname, sa[0])) self._cache[hostname] = sa[0] callback((hostname, sa[0]), None) return @@ -506,7 +520,11 @@ def close(self): def test(): - dns_resolver = DNSResolver() + black_hostname_list = [ + 'baidu.com', + 'yahoo.com', + ] + dns_resolver = DNSResolver(black_hostname_list=black_hostname_list) loop = eventloop.EventLoop() dns_resolver.add_to_loop(loop) @@ -521,16 +539,20 @@ def callback(result, error): # TODO: what can we assert? print(result, error) counter += 1 - if counter == 9: + if counter == 12: dns_resolver.close() loop.stop() + a_callback = callback return a_callback - assert(make_callback() != make_callback()) + assert (make_callback() != make_callback()) dns_resolver.resolve(b'google.com', make_callback()) dns_resolver.resolve('google.com', make_callback()) + dns_resolver.resolve('baidu.com', make_callback()) + dns_resolver.resolve('map.baidu.com', make_callback()) + dns_resolver.resolve('yahoo.com', make_callback()) dns_resolver.resolve('example.com', make_callback()) dns_resolver.resolve('ipv6.google.com', make_callback()) dns_resolver.resolve('www.facebook.com', make_callback()) @@ -546,10 +568,25 @@ def callback(result, error): 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' 'ooooooooooooooooooooooooooooooooooooooooooooooooooo' 'long.hostname', make_callback()) - loop.run() + # test black_hostname_list + dns_resolver = DNSResolver(black_hostname_list=[]) + assert type(dns_resolver._black_hostname_list) == list + assert len(dns_resolver._black_hostname_list) == 0 + dns_resolver.close() + dns_resolver = DNSResolver(black_hostname_list=123) + assert type(dns_resolver._black_hostname_list) == list + assert len(dns_resolver._black_hostname_list) == 0 + dns_resolver.close() + dns_resolver = DNSResolver(black_hostname_list=None) + assert type(dns_resolver._black_hostname_list) == list + assert len(dns_resolver._black_hostname_list) == 0 + dns_resolver.close() + dns_resolver = DNSResolver() + assert type(dns_resolver._black_hostname_list) == list + assert dns_resolver._black_hostname_list.__len__() == 0 + dns_resolver.close() if __name__ == '__main__': test() - diff --git a/shadowsocks/common.py b/shadowsocks/common.py index c4484c046..5c1bb7601 100644 --- a/shadowsocks/common.py +++ b/shadowsocks/common.py @@ -121,7 +121,19 @@ def is_ip(address): return False +def sync_str_bytes(obj, target_example): + """sync (obj)'s type to (target_example)'s type""" + if type(obj) != type(target_example): + if type(target_example) == str: + obj = to_str(obj) + if type(target_example) == bytes: + obj = to_bytes(obj) + return obj + + def match_regex(regex, text): + # avoid 'cannot use a string pattern on a bytes-like object' + regex = sync_str_bytes(regex, text) regex = re.compile(regex) for item in regex.findall(text): return True @@ -253,7 +265,7 @@ def __init__(self, addrs): list(map(self.add_network, addrs)) def add_network(self, addr): - if addr is "": + if addr == "": return block = addr.split('/') addr_family = is_ip(block[0]) @@ -265,9 +277,9 @@ def add_network(self, addr): ip = (hi << 64) | lo else: raise Exception("Not a valid CIDR notation: %s" % addr) - if len(block) is 1: + if len(block) == 1: prefix_size = 0 - while (ip & 1) == 0 and ip is not 0: + while (ip & 1) == 0 and ip != 0: ip >>= 1 prefix_size += 1 logging.warn("You did't specify CIDR routing prefix size for %s, " @@ -381,12 +393,12 @@ def test_inet_conv(): def test_parse_header(): assert parse_header(b'\x03\x0ewww.google.com\x00\x50') == \ - (0, b'www.google.com', 80, 18) + (0, ADDRTYPE_HOST, b'www.google.com', 80, 18) assert parse_header(b'\x01\x08\x08\x08\x08\x00\x35') == \ - (0, b'8.8.8.8', 53, 7) + (0, ADDRTYPE_IPV4, b'8.8.8.8', 53, 7) assert parse_header((b'\x04$\x04h\x00@\x05\x08\x05\x00\x00\x00\x00\x00' b'\x00\x10\x11\x00\x50')) == \ - (0, b'2404:6800:4005:805::1011', 80, 19) + (0, ADDRTYPE_IPV6, b'2404:6800:4005:805::1011', 80, 19) def test_pack_header(): @@ -411,7 +423,25 @@ def test_ip_network(): assert 'www.google.com' not in ip_network +def test_sync_str_bytes(): + assert sync_str_bytes(b'a\.b', b'a\.b') == b'a\.b' + assert sync_str_bytes('a\.b', b'a\.b') == b'a\.b' + assert sync_str_bytes(b'a\.b', 'a\.b') == 'a\.b' + assert sync_str_bytes('a\.b', 'a\.b') == 'a\.b' + pass + + +def test_match_regex(): + assert match_regex(br'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b') + assert match_regex(r'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b') + assert match_regex(br'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b') + assert match_regex(r'a\.b', b'abc,aaa,aaa,b,aaa.b,a.b') + assert match_regex(r'\bgoogle\.com\b', b' google.com ') + pass + if __name__ == '__main__': + test_sync_str_bytes() + test_match_regex() test_inet_conv() test_parse_header() test_pack_header() diff --git a/shadowsocks/crypto/openssl.py b/shadowsocks/crypto/openssl.py index 0a8ca53fb..e4980a402 100644 --- a/shadowsocks/crypto/openssl.py +++ b/shadowsocks/crypto/openssl.py @@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function, \ with_statement -from ctypes import c_char_p, c_int, c_long, byref,\ +from ctypes import c_char_p, c_int, c_long, byref, \ create_string_buffer, c_void_p from shadowsocks import common @@ -30,9 +30,11 @@ buf_size = 2048 +ctx_cleanup = None + def load_openssl(): - global loaded, libcrypto, buf + global loaded, libcrypto, buf, ctx_cleanup libcrypto = util.find_library(('crypto', 'eay32'), 'EVP_get_cipherbyname', @@ -51,8 +53,10 @@ def load_openssl(): if hasattr(libcrypto, "EVP_CIPHER_CTX_cleanup"): libcrypto.EVP_CIPHER_CTX_cleanup.argtypes = (c_void_p,) + ctx_cleanup = libcrypto.EVP_CIPHER_CTX_cleanup else: libcrypto.EVP_CIPHER_CTX_reset.argtypes = (c_void_p,) + ctx_cleanup = libcrypto.EVP_CIPHER_CTX_reset libcrypto.EVP_CIPHER_CTX_free.argtypes = (c_void_p,) libcrypto.RAND_bytes.restype = c_int @@ -73,6 +77,7 @@ def load_cipher(cipher_name): return cipher() return None + def rand_bytes(length): if not loaded: load_openssl() @@ -82,6 +87,7 @@ def rand_bytes(length): raise Exception('RAND_bytes return error') return buf.raw + class OpenSSLCrypto(object): def __init__(self, cipher_name, key, iv, op): self._ctx = None @@ -120,17 +126,20 @@ def __del__(self): def clean(self): if self._ctx: - if hasattr(libcrypto, "EVP_CIPHER_CTX_cleanup"): - libcrypto.EVP_CIPHER_CTX_cleanup(self._ctx) - else: - libcrypto.EVP_CIPHER_CTX_reset(self._ctx) + ctx_cleanup(self._ctx) libcrypto.EVP_CIPHER_CTX_free(self._ctx) + self._ctx = None ciphers = { + # CBC mode need a special use way that different from other. + # CBC mode encrypt message with 16n length, and need 16n+1 length space to decrypt it , otherwise don't decrypt it 'aes-128-cbc': (16, 16, OpenSSLCrypto), 'aes-192-cbc': (24, 16, OpenSSLCrypto), 'aes-256-cbc': (32, 16, OpenSSLCrypto), + 'aes-128-gcm': (16, 16, OpenSSLCrypto), + 'aes-192-gcm': (24, 16, OpenSSLCrypto), + 'aes-256-gcm': (32, 16, OpenSSLCrypto), 'aes-128-cfb': (16, 16, OpenSSLCrypto), 'aes-192-cfb': (24, 16, OpenSSLCrypto), 'aes-256-cfb': (32, 16, OpenSSLCrypto), @@ -160,7 +169,6 @@ def clean(self): def run_method(method): - cipher = OpenSSLCrypto(method, b'k' * 32, b'i' * 16, 1) decipher = OpenSSLCrypto(method, b'k' * 32, b'i' * 16, 0) @@ -195,5 +203,20 @@ def test_rc4(): run_method('rc4') +def test_all(): + for k, v in ciphers.items(): + print(k) + try: + run_method(k) + except AssertionError as e: + eprint("AssertionError===========" + k) + eprint(e) + + +def eprint(*args, **kwargs): + import sys + print(*args, file=sys.stderr, **kwargs) + + if __name__ == '__main__': - test_aes_128_cfb() + test_all() diff --git a/shadowsocks/crypto/sodium.py b/shadowsocks/crypto/sodium.py index 51d476bed..390941c39 100644 --- a/shadowsocks/crypto/sodium.py +++ b/shadowsocks/crypto/sodium.py @@ -20,6 +20,8 @@ from ctypes import c_char_p, c_int, c_ulong, c_ulonglong, byref, \ create_string_buffer, c_void_p +import logging + from shadowsocks.crypto import util __all__ = ['ciphers'] @@ -55,10 +57,31 @@ def load_libsodium(): try: libsodium.crypto_stream_chacha20_ietf_xor_ic.restype = c_int libsodium.crypto_stream_chacha20_ietf_xor_ic.argtypes = (c_void_p, c_char_p, - c_ulonglong, - c_char_p, c_ulong, - c_char_p) + c_ulonglong, + c_char_p, c_ulong, + c_char_p) + except: + logging.info("ChaCha20 IETF not support.") + pass + + try: + libsodium.crypto_stream_xsalsa20_xor_ic.restype = c_int + libsodium.crypto_stream_xsalsa20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) + except: + logging.info("XSalsa20 not support.") + pass + + try: + libsodium.crypto_stream_xchacha20_xor_ic.restype = c_int + libsodium.crypto_stream_xchacha20_xor_ic.argtypes = (c_void_p, c_char_p, + c_ulonglong, + c_char_p, c_ulonglong, + c_char_p) except: + logging.info("XChaCha20 not support. XChaCha20 only support since libsodium v1.0.12") pass buf = create_string_buffer(buf_size) @@ -79,6 +102,10 @@ def __init__(self, cipher_name, key, iv, op): self.cipher = libsodium.crypto_stream_chacha20_xor_ic elif cipher_name == 'chacha20-ietf': self.cipher = libsodium.crypto_stream_chacha20_ietf_xor_ic + elif cipher_name == 'xchacha20': + self.cipher = libsodium.crypto_stream_xchacha20_xor_ic + elif cipher_name == 'xsalsa20': + self.cipher = libsodium.crypto_stream_xsalsa20_xor_ic else: raise Exception('Unknown cipher') # byte counter, not block counter @@ -104,11 +131,16 @@ def update(self, data): # strip off the padding return buf.raw[padding:padding + l] + def clean(self): + pass + ciphers = { 'salsa20': (32, 8, SodiumCrypto), 'chacha20': (32, 8, SodiumCrypto), 'chacha20-ietf': (32, 12, SodiumCrypto), + 'xchacha20': (32, 24, SodiumCrypto), + 'xsalsa20': (32, 24, SodiumCrypto), } @@ -120,7 +152,6 @@ def test_salsa20(): def test_chacha20(): - cipher = SodiumCrypto('chacha20', b'k' * 32, b'i' * 16, 1) decipher = SodiumCrypto('chacha20', b'k' * 32, b'i' * 16, 0) @@ -128,13 +159,29 @@ def test_chacha20(): def test_chacha20_ietf(): - cipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 1) decipher = SodiumCrypto('chacha20-ietf', b'k' * 32, b'i' * 16, 0) util.run_cipher(cipher, decipher) + +def test_xchacha20(): + cipher = SodiumCrypto('xchacha20', b'k' * 32, b'i' * 24, 1) + decipher = SodiumCrypto('xchacha20', b'k' * 32, b'i' * 24, 0) + + util.run_cipher(cipher, decipher) + + +def test_xsalsa20(): + cipher = SodiumCrypto('xsalsa20', b'k' * 32, b'i' * 24, 1) + decipher = SodiumCrypto('xsalsa20', b'k' * 32, b'i' * 24, 0) + + util.run_cipher(cipher, decipher) + + if __name__ == '__main__': test_chacha20_ietf() test_chacha20() test_salsa20() + test_xchacha20() + test_xsalsa20() diff --git a/shadowsocks/crypto/table.py b/shadowsocks/crypto/table.py index 60c2f2451..c7a6e8437 100644 --- a/shadowsocks/crypto/table.py +++ b/shadowsocks/crypto/table.py @@ -65,6 +65,9 @@ def update(self, data): else: return translate(data, self._decrypt_table) + def clean(self): + pass + class NoneCipher(object): def __init__(self, cipher_name, key, iv, op): pass @@ -72,6 +75,9 @@ def __init__(self, cipher_name, key, iv, op): def update(self, data): return data + def clean(self): + pass + ciphers = { 'none': (16, 0, NoneCipher), 'table': (16, 0, TableCipher) diff --git a/shadowsocks/crypto/util.py b/shadowsocks/crypto/util.py index 212df8604..12d9b7eae 100644 --- a/shadowsocks/crypto/util.py +++ b/shadowsocks/crypto/util.py @@ -22,6 +22,13 @@ def find_library_nt(name): + # type: (str) -> list + """ + find lib in windows in all the directory in path env + + :param name: can end with `.dll` or not + :return: lib results list + """ # modified from ctypes.util # ctypes.util.find_library just returns first result he found # but we want to try them all @@ -61,7 +68,9 @@ def find_library(possible_lib_names, search_symbol, library_name): if path: paths.append(path) - if not paths: + # always find lib on extend path that to avoid ```CDLL()``` failed on some strange linux environment + # in that case ```ctypes.util.find_library()``` have different find path from ```CDLL()``` + if True: # We may get here when find_library fails because, for example, # the user does not have sufficient privileges to access those # tools underlying find_library on linux. diff --git a/shadowsocks/encrypt.py b/shadowsocks/encrypt.py index 44f905250..a9fd02aac 100644 --- a/shadowsocks/encrypt.py +++ b/shadowsocks/encrypt.py @@ -22,7 +22,7 @@ import hashlib import logging -from shadowsocks import common +from shadowsocks import common, lru_cache from shadowsocks.crypto import rc4_md5, openssl, sodium, table @@ -39,18 +39,16 @@ def random_string(length): except NotImplementedError as e: return openssl.rand_bytes(length) -cached_keys = {} +cached_keys = lru_cache.LRUCache(timeout=180) def try_cipher(key, method=None): Encryptor(key, method) -def EVP_BytesToKey(password, key_len, iv_len): +def EVP_BytesToKey(password, key_len, iv_len, cache): # equivalent to OpenSSL's EVP_BytesToKey() with count 1 # so that we make the same key and iv as nodejs version - if hasattr(password, 'encode'): - password = password.encode('utf-8') cached_key = '%s-%d-%d' % (password, key_len, iv_len) r = cached_keys.get(cached_key, None) if r: @@ -68,12 +66,14 @@ def EVP_BytesToKey(password, key_len, iv_len): ms = b''.join(m) key = ms[:key_len] iv = ms[key_len:key_len + iv_len] - cached_keys[cached_key] = (key, iv) + if cache: + cached_keys[cached_key] = (key, iv) + cached_keys.sweep() return key, iv class Encryptor(object): - def __init__(self, key, method, iv = None): + def __init__(self, key, method, iv = None, cache = False): self.key = key self.method = method self.iv = None @@ -82,6 +82,7 @@ def __init__(self, key, method, iv = None): self.iv_buf = b'' self.cipher_key = b'' self.decipher = None + self.cache = cache method = method.lower() self._method_info = self.get_method_info(method) if self._method_info: @@ -106,7 +107,7 @@ def get_cipher(self, password, method, op, iv): password = common.to_bytes(password) m = self._method_info if m[0] > 0: - key, iv_ = EVP_BytesToKey(password, m[0], m[1]) + key, iv_ = EVP_BytesToKey(password, m[0], m[1], self.cache) else: # key_length == 0 indicates we should use the key directly key, iv = password, b'' @@ -120,6 +121,9 @@ def get_cipher(self, password, method, op, iv): def encrypt(self, buf): if len(buf) == 0: + if not self.iv_sent: + self.iv_sent = True + return self.cipher_iv return buf if self.iv_sent: return self.cipher.update(buf) @@ -146,12 +150,17 @@ def decrypt(self, buf): else: return b'' + def dispose(self): + if self.decipher is not None: + self.decipher.clean() + self.decipher = None + def encrypt_all(password, method, op, data): result = [] method = method.lower() (key_len, iv_len, m) = method_supported[method] if key_len > 0: - key, _ = EVP_BytesToKey(password, key_len, iv_len) + key, _ = EVP_BytesToKey(password, key_len, iv_len, True) else: key = password if op: @@ -168,7 +177,7 @@ def encrypt_key(password, method): method = method.lower() (key_len, iv_len, m) = method_supported[method] if key_len > 0: - key, _ = EVP_BytesToKey(password, key_len, iv_len) + key, _ = EVP_BytesToKey(password, key_len, iv_len, True) else: key = password return key diff --git a/shadowsocks/lru_cache.py b/shadowsocks/lru_cache.py index ab0d21086..b9e1fefca 100644 --- a/shadowsocks/lru_cache.py +++ b/shadowsocks/lru_cache.py @@ -41,6 +41,12 @@ SWEEP_MAX_ITEMS = 1024 +# https://stackoverflow.com/questions/70943244/attributeerror-module-collections-has-no-attribute-mutablemapping +import sys +if sys.version_info.major == 3 and sys.version_info.minor >= 10: + import collections + setattr(collections, "MutableMapping", collections.abc.MutableMapping) + class LRUCache(collections.MutableMapping): """This class is not thread safe""" diff --git a/shadowsocks/obfs.py b/shadowsocks/obfs.py index 3dfdb141f..f8ee2d397 100644 --- a/shadowsocks/obfs.py +++ b/shadowsocks/obfs.py @@ -23,7 +23,7 @@ import logging from shadowsocks import common -from shadowsocks.obfsplugin import plain, http_simple, obfs_tls, verify, auth, auth_chain +from shadowsocks.obfsplugin import plain, http_simple, obfs_tls, verify, auth, auth_chain, auth_akarin method_supported = {} @@ -33,9 +33,12 @@ method_supported.update(verify.obfs_map) method_supported.update(auth.obfs_map) method_supported.update(auth_chain.obfs_map) +method_supported.update(auth_akarin.obfs_map) def mu_protocol(): - return ["auth_aes128_md5", "auth_aes128_sha1", "auth_chain_a"] + return {"auth_aes128_md5", "auth_aes128_sha1", + "auth_chain_a", "auth_chain_b", "auth_chain_c", "auth_chain_d", "auth_chain_e", "auth_chain_f", + "auth_akarin_rand", "auth_akarin_spec_a"} class server_info(object): def __init__(self, data): diff --git a/shadowsocks/obfsplugin/auth_akarin.py b/shadowsocks/obfsplugin/auth_akarin.py new file mode 100644 index 000000000..58f84a30c --- /dev/null +++ b/shadowsocks/obfsplugin/auth_akarin.py @@ -0,0 +1,801 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2018-2018 Akkariin +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +from __future__ import absolute_import, division, print_function, \ + with_statement + +import hashlib +import logging +import binascii +import base64 +import time +import datetime +import random +import math +import struct +import hmac +import bisect + +import shadowsocks +from shadowsocks import common, lru_cache, encrypt +from shadowsocks.obfsplugin import plain +from shadowsocks.common import to_bytes, to_str, ord, chr +from shadowsocks.crypto import openssl + +rand_bytes = openssl.rand_bytes + +def create_auth_akarin_rand(method): + return auth_akarin_rand(method) + + +def create_auth_akarin_spec_a(method): + return auth_akarin_spec_a(method) + + +obfs_map = { + 'auth_akarin_rand': (create_auth_akarin_rand,), + 'auth_akarin_spec_a': (create_auth_akarin_spec_a,), +} + + +class xorshift128plus(object): + max_int = (1 << 64) - 1 + mov_mask = (1 << (64 - 23)) - 1 + + def __init__(self): + self.v0 = 0 + self.v1 = 0 + + def next(self): + x = self.v0 + y = self.v1 + self.v0 = y + x ^= ((x & xorshift128plus.mov_mask) << 23) + x ^= (y ^ (x >> 17) ^ (y >> 26)) + self.v1 = x + return (x + y) & xorshift128plus.max_int + + def init_from_bin(self, bin): + if len(bin) < 16: + bin += b'\0' * 16 + self.v0 = struct.unpack('= len(str2): + if str1[:len(str2)] == str2: + return True + return False + + +class auth_base(plain.plain): + def __init__(self, method): + super(auth_base, self).__init__(method) + self.method = method + self.no_compatible_method = '' + self.overhead = 4 + + def init_data(self): + return '' + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def set_server_info(self, server_info): + self.server_info = server_info + + def client_encode(self, buf): + return buf + + def client_decode(self, buf): + return (buf, False) + + def server_encode(self, buf): + return buf + + def server_decode(self, buf): + return (buf, True, False) + + def not_match_return(self, buf): + self.raw_trans = True + self.overhead = 0 + if self.method == self.no_compatible_method: + return (b'E' * 2048, False) + return (buf, False) + + +class client_queue(object): + def __init__(self, begin_id): + self.front = begin_id - 64 + self.back = begin_id + 1 + self.alloc = {} + self.enable = True + self.last_update = time.time() + self.ref = 0 + + def update(self): + self.last_update = time.time() + + def addref(self): + self.ref += 1 + + def delref(self): + if self.ref > 0: + self.ref -= 1 + + def is_active(self): + return (self.ref > 0) and (time.time() - self.last_update < 60 * 10) + + def re_enable(self, connection_id): + self.enable = True + self.front = connection_id - 64 + self.back = connection_id + 1 + self.alloc = {} + + def insert(self, connection_id): + if not self.enable: + logging.warn('obfs auth: not enable') + return False + if not self.is_active(): + self.re_enable(connection_id) + self.update() + if connection_id < self.front: + logging.warn('obfs auth: deprecated id, someone replay attack') + return False + if connection_id > self.front + 0x4000: + logging.warn('obfs auth: wrong id') + return False + if connection_id in self.alloc: + logging.warn('obfs auth: duplicate id, someone replay attack') + return False + if self.back <= connection_id: + self.back = connection_id + 1 + self.alloc[connection_id] = 1 + while (self.front in self.alloc) or self.front + 0x1000 < self.back: + if self.front in self.alloc: + del self.alloc[self.front] + self.front += 1 + self.addref() + return True + + +class obfs_auth_akarin_data(object): + def __init__(self, name): + self.name = name + self.user_id = {} + self.local_client_id = b'' + self.connection_id = 0 + self.set_max_client(64) # max active client count + + def update(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if client_id in local_client_id: + local_client_id[client_id].update() + + def set_max_client(self, max_client): + self.max_client = max_client + self.max_buffer = max(self.max_client * 2, 1024) + + def insert(self, user_id, client_id, connection_id): + if user_id not in self.user_id: + self.user_id[user_id] = lru_cache.LRUCache() + local_client_id = self.user_id[user_id] + + if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: + if local_client_id.first() is None or len(local_client_id) < self.max_client: + if client_id not in local_client_id: + # TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + if not local_client_id[local_client_id.first()].is_active(): + del local_client_id[local_client_id.first()] + if client_id not in local_client_id: + # TODO: check + local_client_id[client_id] = client_queue(connection_id) + else: + local_client_id[client_id].re_enable(connection_id) + return local_client_id[client_id].insert(connection_id) + + logging.warn(self.name + ': no inactive client') + return False + else: + return local_client_id[client_id].insert(connection_id) + + def remove(self, user_id, client_id): + if user_id in self.user_id: + local_client_id = self.user_id[user_id] + if client_id in local_client_id: + local_client_id[client_id].delref() + + +class auth_akarin_rand(auth_base): + def __init__(self, method): + super(auth_akarin_rand, self).__init__(method) + self.hashfunc = hashlib.md5 + self.recv_buf = b'' + self.unit_len = 2800 + self.raw_trans = False + self.has_sent_header = False + self.has_recv_header = False + self.client_id = 0 + self.connection_id = 0 + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.salt = b"auth_akarin_rand" + self.no_compatible_method = 'auth_akarin_rand' + self.pack_id = 1 + self.recv_id = 1 + self.user_id = None + self.user_id_num = 0 + self.user_key = None + self.overhead = 4 + self.client_over_head = self.overhead + self.last_client_hash = b'' + self.last_server_hash = b'' + self.random_client = xorshift128plus() + self.random_server = xorshift128plus() + self.encryptor = None + self.new_send_tcp_mss = 2000 + self.send_tcp_mss = 2000 + self.recv_tcp_mss = 2000 + self.send_back_cmd = [] + + def init_data(self): + return obfs_auth_akarin_data(self.method) + + def get_overhead(self, direction): # direction: true for c->s false for s->c + return self.overhead + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + + def trapezoid_random_float(self, d): + if d == 0: + return random.random() + s = random.random() + a = 1 - d + return (math.sqrt(a * a + 4 * d * s) - a) / (2 * d) + + def trapezoid_random_int(self, max_val, d): + v = self.trapezoid_random_float(d) + return int(v * max_val) + + def send_rnd_data_len(self, buf_size, last_hash, random): + if buf_size + self.server_info.overhead > self.send_tcp_mss: + random.init_from_bin_len(last_hash, buf_size) + return random.next() % 521 + if buf_size >= 1440 or buf_size + self.server_info.overhead == self.send_tcp_mss: + return 0 + random.init_from_bin_len(last_hash, buf_size) + if buf_size > 1300: + return random.next() % 31 + if buf_size > 900: + return random.next() % 127 + if buf_size > 400: + return random.next() % 521 + return random.next() % (self.send_tcp_mss - buf_size - self.server_info.overhead) + + def recv_rnd_data_len(self, buf_size, last_hash, random): + if buf_size + self.server_info.overhead > self.recv_tcp_mss: + random.init_from_bin_len(last_hash, buf_size) + return random.next() % 521 + if buf_size >= 1440 or buf_size + self.server_info.overhead == self.send_tcp_mss: + return 0 + random.init_from_bin_len(last_hash, buf_size) + if buf_size > 1300: + return random.next() % 31 + if buf_size > 900: + return random.next() % 127 + if buf_size > 400: + return random.next() % 521 + return random.next() % (self.recv_tcp_mss - buf_size - self.server_info.overhead) + + def udp_rnd_data_len(self, last_hash, random): + random.init_from_bin(last_hash) + return random.next() % 127 + + def rnd_data(self, buf_size, buf, last_hash, random): + rand_len = self.send_rnd_data_len(buf_size, last_hash, random) + + rnd_data_buf = rand_bytes(rand_len) + + if buf_size == 0: + return rnd_data_buf + else: + if rand_len > 0: + return buf + rnd_data_buf + else: + return buf + + def pack_client_data(self, buf): + buf = self.encryptor.encrypt(buf) + if self.send_back_cmd: + cmd_len = 2 + self.send_tcp_mss = self.recv_tcp_mss + data = self.rnd_data(len(buf) + cmd_len, buf, self.last_client_hash, self.random_client) + length = len(buf) ^ struct.unpack('H', rand_bytes(2))[0] % 1024 + 400 + data = data + (struct.pack(' 0xFF000000: + self.server_info.data.local_client_id = b'' + if not self.server_info.data.local_client_id: + self.server_info.data.local_client_id = rand_bytes(4) + logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) + self.server_info.data.connection_id = struct.unpack(' self.unit_len: + ret += self.pack_client_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_client_data(buf) + return ret + + def client_post_decrypt(self, buf): + if self.raw_trans: + return buf + self.recv_buf += buf + out_buf = b'' + while len(self.recv_buf) > 4: + mac_key = self.user_key + struct.pack('= 4096: + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data error') + + if length + 4 > len(self.recv_buf): + break + + server_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() + if server_hash[:2] != self.recv_buf[length + 2: length + 4]: + logging.info('%s: checksum error, data %s' + % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) + self.raw_trans = True + self.recv_buf = b'' + raise Exception('client_post_decrypt data uncorrect checksum') + + pos = 2 + if data_len > 0 and rand_len > 0: + pos = 2 + out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos]) + self.last_server_hash = server_hash + if self.recv_id == 1: + self.server_info.tcp_mss = struct.unpack(' self.unit_len: + ret += self.pack_server_data(buf[:self.unit_len]) + buf = buf[self.unit_len:] + ret += self.pack_server_data(buf) + return ret + + def server_post_decrypt(self, buf): + if self.raw_trans: + return (buf, False) + self.recv_buf += buf + out_buf = b'' + sendback = False + + if not self.has_recv_header: + if len(self.recv_buf) >= 12 or len(self.recv_buf) in [7, 8]: + recv_len = min(len(self.recv_buf), 12) + mac_key = self.server_info.recv_iv + self.server_info.key + md5data = hmac.new(mac_key, self.recv_buf[:4], self.hashfunc).digest() + if md5data[:recv_len - 4] != self.recv_buf[4:recv_len]: + return self.not_match_return(self.recv_buf) + + if len(self.recv_buf) < 12 + 24: + return (b'', False) + + self.last_client_hash = md5data + uid = struct.unpack(' self.max_time_dif: + logging.info('%s: wrong timestamp, time_dif %d, data %s' % ( + self.no_compatible_method, time_dif, binascii.hexlify(head) + )) + return self.not_match_return(self.recv_buf) + elif self.server_info.data.insert(self.user_id, client_id, connection_id): + self.has_recv_header = True + self.client_id = client_id + self.connection_id = connection_id + else: + logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf))) + return self.not_match_return(self.recv_buf) + + self.on_recv_auth_data(utc_time) + self.encryptor = encrypt.Encryptor( + to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'chacha20', self.last_server_hash[:8]) + self.encryptor.encrypt(b'') + self.encryptor.decrypt(self.last_client_hash[:8]) + self.recv_buf = self.recv_buf[36:] + self.has_recv_header = True + sendback = True + + while len(self.recv_buf) > 4: + mac_key = self.user_key + struct.pack('= 0xff00: + if data_len == 0xff00: + cmd_len += 2 + self.recv_tcp_mss = self.send_tcp_mss + recv_buf = recv_buf[2:] + data_len = struct.unpack('= 4096: + self.raw_trans = True + self.recv_buf = b'' + if self.recv_id == 1: + logging.info(self.no_compatible_method + ': over size') + return (b'E' * 2048, False) + else: + raise Exception('server_post_decrype data error') + + if length + 4 > len(recv_buf): + break + + client_hash = hmac.new(mac_key, self.recv_buf[:length + cmd_len + 2], self.hashfunc).digest() + if client_hash[:2] != self.recv_buf[length + cmd_len + 2: length + cmd_len + 4]: + logging.info('%s: checksum error, data %s' % ( + self.no_compatible_method, binascii.hexlify(self.recv_buf[:length + cmd_len]), + )) + self.raw_trans = True + self.recv_buf = b'' + if self.recv_id == 1: + return (b'E' * 2048, False) + else: + raise Exception('server_post_decrype data uncorrect checksum') + + self.recv_id = (self.recv_id + 1) & 0xFFFFFFFF + pos = 2 + if data_len > 0 and rand_len > 0: + pos = 2 + out_buf += self.encryptor.decrypt(recv_buf[pos: data_len + pos]) + self.last_client_hash = client_hash + self.recv_buf = recv_buf[length + 4:] + if data_len == 0: + sendback = True + + if out_buf: + self.server_info.data.update(self.user_id, self.client_id, self.connection_id) + return (out_buf, sendback) + + def client_udp_pre_encrypt(self, buf): + if self.user_key is None: + if b':' in to_bytes(self.server_info.protocol_param): + try: + items = to_bytes(self.server_info.protocol_param).split(':') + self.user_key = self.hashfunc(items[1]).digest() + self.user_id = struct.pack(' self.send_tcp_mss: + random.init_from_bin_len(last_hash, buf_size) + return random.next() % 521 + if buf_size >= 1440 or buf_size + self.server_info.overhead == self.send_tcp_mss: + return 0 + random.init_from_bin_len(last_hash, buf_size) + pos = bisect.bisect_left(self.data_size_list, buf_size + self.server_info.overhead) + final_pos = pos + random.next() % (len(self.data_size_list)) + if final_pos < len(self.data_size_list): + return self.data_size_list[final_pos] - buf_size - self.server_info.overhead + + pos = bisect.bisect_left(self.data_size_list2, buf_size + self.server_info.overhead) + final_pos = pos + random.next() % (len(self.data_size_list2)) + if final_pos < len(self.data_size_list2): + return self.data_size_list2[final_pos] - buf_size - self.server_info.overhead + if final_pos < pos + len(self.data_size_list2) - 1: + return 0 + + if buf_size > 1300: + return random.next() % 31 + if buf_size > 900: + return random.next() % 127 + if buf_size > 400: + return random.next() % 521 + return random.next() % 1021 + + + def recv_rnd_data_len(self, buf_size, last_hash, random): + if buf_size + self.server_info.overhead > self.recv_tcp_mss: + random.init_from_bin_len(last_hash, buf_size) + return random.next() % 521 + if buf_size >= 1440 or buf_size + self.server_info.overhead == self.send_tcp_mss: + return 0 + random.init_from_bin_len(last_hash, buf_size) + pos = bisect.bisect_left(self.data_size_list, buf_size + self.server_info.overhead) + final_pos = pos + random.next() % (len(self.data_size_list)) + if final_pos < len(self.data_size_list): + return self.data_size_list[final_pos] - buf_size - self.server_info.overhead + + pos = bisect.bisect_left(self.data_size_list2, buf_size + self.server_info.overhead) + final_pos = pos + random.next() % (len(self.data_size_list2)) + if final_pos < len(self.data_size_list2): + return self.data_size_list2[final_pos] - buf_size - self.server_info.overhead + if final_pos < pos + len(self.data_size_list2) - 1: + return 0 + + if buf_size > 1300: + return random.next() % 31 + if buf_size > 900: + return random.next() % 127 + if buf_size > 400: + return random.next() % 521 + return random.next() % 1021 + + diff --git a/shadowsocks/obfsplugin/auth_chain.py b/shadowsocks/obfsplugin/auth_chain.py index 26097bfb9..9a2ce806a 100644 --- a/shadowsocks/obfsplugin/auth_chain.py +++ b/shadowsocks/obfsplugin/auth_chain.py @@ -1,4 +1,5 @@ #!/usr/bin/env python +# -*- coding: utf-8 -*- # # Copyright 2015-2015 breakwa11 # @@ -17,8 +18,6 @@ from __future__ import absolute_import, division, print_function, \ with_statement -import os -import sys import hashlib import logging import binascii @@ -28,27 +27,51 @@ import random import math import struct -import zlib import hmac -import hashlib import bisect import shadowsocks from shadowsocks import common, lru_cache, encrypt from shadowsocks.obfsplugin import plain from shadowsocks.common import to_bytes, to_str, ord, chr +from shadowsocks.crypto import openssl + +rand_bytes = openssl.rand_bytes def create_auth_chain_a(method): return auth_chain_a(method) + def create_auth_chain_b(method): return auth_chain_b(method) + +def create_auth_chain_c(method): + return auth_chain_c(method) + + +def create_auth_chain_d(method): + return auth_chain_d(method) + + +def create_auth_chain_e(method): + return auth_chain_e(method) + + +def create_auth_chain_f(method): + return auth_chain_f(method) + + obfs_map = { - 'auth_chain_a': (create_auth_chain_a,), - 'auth_chain_b': (create_auth_chain_b,), + 'auth_chain_a': (create_auth_chain_a,), + 'auth_chain_b': (create_auth_chain_b,), + 'auth_chain_c': (create_auth_chain_c,), + 'auth_chain_d': (create_auth_chain_d,), + 'auth_chain_e': (create_auth_chain_e,), + 'auth_chain_f': (create_auth_chain_f,), } + class xorshift128plus(object): max_int = (1 << 64) - 1 mov_mask = (1 << (64 - 23)) - 1 @@ -62,19 +85,20 @@ def next(self): y = self.v1 self.v0 = y x ^= ((x & xorshift128plus.mov_mask) << 23) - x ^= (y ^ (x >> 17) ^ (y >> 26)) & xorshift128plus.max_int + x ^= (y ^ (x >> 17) ^ (y >> 26)) self.v1 = x return (x + y) & xorshift128plus.max_int def init_from_bin(self, bin): - bin += b'\0' * 16 + if len(bin) < 16: + bin += b'\0' * 16 self.v0 = struct.unpack('s false for s->c + def get_overhead(self, direction): # direction: true for c->s false for s->c return self.overhead def set_server_info(self, server_info): @@ -118,9 +143,10 @@ def not_match_return(self, buf): self.raw_trans = True self.overhead = 0 if self.method == self.no_compatible_method: - return (b'E'*2048, False) + return (b'E' * 2048, False) return (buf, False) + class client_queue(object): def __init__(self, begin_id): self.front = begin_id - 64 @@ -175,13 +201,14 @@ def insert(self, connection_id): self.addref() return True + class obfs_auth_chain_data(object): def __init__(self, name): self.name = name self.user_id = {} self.local_client_id = b'' self.connection_id = 0 - self.set_max_client(64) # max active client count + self.set_max_client(64) # max active client count def update(self, user_id, client_id, connection_id): if user_id not in self.user_id: @@ -203,7 +230,7 @@ def insert(self, user_id, client_id, connection_id): if local_client_id.get(client_id, None) is None or not local_client_id[client_id].enable: if local_client_id.first() is None or len(local_client_id) < self.max_client: if client_id not in local_client_id: - #TODO: check + # TODO: check local_client_id[client_id] = client_queue(connection_id) else: local_client_id[client_id].re_enable(connection_id) @@ -212,7 +239,7 @@ def insert(self, user_id, client_id, connection_id): if not local_client_id[local_client_id.first()].is_active(): del local_client_id[local_client_id.first()] if client_id not in local_client_id: - #TODO: check + # TODO: check local_client_id[client_id] = client_queue(connection_id) else: local_client_id[client_id].re_enable(connection_id) @@ -229,6 +256,7 @@ def remove(self, user_id, client_id): if client_id in local_client_id: local_client_id[client_id].delref() + class auth_chain_a(auth_base): def __init__(self, method): super(auth_chain_a, self).__init__(method) @@ -240,7 +268,7 @@ def __init__(self, method): self.has_recv_header = False self.client_id = 0 self.connection_id = 0 - self.max_time_dif = 60 * 60 * 24 # time dif (second) setting + self.max_time_dif = 60 * 60 * 24 # time dif (second) setting self.salt = b"auth_chain_a" self.no_compatible_method = 'auth_chain_a' self.pack_id = 1 @@ -259,7 +287,7 @@ def __init__(self, method): def init_data(self): return obfs_auth_chain_data(self.method) - def get_overhead(self, direction): # direction: true for c->s false for s->c + def get_overhead(self, direction): # direction: true for c->s false for s->c return self.overhead def set_server_info(self, server_info): @@ -305,7 +333,7 @@ def rnd_start_pos(self, rand_len, random): def rnd_data(self, buf_size, buf, last_hash, random): rand_len = self.rnd_data_len(buf_size, last_hash, random) - rnd_data_buf = os.urandom(rand_len) + rnd_data_buf = rand_bytes(rand_len) if buf_size == 0: return rnd_data_buf @@ -319,7 +347,6 @@ def rnd_data(self, buf_size, buf, last_hash, random): def pack_client_data(self, buf): buf = self.encryptor.encrypt(buf) data = self.rnd_data(len(buf), buf, self.last_client_hash, self.random_client) - data_len = len(data) + 8 mac_key = self.user_key + struct.pack(' 0xFF000000: self.server_info.data.local_client_id = b'' if not self.server_info.data.local_client_id: - self.server_info.data.local_client_id = os.urandom(4) + self.server_info.data.local_client_id = rand_bytes(4) logging.debug("local_client_id %s" % (binascii.hexlify(self.server_info.data.local_client_id),)) - self.server_info.data.connection_id = struct.unpack(' 0 and rand_len > 0: pos = 2 + self.rnd_start_pos(rand_len, self.random_server) - out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos]) + out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos]) self.last_server_hash = server_hash if self.recv_id == 1: self.server_info.tcp_mss = struct.unpack(' self.max_time_dif: - logging.info('%s: wrong timestamp, time_dif %d, data %s' % (self.no_compatible_method, time_dif, binascii.hexlify(head))) + logging.info('%s: wrong timestamp, time_dif %d, data %s' % ( + self.no_compatible_method, time_dif, binascii.hexlify(head) + )) return self.not_match_return(self.recv_buf) elif self.server_info.data.insert(self.user_id, client_id, connection_id): self.has_recv_header = True @@ -513,7 +549,9 @@ def server_post_decrypt(self, buf): logging.info('%s: auth fail, data %s' % (self.no_compatible_method, binascii.hexlify(out_buf))) return self.not_match_return(self.recv_buf) - self.encryptor = encrypt.Encryptor(to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') + self.on_recv_auth_data(utc_time) + self.encryptor = encrypt.Encryptor( + to_bytes(base64.b64encode(self.user_key)) + to_bytes(base64.b64encode(self.last_client_hash)), 'rc4') self.recv_buf = self.recv_buf[36:] self.has_recv_header = True sendback = True @@ -526,9 +564,9 @@ def server_post_decrypt(self, buf): if length >= 4096: self.raw_trans = True self.recv_buf = b'' - if self.recv_id == 0: + if self.recv_id == 1: logging.info(self.no_compatible_method + ': over size') - return (b'E'*2048, False) + return (b'E' * 2048, False) else: raise Exception('server_post_decrype data error') @@ -536,12 +574,14 @@ def server_post_decrypt(self, buf): break client_hash = hmac.new(mac_key, self.recv_buf[:length + 2], self.hashfunc).digest() - if client_hash[:2] != self.recv_buf[length + 2 : length + 4]: - logging.info('%s: checksum error, data %s' % (self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]))) + if client_hash[:2] != self.recv_buf[length + 2: length + 4]: + logging.info('%s: checksum error, data %s' % ( + self.no_compatible_method, binascii.hexlify(self.recv_buf[:length]) + )) self.raw_trans = True self.recv_buf = b'' - if self.recv_id == 0: - return (b'E'*2048, False) + if self.recv_id == 1: + return (b'E' * 2048, False) else: raise Exception('server_post_decrype data uncorrect checksum') @@ -549,7 +589,7 @@ def server_post_decrypt(self, buf): pos = 2 if data_len > 0 and rand_len > 0: pos = 2 + self.rnd_start_pos(rand_len, self.random_client) - out_buf += self.encryptor.decrypt(self.recv_buf[pos : data_len + pos]) + out_buf += self.encryptor.decrypt(self.recv_buf[pos: data_len + pos]) self.last_client_hash = client_hash self.recv_buf = self.recv_buf[length + 4:] if data_len == 0: @@ -569,17 +609,18 @@ def client_udp_pre_encrypt(self, buf): except: pass if self.user_key is None: - self.user_id = os.urandom(4) + self.user_id = rand_bytes(4) self.user_key = self.server_info.key - authdata = os.urandom(3) + authdata = rand_bytes(3) mac_key = self.server_info.key md5data = hmac.new(mac_key, authdata, self.hashfunc).digest() uid = struct.unpack(' 1300: return random.next() % 31 @@ -690,3 +743,175 @@ def rnd_data_len(self, buf_size, last_hash, random): return random.next() % 521 return random.next() % 1021 + +class auth_chain_c(auth_chain_b): + def __init__(self, method): + super(auth_chain_c, self).__init__(method) + self.salt = b"auth_chain_c" + self.no_compatible_method = 'auth_chain_c' + self.data_size_list0 = [] + + def init_data_size(self, key): + if self.data_size_list0: + self.data_size_list0 = [] + random = xorshift128plus() + random.init_from_bin(key) + # 补全数组长为12~24-1 + list_len = random.next() % (8 + 16) + (4 + 8) + for i in range(0, list_len): + self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440)) + self.data_size_list0.sort() + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + self.init_data_size(self.server_info.key) + + def rnd_data_len(self, buf_size, last_hash, random): + other_data_size = buf_size + self.server_info.overhead + # 一定要在random使用前初始化,以保证服务器与客户端同步,保证包大小验证结果正确 + random.init_from_bin_len(last_hash, buf_size) + # final_pos 总是分布在pos~(data_size_list0.len-1)之间 + # 除非data_size_list0中的任何值均过小使其全部都无法容纳buf + if other_data_size >= self.data_size_list0[-1]: + if other_data_size >= 1440: + return 0 + if other_data_size > 1300: + return random.next() % 31 + if other_data_size > 900: + return random.next() % 127 + if other_data_size > 400: + return random.next() % 521 + return random.next() % 1021 + + pos = bisect.bisect_left(self.data_size_list0, other_data_size) + # random select a size in the leftover data_size_list0 + final_pos = pos + random.next() % (len(self.data_size_list0) - pos) + return self.data_size_list0[final_pos] - other_data_size + + +class auth_chain_d(auth_chain_b): + def __init__(self, method): + super(auth_chain_d, self).__init__(method) + self.salt = b"auth_chain_d" + self.no_compatible_method = 'auth_chain_d' + self.data_size_list0 = [] + + def check_and_patch_data_size(self, random): + # append new item + # when the biggest item(first time) or the last append item(other time) are not big enough. + # but set a limit size (64) to avoid stack overflow. + if self.data_size_list0[-1] < 1300 and len(self.data_size_list0) < 64: + self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440)) + self.check_and_patch_data_size(random) + + def init_data_size(self, key): + if self.data_size_list0: + self.data_size_list0 = [] + random = xorshift128plus() + random.init_from_bin(key) + # 补全数组长为12~24-1 + list_len = random.next() % (8 + 16) + (4 + 8) + for i in range(0, list_len): + self.data_size_list0.append((int)(random.next() % 2340 % 2040 % 1440)) + self.data_size_list0.sort() + old_len = len(self.data_size_list0) + self.check_and_patch_data_size(random) + # if check_and_patch_data_size are work, re-sort again. + if old_len != len(self.data_size_list0): + self.data_size_list0.sort() + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + self.init_data_size(self.server_info.key) + + def rnd_data_len(self, buf_size, last_hash, random): + other_data_size = buf_size + self.server_info.overhead + # if other_data_size > the bigest item in data_size_list0, not padding any data + if other_data_size >= self.data_size_list0[-1]: + return 0 + + random.init_from_bin_len(last_hash, buf_size) + pos = bisect.bisect_left(self.data_size_list0, other_data_size) + # random select a size in the leftover data_size_list0 + final_pos = pos + random.next() % (len(self.data_size_list0) - pos) + return self.data_size_list0[final_pos] - other_data_size + + +class auth_chain_e(auth_chain_d): + def __init__(self, method): + super(auth_chain_e, self).__init__(method) + self.salt = b"auth_chain_e" + self.no_compatible_method = 'auth_chain_e' + + def rnd_data_len(self, buf_size, last_hash, random): + random.init_from_bin_len(last_hash, buf_size) + other_data_size = buf_size + self.server_info.overhead + # if other_data_size > the bigest item in data_size_list0, not padding any data + if other_data_size >= self.data_size_list0[-1]: + return 0 + + # use the mini size in the data_size_list0 + pos = bisect.bisect_left(self.data_size_list0, other_data_size) + return self.data_size_list0[pos] - other_data_size + + +# auth_chain_f +# when every connect create, generate size_list will different when every day or every custom time interval which set in the config +class auth_chain_f(auth_chain_e): + def __init__(self, method): + super(auth_chain_f, self).__init__(method) + self.salt = b"auth_chain_f" + self.no_compatible_method = 'auth_chain_f' + + def set_server_info(self, server_info): + self.server_info = server_info + try: + max_client = int(server_info.protocol_param.split('#')[0]) + except: + max_client = 64 + self.server_info.data.set_max_client(max_client) + try: + self.key_change_interval = int(server_info.protocol_param.split('#')[1]) # config are in second + except: + self.key_change_interval = 60 * 60 * 24 # a day by second + + def on_recv_auth_data(self, utc_time): + self.key_change_datetime_key = int(utc_time / self.key_change_interval) + self.key_change_datetime_key_bytes = [] # big bit first list + for i in range(7, -1, -1): # big-ending compare to c + self.key_change_datetime_key_bytes.append((self.key_change_datetime_key >> (8 * i)) & 0xFF) + self.init_data_size(self.server_info.key) + + def init_data_size(self, key): + if self.data_size_list0: + self.data_size_list0 = [] + random = xorshift128plus() + # key xor with key_change_datetime_key + new_key = bytearray(key) + new_key_str = '' + for i in range(0, 8): + new_key[i] ^= self.key_change_datetime_key_bytes[i] + new_key_str += chr(new_key[i]) + for i in range(8, len(new_key)): + new_key_str += chr(new_key[i]) + random.init_from_bin(to_bytes(new_key_str)) + # 补全数组长为12~24-1 + list_len = random.next() % (8 + 16) + (4 + 8) + for i in range(0, list_len): + self.data_size_list0.append(int(random.next() % 2340 % 2040 % 1440)) + self.data_size_list0.sort() + old_len = len(self.data_size_list0) + self.check_and_patch_data_size(random) + # if check_and_patch_data_size are work, re-sort again. + if old_len != len(self.data_size_list0): + self.data_size_list0.sort() diff --git a/shadowsocks/obfsplugin/http_simple.py b/shadowsocks/obfsplugin/http_simple.py index 6f1a05e4b..ff3c5fdfb 100644 --- a/shadowsocks/obfsplugin/http_simple.py +++ b/shadowsocks/obfsplugin/http_simple.py @@ -63,6 +63,7 @@ def __init__(self, method): self.host = None self.port = 0 self.recv_buffer = b'' + # TODO user config user_agent self.user_agent = [b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/40.0", b"Mozilla/5.0 (Windows NT 6.3; WOW64; rv:40.0) Gecko/20100101 Firefox/44.0", b"Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36", diff --git a/shadowsocks/server.py b/shadowsocks/server.py index c18ad1cca..dced0a397 100755 --- a/shadowsocks/server.py +++ b/shadowsocks/server.py @@ -25,6 +25,7 @@ if __name__ == '__main__': import inspect + file_path = os.path.dirname(os.path.realpath(inspect.getfile(inspect.currentframe()))) sys.path.insert(0, os.path.join(file_path, '../')) @@ -43,7 +44,8 @@ def main(): try: import resource - logging.info('current process RLIMIT_NOFILE resource: soft %d hard %d' % resource.getrlimit(resource.RLIMIT_NOFILE)) + logging.info( + 'current process RLIMIT_NOFILE resource: soft %d hard %d' % resource.getrlimit(resource.RLIMIT_NOFILE)) except ImportError: pass @@ -68,7 +70,7 @@ def main(): tcp_servers = [] udp_servers = [] - dns_resolver = asyncdns.DNSResolver() + dns_resolver = asyncdns.DNSResolver(config['black_hostname_list']) if int(config['workers']) > 1: stat_counter_dict = None else: @@ -103,10 +105,11 @@ def main(): a_config = config.copy() ipv6_ok = False logging.info("server start with protocol[%s] password [%s] method [%s] obfs [%s] obfs_param [%s]" % - (protocol, password, method, obfs, obfs_param)) + (protocol, password, method, obfs, obfs_param)) if 'server_ipv6' in a_config: try: - if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == "[" and a_config['server_ipv6'][-1] == "]": + if len(a_config['server_ipv6']) > 2 and a_config['server_ipv6'][0] == b"[" and a_config['server_ipv6'][ + -1] == b"]": a_config['server_ipv6'] = a_config['server_ipv6'][1:-1] a_config['server_port'] = int(port) a_config['password'] = password @@ -117,7 +120,7 @@ def main(): a_config['obfs_param'] = obfs_param a_config['out_bind'] = bind a_config['out_bindv6'] = bindv6 - a_config['server'] = a_config['server_ipv6'] + a_config['server'] = common.to_str(a_config['server_ipv6']) logging.info("starting server at [%s]:%d" % (a_config['server'], int(port))) tcp_servers.append(tcprelay.TCPRelay(a_config, dns_resolver, False, stat_counter=stat_counter_dict)) @@ -151,11 +154,13 @@ def child_handler(signum, _): logging.warn('received SIGQUIT, doing graceful shutting down..') list(map(lambda s: s.close(next_tick=True), tcp_servers + udp_servers)) + signal.signal(getattr(signal, 'SIGQUIT', signal.SIGTERM), child_handler) def int_handler(signum, _): sys.exit(1) + signal.signal(signal.SIGINT, int_handler) try: @@ -191,6 +196,7 @@ def handler(signum, _): except OSError: # child may already exited pass sys.exit() + signal.signal(signal.SIGTERM, handler) signal.signal(signal.SIGQUIT, handler) signal.signal(signal.SIGINT, handler) diff --git a/shadowsocks/shell.py b/shadowsocks/shell.py index 6246d98af..a1547d082 100755 --- a/shadowsocks/shell.py +++ b/shadowsocks/shell.py @@ -26,7 +26,6 @@ from shadowsocks.common import to_bytes, to_str, IPNetwork, PortRange from shadowsocks import encrypt - VERBOSE_LEVEL = 5 verbose = 0 @@ -52,6 +51,7 @@ def print_exception(e): import traceback traceback.print_exc() + def __version(): version_str = '' try: @@ -65,9 +65,11 @@ def __version(): pass return version_str + def print_shadowsocks(): print('ShadowsocksR %s' % __version()) + def log_shadowsocks_version(): logging.info('ShadowsocksR %s' % __version()) @@ -84,6 +86,7 @@ def sub_find(file_name): return sub_find(user_config_path) or sub_find(config_path) + def check_config(config, is_local): if config.get('daemon', None) == 'stop': # no need to specify configuration for daemon stop @@ -110,13 +113,13 @@ def check_config(config, is_local): logging.warning('warning: local set to listen on 0.0.0.0, it\'s not safe') if config.get('server', '') in ['127.0.0.1', 'localhost']: logging.warning('warning: server set to listen on %s:%s, are you sure?' % - (to_str(config['server']), config['server_port'])) + (to_str(config['server']), config['server_port'])) if config.get('timeout', 300) < 100: logging.warning('warning: your timeout %d seems too short' % - int(config.get('timeout'))) + int(config.get('timeout'))) if config.get('timeout', 300) > 600: logging.warning('warning: your timeout %d seems too long' % - int(config.get('timeout'))) + int(config.get('timeout'))) if config.get('password') in [b'mypassword']: logging.error('DON\'T USE DEFAULT PASSWORD! Please change it in your ' 'config.json!') @@ -160,7 +163,6 @@ def get_config(is_local): if config_path is None: config_path = find_config() - if config_path: logging.debug('loading config from %s' % config_path) with open(config_path, 'rb') as f: @@ -170,7 +172,6 @@ def get_config(is_local): logging.error('found an error in config.json: %s', str(e)) sys.exit(1) - v_count = 0 for key, value in optlist: if key == '-p': @@ -260,6 +261,9 @@ def get_config(is_local): config['server'] = to_str(config['server']) else: config['server'] = to_str(config.get('server', '0.0.0.0')) + config['black_hostname_list'] = to_str(config.get('black_hostname_list', '')).split(',') + if len(config['black_hostname_list']) == 1 and config['black_hostname_list'][0] == '': + config['black_hostname_list'] = [] try: config['forbidden_ip'] = \ IPNetwork(config.get('forbidden_ip', '127.0.0.0/8,::1/128')) @@ -398,6 +402,7 @@ def _decode_dict(data): rv[key] = value return rv + class JSFormat: def __init__(self): self.state = 0 @@ -435,6 +440,7 @@ def push(self, ch): return "\n" return "" + def remove_comment(json): fmt = JSFormat() return "".join([fmt.push(c) for c in json]) diff --git a/shadowsocks/tcprelay.py b/shadowsocks/tcprelay.py index 595e2be73..1e11a6be5 100644 --- a/shadowsocks/tcprelay.py +++ b/shadowsocks/tcprelay.py @@ -266,7 +266,7 @@ def _update_tcp_mss(self, local_sock): def _create_encryptor(self, config): try: self._encryptor = encrypt.Encryptor(config['password'], - config['method']) + config['method'], None, True) return True except Exception: self._stage = STAGE_DESTROYED @@ -490,7 +490,6 @@ def _get_redirect_host(self, client_address, ogn_data): return host_port[((hash_code & 0xffffffff) + addr) % len(host_port)] else: - host_port = [] for host in host_list: items_sum = common.to_str(host).rsplit('#', 1) items_match = common.to_str(items_sum[0]).rsplit(':', 1) @@ -1162,13 +1161,20 @@ def destroy(self): if self._protocol: self._protocol.dispose() self._protocol = None - self._encryptor = None + + if self._encryptor: + self._encryptor.dispose() + self._encryptor = None self._dns_resolver.remove_callback(self._handle_dns_resolved) self._server.remove_handler(self) if self._add_ref > 0: self._server.add_connection(-1) self._server.stat_add(self._client_address[0], -1) + #import gc + #gc.collect() + #logging.debug("gc %s" % (gc.garbage,)) + class TCPRelay(object): def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_counter=None): self._config = config diff --git a/shadowsocks/udprelay.py b/shadowsocks/udprelay.py index b9606cd81..0ab00b834 100644 --- a/shadowsocks/udprelay.py +++ b/shadowsocks/udprelay.py @@ -162,7 +162,7 @@ def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_coun self.server_user_transfer_ul = {} self.server_user_transfer_dl = {} - if common.to_bytes(config['protocol']) in obfs.mu_protocol(): + if common.to_str(config['protocol']) in obfs.mu_protocol(): self._update_users(None, None) self.protocol_data = obfs.obfs(config['protocol']).init_data() @@ -213,6 +213,8 @@ def __init__(self, config, dns_resolver, is_local, stat_callback=None, stat_coun server_socket = socket.socket(af, socktype, proto) server_socket.bind((self._listen_addr, self._listen_port)) server_socket.setblocking(False) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024 * 1024) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 1024 * 1024) self._server_socket = server_socket self._stat_callback = stat_callback diff --git a/shadowsocks/version.py b/shadowsocks/version.py index f3e1ef796..7454f3fd5 100644 --- a/shadowsocks/version.py +++ b/shadowsocks/version.py @@ -16,5 +16,5 @@ # under the License. def version(): - return '3.4.0 2017-07-27' + return 'SSRR 3.2.2 2018-05-22' diff --git a/switchrule.py b/switchrule.py index 6687e12cf..56ed995d2 100644 --- a/switchrule.py +++ b/switchrule.py @@ -1,3 +1,6 @@ +def getRowMap(): + return {} # if your db row "encrypt" means "method", write {"encrypt": "method"} + def getKeys(key_list): return key_list #return key_list + ['plan'] # append the column name 'plan'