Skip to content

Commit

Permalink
Redis additions (#105)
Browse files Browse the repository at this point in the history
Add methods to `CacheManager` for getting & setting multiple keys.

Also, add a way to use Redis pipeline to the client, enabling
transactional operations.
  • Loading branch information
lukasz-matter authored Jun 26, 2024
1 parent c1a4ed4 commit b4fd15c
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 5 deletions.
50 changes: 45 additions & 5 deletions matter_persistence/redis/async_redis_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import contextlib
from collections.abc import AsyncIterator, Mapping, Sequence
from datetime import timedelta

from redis import asyncio as aioredis

from matter_persistence.decorators import retry_if_failed
from matter_persistence.redis.exceptions import CacheConnectionNotEstablishedError


class AsyncRedisClient:
Expand Down Expand Up @@ -54,8 +57,8 @@ def __init__(
self, connection: aioredis.Redis | None = None, connection_pool: aioredis.ConnectionPool | None = None
):
if (connection and connection_pool is None) or (connection is None and connection_pool):
self.connection = connection
self._connection_pool = connection_pool
self.connection: aioredis.Redis | None = connection
self._connection_pool: aioredis.ConnectionPool | None = connection_pool
else:
raise ValueError(
"Invalid argument combination. Please provide either: "
Expand All @@ -78,8 +81,18 @@ async def connect(self):
async def close(self):
if self.connection:
await self.connection.aclose()
if self._connection_pool:
await self._connection_pool.aclose()

@contextlib.asynccontextmanager
async def pipeline(
self,
transaction: bool = True,
) -> AsyncIterator[aioredis.client.Pipeline]:
if not isinstance(self.connection, aioredis.Redis):
raise CacheConnectionNotEstablishedError(
"You cannot use the client if the connection isn't established. Use as async context manager."
)
async with self.connection.pipeline(transaction=transaction) as pipe:
yield pipe

@retry_if_failed
async def set_value(self, key: str, value: str, ttl=None):
Expand All @@ -90,9 +103,27 @@ async def set_value(self, key: str, value: str, ttl=None):
return result

@retry_if_failed
async def get_value(self, key: str):
async def set_many_values(self, values: Mapping[str, str], ttl: int | None = None) -> None:
async with self.pipeline() as pipe:
await pipe.mset(values)
if ttl is not None:
for key in values.keys():
await pipe.expire(key, ttl)
await pipe.execute()

@retry_if_failed
async def get_value(self, key: str) -> str:
return await self.connection.get(key) # type: ignore

@retry_if_failed
async def get_many_values(self, keys: Sequence[str]) -> dict[str, str]:
if not isinstance(self.connection, aioredis.Redis):
raise CacheConnectionNotEstablishedError(
"You cannot use the client if the connection isn't established. Use as async context manager."
)
response = await self.connection.mget(keys)
return dict(zip(keys, response, strict=True))

@retry_if_failed
async def set_hash_field(self, hash_key: str, field: str, value: str, ttl: int | timedelta | None = None):
result = await self.connection.hset(hash_key, field, value) # type: ignore
Expand Down Expand Up @@ -120,6 +151,15 @@ async def exists(self, key_or_hash: str, field: str | None = None):
else:
return await self.connection.hexists(key_or_hash, field) # type: ignore

@retry_if_failed
async def exists_many(self, keys: Sequence[str]) -> bool:
if not isinstance(self.connection, aioredis.Redis):
raise CacheConnectionNotEstablishedError(
"You cannot use the client if the connection isn't established. Use as async context manager."
)
number_of_existing_keys: int = await self.connection.exists(*keys)
return number_of_existing_keys == len(keys)

@retry_if_failed
async def is_alive(self):
return await self.connection.ping()
4 changes: 4 additions & 0 deletions matter_persistence/redis/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ class CacheRecordNotSavedError(DetailedException):

class CacheServerError(DetailedException):
TOPIC = "Cache Server Error"


class CacheConnectionNotEstablishedError(DetailedException):
TOPIC = "Cache Connection Not Established"
43 changes: 43 additions & 0 deletions matter_persistence/redis/manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import Any
from uuid import UUID

Expand Down Expand Up @@ -185,6 +186,26 @@ async def save_with_key(

return result

async def save_many_with_keys(
self,
values_to_store: dict[str, Any],
object_class: type[Model] | None = None,
expiration_in_seconds: int | None = None,
) -> None:
object_name = object_class.__name__ if object_class else None

async with self.__get_cache_client() as cache_client:
if object_class is not None:
processed_input = {
CacheHelper.create_basic_hash_key(key, object_name): value.model_dump_json()
for key, value in values_to_store.items()
}
else:
processed_input = {
CacheHelper.create_basic_hash_key(key, object_name): value for key, value in values_to_store.items()
}
await cache_client.set_many_values(processed_input, ttl=expiration_in_seconds)

async def get_with_key(self, key: str, object_class: type[Model] | None = None) -> Any:
object_name = object_class.__name__ if object_class else None
hash_key = CacheHelper.create_basic_hash_key(key, object_name)
Expand All @@ -204,6 +225,28 @@ async def get_with_key(self, key: str, object_class: type[Model] | None = None)

return value

async def get_many_with_keys(
self, keys: Sequence[str], object_class: type[Model] | None = None
) -> dict[str, str | list[str] | Model | list[Model]]:
object_name = object_class.__name__ if object_class else None
return_set: dict[str, str | list[str] | Model | list[Model]] = {}
async with self.__get_cache_client() as cache_client:
processed_input = {
CacheHelper.create_basic_hash_key(original_key, object_name): original_key for original_key in keys
}
response: dict[str, str | list[str]] = await cache_client.get_many_values(processed_input.keys())
if object_class:
for key, value in response.items():
if isinstance(value, list):
return_set[processed_input[key]] = [object_class.model_validate_json(item) for item in value]
elif value is not None:
return_set[processed_input[key]] = object_class.model_validate_json(value)
else:
return_set[processed_input[key]] = value
else:
return_set = {processed_input[key]: value for key, value in response.items()}
return return_set

async def delete_with_key(
self,
key: str,
Expand Down
26 changes: 26 additions & 0 deletions tests/redis/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,32 @@ async def test_cache_manager_save_with_key_and_get_with_key_success(cache_manage
assert await cache_manager.get_with_key("key", TestDTO)


async def test_cache_manager_save_and_get_many_objects_with_keys_success(cache_manager: CacheManager) -> None:
test_dtos = {f"key_{i}": TestDTO(test_field=i) for i in range(10)}
await cache_manager.save_many_with_keys(test_dtos, TestDTO, 100)
response = await cache_manager.get_many_with_keys(test_dtos, TestDTO)
assert response == test_dtos


async def test_cache_manager_save_and_get_many_raw_values_with_keys_success(cache_manager: CacheManager) -> None:
test_input = {f"key_{i}": f"test_value_{i}" for i in range(10)}
await cache_manager.save_many_with_keys(test_input, None, 100)
response = await cache_manager.get_many_with_keys(test_input, None)
for key, value in response.items():
assert test_input[key] == value.decode()


async def test_cache_manager_get_many_returns_none_for_missing_keys(
cache_manager: CacheManager, test_dto: TestDTO
) -> None:
await cache_manager.save_with_key("key", test_dto, TestDTO)
response = await cache_manager.get_many_with_keys(("key", "key2", "key3"), TestDTO)
assert len(response.keys()) == 3
assert response["key"] == test_dto
assert response["key2"] is None
assert response["key3"] is None


async def test_cache_manager_save_with_key_and_get_with_key_expired(cache_manager, test_dto):
await cache_manager.save_with_key("key", test_dto, TestDTO, 1)
await asyncio.sleep(1.05)
Expand Down

0 comments on commit b4fd15c

Please sign in to comment.