import itertools
from typing import Any, Callable, Optional, TYPE_CHECKING
import redis.asyncio as redis
from redis.exceptions import ResponseError as IncrbyException
from aiocache.base import BaseCache
from aiocache.serializers import JsonSerializer
if TYPE_CHECKING: # pragma: no cover
from aiocache.serializers import BaseSerializer
class RedisBackend(BaseCache[str]):
RELEASE_SCRIPT = (
"if redis.call('get',KEYS[1]) == ARGV[1] then"
" return redis.call('del',KEYS[1])"
" else"
" return 0"
" end"
)
CAS_SCRIPT = (
"if redis.call('get',KEYS[1]) == ARGV[2] then"
" if #ARGV == 4 then"
" return redis.call('set', KEYS[1], ARGV[1], ARGV[3], ARGV[4])"
" else"
" return redis.call('set', KEYS[1], ARGV[1])"
" end"
" else"
" return 0"
" end"
)
def __init__(
self,
client: redis.Redis,
**kwargs,
):
super().__init__(**kwargs)
# NOTE: decoding can't be controlled on API level after switching to
# redis, we need to disable decoding on global/connection level
# (decode_responses=False), because some of the values are saved as
# bytes directly, like pickle serialized values, which may raise an
# exception when decoded with 'utf-8'.
if client.connection_pool.connection_kwargs['decode_responses']:
raise ValueError("redis client must be constructed with decode_responses set to False")
self.client = client
async def _get(self, key, encoding="utf-8", _conn=None):
value = await self.client.get(key)
if encoding is None or value is None:
return value
return value.decode(encoding)
async def _gets(self, key, encoding="utf-8", _conn=None):
return await self._get(key, encoding=encoding, _conn=_conn)
async def _multi_get(self, keys, encoding="utf-8", _conn=None):
values = await self.client.mget(*keys)
if encoding is None:
return values
return [v if v is None else v.decode(encoding) for v in values]
async def _set(self, key, value, ttl=None, _cas_token=None, _conn=None):
if _cas_token is not None:
return await self._cas(key, value, _cas_token, ttl=ttl, _conn=_conn)
if ttl is None:
return await self.client.set(key, value)
if isinstance(ttl, float):
ttl = int(ttl * 1000)
return await self.client.psetex(key, ttl, value)
return await self.client.setex(key, ttl, value)
async def _cas(self, key, value, token, ttl=None, _conn=None):
args = ()
if ttl is not None:
args = ("PX", int(ttl * 1000)) if isinstance(ttl, float) else ("EX", ttl)
return await self._raw("eval", self.CAS_SCRIPT, 1, key, value, token, *args, _conn=_conn)
async def _multi_set(self, pairs, ttl=None, _conn=None):
ttl = ttl or 0
flattened = list(itertools.chain.from_iterable((key, value) for key, value in pairs))
if ttl:
await self.__multi_set_ttl(flattened, ttl)
else:
await self.client.execute_command("MSET", *flattened)
return True
async def __multi_set_ttl(self, flattened, ttl):
async with self.client.pipeline(transaction=True) as p:
p.execute_command("MSET", *flattened)
ttl, exp = (int(ttl * 1000), p.pexpire) if isinstance(ttl, float) else (ttl, p.expire)
for key in flattened[::2]:
exp(key, time=ttl)
await p.execute()
async def _add(self, key, value, ttl=None, _conn=None):
kwargs = {"nx": True}
if isinstance(ttl, float):
kwargs["px"] = int(ttl * 1000)
else:
kwargs["ex"] = ttl
was_set = await self.client.set(key, value, **kwargs)
if not was_set:
raise ValueError("Key {} already exists, use .set to update the value".format(key))
return was_set
async def _exists(self, key, _conn=None):
number = await self.client.exists(key)
return bool(number)
async def _increment(self, key, delta, _conn=None):
try:
return await self.client.incrby(key, delta)
except IncrbyException:
raise TypeError("Value is not an integer") from None
async def _expire(self, key, ttl, _conn=None):
if ttl == 0:
return await self.client.persist(key)
return await self.client.expire(key, ttl)
async def _delete(self, key, _conn=None):
return await self.client.delete(key)
async def _clear(self, namespace=None, _conn=None):
if namespace:
keys = await self.client.keys("{}:*".format(namespace))
if keys:
await self.client.delete(*keys)
else:
await self.client.flushdb()
return True
async def _raw(self, command, *args, encoding="utf-8", _conn=None, **kwargs):
value = await getattr(self.client, command)(*args, **kwargs)
if encoding is not None:
if command == "get" and value is not None:
value = value.decode(encoding)
elif command in {"keys", "mget"}:
value = [v if v is None else v.decode(encoding) for v in value]
return value
async def _redlock_release(self, key, value):
return await self._raw("eval", self.RELEASE_SCRIPT, 1, key, value)
def build_key(self, key: str, namespace: Optional[str] = None) -> str:
return self._str_build_key(key, namespace)
[docs]
class RedisCache(RedisBackend):
"""
Redis cache implementation with the following components as defaults:
- serializer: :class:`aiocache.serializers.JsonSerializer`
- plugins: []
Config options are:
:param serializer: obj derived from :class:`aiocache.serializers.BaseSerializer`.
:param plugins: list of :class:`aiocache.plugins.BasePlugin` derived classes.
:param namespace: string to use as default prefix for the key used in all operations of
the backend. Default is an empty string, "".
:param timeout: int or float in seconds specifying maximum timeout for the operations to last.
By default its 5.
:param client: redis.Redis which is an active client for working with redis
"""
NAME = "redis"
def __init__(
self,
client: redis.Redis,
serializer: Optional["BaseSerializer"] = None,
namespace: str = "",
key_builder: Callable[[str, str], str] = lambda k, ns: f"{ns}:{k}" if ns else k,
**kwargs: Any,
):
super().__init__(
client=client,
serializer=serializer or JsonSerializer(),
namespace=namespace,
key_builder=key_builder,
**kwargs,
)
[docs]
@classmethod
def parse_uri_path(cls, path):
"""
Given a uri path, return the Redis specific configuration
options in that path string according to iana definition
http://www.iana.org/assignments/uri-schemes/prov/redis
:param path: string containing the path. Example: "/0"
:return: mapping containing the options. Example: {"db": "0"}
"""
options = {}
db, *_ = path[1:].split("/")
if db:
options["db"] = db
return options
def __repr__(self): # pragma: no cover
connection_kwargs = self.client.connection_pool.connection_kwargs
return "RedisCache ({}:{})".format(connection_kwargs['host'], connection_kwargs['port'])