diff --git a/channels_redis/core.py b/channels_redis/core.py index 7c04ecd..c691628 100644 --- a/channels_redis/core.py +++ b/channels_redis/core.py @@ -15,7 +15,7 @@ from channels.exceptions import ChannelFull from channels.layers import BaseChannelLayer -from .utils import _consistent_hash +from .utils import create_pool, decode_hosts, _consistent_hash logger = logging.getLogger(__name__) @@ -98,7 +98,7 @@ def __init__( self.prefix = prefix assert isinstance(self.prefix, str), "Prefix must be unicode" # Configure the host objects - self.hosts = self.decode_hosts(hosts) + self.hosts = decode_hosts(hosts) self.ring_size = len(self.hosts) # Cached redis connection pools and the event loop they are from self.pools = {} @@ -127,46 +127,7 @@ def __init__( self.receive_clean_locks = ChannelLock() def create_pool(self, index): - host = self.hosts[index] - - if "address" in host: - return aioredis.ConnectionPool.from_url(host["address"]) - elif "master_name" in host: - sentinels = host.pop("sentinels") - master_name = host.pop("master_name") - sentinel_kwargs = host.pop("sentinel_kwargs", None) - return aioredis.sentinel.SentinelConnectionPool( - master_name, - aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs), - **host - ) - else: - return aioredis.ConnectionPool(**host) - - def decode_hosts(self, hosts): - """ - Takes the value of the "hosts" argument passed to the class and returns - a list of kwargs to use for the Redis connection constructor. - """ - # If no hosts were provided, return a default value - if not hosts: - return [{"address": "redis://localhost:6379"}] - # If they provided just a string, scold them. - if isinstance(hosts, (str, bytes)): - raise ValueError( - "You must pass a list of Redis hosts, even if there is only one." - ) - - # Decode each hosts entry into a kwargs dict - result = [] - for entry in hosts: - if isinstance(entry, dict): - result.append(entry) - elif isinstance(entry, tuple): - result.append({"host": entry[0], "port": entry[1]}) - else: - result.append({"address": entry}) - return result + return create_pool(self.hosts[index]) def _setup_encryption(self, symmetric_encryption_keys): # See if we can do encryption if they asked diff --git a/channels_redis/pubsub.py b/channels_redis/pubsub.py index ccaef0f..7318ff2 100644 --- a/channels_redis/pubsub.py +++ b/channels_redis/pubsub.py @@ -7,7 +7,7 @@ import msgpack from redis import asyncio as aioredis -from .utils import _consistent_hash +from .utils import create_pool, decode_hosts, _consistent_hash logger = logging.getLogger(__name__) @@ -97,12 +97,6 @@ def __init__( channel_layer=None, **kwargs, ): - if hosts is None: - hosts = ["redis://localhost:6379"] - assert ( - isinstance(hosts, list) and len(hosts) > 0 - ), "`hosts` must be a list with at least one Redis server" - self.prefix = prefix self.on_disconnect = on_disconnect @@ -118,7 +112,7 @@ def __init__( self.groups = {} # For each host, we create a `RedisSingleShardConnection` to manage the connection to that host. - self._shards = [RedisSingleShardConnection(host, self) for host in hosts] + self._shards = [RedisSingleShardConnection(host, self) for host in decode_hosts(hosts)] def _get_shard(self, channel_or_group_name): """ @@ -263,9 +257,7 @@ async def flush(self): class RedisSingleShardConnection: def __init__(self, host, channel_layer): - self.host = host.copy() if type(host) is dict else {"address": host} - self.master_name = self.host.pop("master_name", None) - self.sentinel_kwargs = self.host.pop("sentinel_kwargs", None) + self.host = host self.channel_layer = channel_layer self._subscribed_to = set() self._lock = asyncio.Lock() @@ -347,18 +339,7 @@ def _receive_message(self, message): def _ensure_redis(self): if self._redis is None: - if self.master_name is None: - pool = aioredis.ConnectionPool.from_url(self.host["address"]) - else: - # aioredis default timeout is way too low - pool = aioredis.sentinel.SentinelConnectionPool( - self.master_name, - aioredis.sentinel.Sentinel( - self.host["sentinels"], - socket_timeout=2, - sentinel_kwargs=self.sentinel_kwargs, - ), - ) + pool = create_pool(self.host) self._redis = aioredis.Redis(connection_pool=pool) self._pubsub = self._redis.pubsub() diff --git a/channels_redis/utils.py b/channels_redis/utils.py index 7b30fdc..b0dd0cd 100644 --- a/channels_redis/utils.py +++ b/channels_redis/utils.py @@ -1,4 +1,5 @@ import binascii +from redis import asyncio as aioredis def _consistent_hash(value, ring_size): @@ -15,3 +16,53 @@ def _consistent_hash(value, ring_size): bigval = binascii.crc32(value) & 0xFFF ring_divisor = 4096 / float(ring_size) return int(bigval / ring_divisor) + + +def decode_hosts(hosts): + """ + Takes the value of the "hosts" argument and returns + a list of kwargs to use for the Redis connection constructor. + """ + # If no hosts were provided, return a default value + if not hosts: + return [{"address": "redis://localhost:6379"}] + # If they provided just a string, scold them. + if isinstance(hosts, (str, bytes)): + raise ValueError( + "You must pass a list of Redis hosts, even if there is only one." + ) + + # Decode each hosts entry into a kwargs dict + result = [] + for entry in hosts: + if isinstance(entry, dict): + result.append(entry) + elif isinstance(entry, (tuple, list)): + result.append({"host": entry[0], "port": entry[1]}) + else: + result.append({"address": entry}) + return result + + +def create_pool(host): + """ + Takes the value of the "host" argument and returns a suited connection pool to + the corresponding redis instance. + """ + # avoid side-effects from modifying host + host = host.copy() + if "address" in host: + address = host.pop('address') + return aioredis.ConnectionPool.from_url(address, **host) + + master_name = host.pop('master_name', None) + if master_name is not None: + sentinels = host.pop("sentinels") + sentinel_kwargs = host.pop("sentinel_kwargs", None) + return aioredis.sentinel.SentinelConnectionPool( + master_name, + aioredis.sentinel.Sentinel(sentinels, sentinel_kwargs=sentinel_kwargs), + **host + ) + + return aioredis.ConnectionPool(**host)