Skip to content

feat(typing)!: make cache generic over value (#608) #986

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiocache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

logger = logging.getLogger(__name__)

_AIOCACHE_CACHES: list[Type[BaseCache[Any]]] = [SimpleMemoryCache]
_AIOCACHE_CACHES: list[Type[BaseCache[Any, Any]]] = [SimpleMemoryCache]

try:
import redis
Expand Down
4 changes: 2 additions & 2 deletions aiocache/backends/memcached.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import asyncio
from typing import Optional
from typing import Any, Optional

import aiomcache

from aiocache.base import BaseCache
from aiocache.serializers import JsonSerializer


class MemcachedBackend(BaseCache[bytes]):
class MemcachedBackend(BaseCache[bytes, Any]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all these need to be defined as a TypeVar. Then the user would need to annotate their code as: MemcachedCache[Any] and similar.

Though, probably better to get #684 done before this one.

def __init__(self, host="127.0.0.1", port=11211, pool_size=2, **kwargs):
super().__init__(**kwargs)
self.host = host
Expand Down
11 changes: 8 additions & 3 deletions aiocache/backends/memory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import asyncio
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, TypeVar

from aiocache.base import BaseCache
from aiocache.serializers import NullSerializer

CacheKeyType = TypeVar('CacheKeyType')
CacheValueType = TypeVar('CacheValueType')

class SimpleMemoryBackend(BaseCache[str]):

class SimpleMemoryBackend(
BaseCache[CacheKeyType, CacheValueType],
):
"""
Wrapper around dict operations to use it as a cache backend
"""
Expand Down Expand Up @@ -110,7 +115,7 @@ def build_key(self, key: str, namespace: Optional[str] = None) -> str:
return self._str_build_key(key, namespace)


class SimpleMemoryCache(SimpleMemoryBackend):
class SimpleMemoryCache(SimpleMemoryBackend[str, Any]):
"""
Memory cache implementation with the following components as defaults:
- serializer: :class:`aiocache.serializers.NullSerializer`
Expand Down
2 changes: 1 addition & 1 deletion aiocache/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from aiocache.serializers import BaseSerializer


class RedisBackend(BaseCache[str]):
class RedisBackend(BaseCache[str, Any]):
RELEASE_SCRIPT = (
"if redis.call('get',KEYS[1]) == ARGV[1] then"
" return redis.call('del',KEYS[1])"
Expand Down
93 changes: 67 additions & 26 deletions aiocache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@
from abc import ABC, abstractmethod
from enum import Enum
from types import TracebackType
from typing import Callable, Generic, List, Optional, Set, TYPE_CHECKING, Type, TypeVar
from typing import Any, Callable, Generic, List, Optional, Set, TYPE_CHECKING, Type, TypeVar

from aiocache.serializers import StringSerializer

if TYPE_CHECKING: # pragma: no cover
from aiocache.plugins import BasePlugin
from aiocache.serializers import BaseSerializer

CacheKeyType = TypeVar("CacheKeyType")
CacheValueType = TypeVar("CacheValueType")


logger = logging.getLogger(__name__)

SENTINEL = object()
CacheKeyType = TypeVar("CacheKeyType")


class API:
Expand Down Expand Up @@ -93,7 +95,7 @@ async def _plugins(self, *args, **kwargs):
return _plugins


class BaseCache(Generic[CacheKeyType], ABC):
class BaseCache(Generic[CacheKeyType, CacheValueType], ABC):
"""
Base class that agregates the common logic for the different caches that may exist. Cache
related available options are:
Expand All @@ -110,6 +112,8 @@ class BaseCache(Generic[CacheKeyType], ABC):
By default its 5. Use 0 or None if you want to disable it.
:param ttl: int the expiration time in seconds to use as a default in all operations of
the backend. It can be overriden in the specific calls.
:typeparam CacheKeyType: The type of the cache key (e.g., str, bytes).
:typeparam CacheValueType: The type of the cache value (e.g., str, int, custom object).
"""

NAME: str
Expand Down Expand Up @@ -152,16 +156,25 @@ def plugins(self, value):
@API.aiocache_enabled(fake_return=True)
@API.timeout
@API.plugins
async def add(self, key, value, ttl=SENTINEL, dumps_fn=None, namespace=None, _conn=None):
async def add(
self,
key: CacheKeyType,
value: CacheValueType,
ttl=SENTINEL,
dumps_fn: Optional[Callable[[CacheValueType], Any]] = None,
namespace: Optional[str] = None,
_conn=None,
) -> bool:
"""
Stores the value in the given key with ttl if specified. Raises an error if the
key already exists.

:param key: str
:param value: obj
:param ttl: int the expiration time in seconds. Due to memcached
restrictions if you want compatibility use int. In case you
need miliseconds, redis and memory support float ttls
:param key: CacheKeyType
:param value: CacheValueType
:param ttl: int the expiration time in seconds. Due to memcached restrictions.
If you want compatibility use int.
In case you need milliseconds,
redis and memory support float ttls
:param dumps_fn: callable alternative to use as dumps function
:param namespace: str alternative namespace to use
:param timeout: int or float in seconds specifying maximum timeout
Expand All @@ -188,17 +201,24 @@ async def _add(self, key, value, ttl, _conn=None):
@API.aiocache_enabled()
@API.timeout
@API.plugins
async def get(self, key, default=None, loads_fn=None, namespace=None, _conn=None):
async def get(
self,
key: CacheKeyType,
default: Optional[CacheValueType] = None,
loads_fn: Optional[Callable[[Any], CacheValueType]] = None,
namespace: Optional[str] = None,
_conn=None,
) -> Optional[CacheValueType]:
"""
Get a value from the cache. Returns default if not found.

:param key: str
:param default: obj to return when key is not found
:param key: CacheKeyType
:param default: CacheValueType to return when key is not found
:param loads_fn: callable alternative to use as loads function
:param namespace: str alternative namespace to use
:param timeout: int or float in seconds specifying maximum timeout
for the operations to last
:returns: obj loaded
:returns: CacheValueType loaded
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.monotonic()
Expand All @@ -222,16 +242,22 @@ async def _gets(self, key, encoding="utf-8", _conn=None):
@API.aiocache_enabled(fake_return=[])
@API.timeout
@API.plugins
async def multi_get(self, keys, loads_fn=None, namespace=None, _conn=None):
async def multi_get(
self,
keys: List[CacheKeyType],
loads_fn: Optional[Callable[[Any], CacheValueType]] = None,
namespace: Optional[str] = None,
_conn=None,
) -> List[Optional[CacheValueType]]:
"""
Get multiple values from the cache, values not found are Nones.

:param keys: list of str
:param keys: list of CacheKeyType
:param loads_fn: callable alternative to use as loads function
:param namespace: str alternative namespace to use
:param timeout: int or float in seconds specifying maximum timeout
for the operations to last
:returns: list of objs
:returns: list of CacheValueType
:raises: :class:`asyncio.TimeoutError` if it lasts more than self.timeout
"""
start = time.monotonic()
Expand Down Expand Up @@ -262,13 +288,20 @@ async def _multi_get(self, keys, encoding, _conn=None):
@API.timeout
@API.plugins
async def set(
self, key, value, ttl=SENTINEL, dumps_fn=None, namespace=None, _cas_token=None, _conn=None
):
self,
key: CacheKeyType,
value: CacheValueType,
ttl=SENTINEL,
dumps_fn: Optional[Callable[[CacheValueType], Any]] = None,
namespace: Optional[str] = None,
_cas_token=None,
_conn=None,
) -> bool:
"""
Stores the value in the given key with ttl if specified

:param key: str
:param value: obj
:param key: CacheKeyType
:param value: CacheValueType
:param ttl: int the expiration time in seconds. Due to memcached
restrictions if you want compatibility use int. In case you
need miliseconds, redis and memory support float ttls
Expand Down Expand Up @@ -298,14 +331,22 @@ async def _set(self, key, value, ttl, _cas_token=None, _conn=None):
@API.aiocache_enabled(fake_return=True)
@API.timeout
@API.plugins
async def multi_set(self, pairs, ttl=SENTINEL, dumps_fn=None, namespace=None, _conn=None):
async def multi_set(
self,
pairs: List[tuple[CacheKeyType, CacheValueType]],
ttl=SENTINEL,
dumps_fn: Optional[Callable[[CacheValueType], Any]] = None,
namespace: Optional[str] = None,
_conn=None,
) -> bool:
"""
Stores multiple values in the given keys.

:param pairs: list of two element iterables. First is key and second is value
:param ttl: int the expiration time in seconds. Due to memcached
restrictions if you want compatibility use int. In case you
need miliseconds, redis and memory support float ttls
:param pairs: list of two element iterables. First is CacheKeyType
and second is CacheValueType
:param ttl: int the expiration time in seconds. Due to memcached restrictions.
If you want compatibility use int. In case you need milliseconds,
redis and memory support float ttls
:param dumps_fn: callable alternative to use as dumps function
:param namespace: str alternative namespace to use
:param timeout: int or float in seconds specifying maximum timeout
Expand All @@ -326,7 +367,7 @@ async def multi_set(self, pairs, ttl=SENTINEL, dumps_fn=None, namespace=None, _c
"MULTI_SET %s %d (%.4f)s",
[key for key, value in tmp_pairs],
len(tmp_pairs),
time.monotonic() - start,
time.monotonic() - start
)
return True

Expand Down
4 changes: 2 additions & 2 deletions aiocache/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class RedLock(Generic[CacheKeyType]):

_EVENTS: Dict[str, asyncio.Event] = {}

def __init__(self, client: BaseCache[CacheKeyType], key: str, lease: Union[int, float]):
def __init__(self, client: BaseCache[CacheKeyType, Any], key: str, lease: Union[int, float]):
self.client = client
self.key = self.client.build_key(key + "-lock")
self.lease = lease
Expand Down Expand Up @@ -133,7 +133,7 @@ class OptimisticLock(Generic[CacheKeyType]):
If the lock is created with an unexisting key, there will never be conflicts.
"""

def __init__(self, client: BaseCache[CacheKeyType], key: str):
def __init__(self, client: BaseCache[CacheKeyType, Any], key: str):
self.client = client
self.key = key
self.ns_key = self.client.build_key(key)
Expand Down
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Optional, Union
from typing import Any, Optional, Union

from aiocache.base import BaseCache

Expand All @@ -19,7 +19,7 @@ def ensure_key(key: Union[str, Enum]) -> str:
return key


class AbstractBaseCache(BaseCache[str]):
class AbstractBaseCache(BaseCache[str, Any]):
"""BaseCache that can be mocked for NotImplementedError tests"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down