Signed-off-by: Mathieu Velten <mathieuv@matrix.org>tags/v1.69.0rc1
@@ -0,0 +1 @@ | |||
Add cache invalidation across workers to module API. |
@@ -29,7 +29,7 @@ class SynapsePlugin(Plugin): | |||
self, fullname: str | |||
) -> Optional[Callable[[MethodSigContext], CallableType]]: | |||
if fullname.startswith( | |||
"synapse.util.caches.descriptors._CachedFunction.__call__" | |||
"synapse.util.caches.descriptors.CachedFunction.__call__" | |||
) or fullname.startswith( | |||
"synapse.util.caches.descriptors._LruCachedFunction.__call__" | |||
): | |||
@@ -38,7 +38,7 @@ class SynapsePlugin(Plugin): | |||
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: | |||
"""Fixes the `_CachedFunction.__call__` signature to be correct. | |||
"""Fixes the `CachedFunction.__call__` signature to be correct. | |||
It already has *almost* the correct signature, except: | |||
@@ -125,7 +125,7 @@ from synapse.types import ( | |||
) | |||
from synapse.util import Clock | |||
from synapse.util.async_helpers import maybe_awaitable | |||
from synapse.util.caches.descriptors import cached | |||
from synapse.util.caches.descriptors import CachedFunction, cached | |||
from synapse.util.frozenutils import freeze | |||
if TYPE_CHECKING: | |||
@@ -836,6 +836,37 @@ class ModuleApi: | |||
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type] | |||
) | |||
def register_cached_function(self, cached_func: CachedFunction) -> None: | |||
"""Register a cached function that should be invalidated across workers. | |||
Invalidation local to a worker can be done directly using `cached_func.invalidate`, | |||
however invalidation that needs to go to other workers needs to call `invalidate_cache` | |||
on the module API instead. | |||
Args: | |||
cached_function: The cached function that will be registered to receive invalidation | |||
locally and from other workers. | |||
""" | |||
self._store.register_external_cached_function( | |||
f"{cached_func.__module__}.{cached_func.__name__}", cached_func | |||
) | |||
async def invalidate_cache( | |||
self, cached_func: CachedFunction, keys: Tuple[Any, ...] | |||
) -> None: | |||
"""Invalidate a cache entry of a cached function across workers. The cached function | |||
needs to be registered on all workers first with `register_cached_function`. | |||
Args: | |||
cached_function: The cached function that needs an invalidation | |||
keys: keys of the entry to invalidate, usually matching the arguments of the | |||
cached function. | |||
""" | |||
cached_func.invalidate(keys) | |||
await self._store.send_invalidation_to_replication( | |||
f"{cached_func.__module__}.{cached_func.__name__}", | |||
keys, | |||
) | |||
async def complete_sso_login_async( | |||
self, | |||
registered_user_id: str, | |||
@@ -15,12 +15,13 @@ | |||
# limitations under the License. | |||
import logging | |||
from abc import ABCMeta | |||
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union | |||
from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union | |||
from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401 | |||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection | |||
from synapse.types import get_domain_from_id | |||
from synapse.util import json_decoder | |||
from synapse.util.caches.descriptors import CachedFunction | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
@@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta): | |||
self.database_engine = database.engine | |||
self.db_pool = database | |||
self.external_cached_functions: Dict[str, CachedFunction] = {} | |||
def process_replication_rows( | |||
self, | |||
stream_name: str, | |||
@@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta): | |||
def _attempt_to_invalidate_cache( | |||
self, cache_name: str, key: Optional[Collection[Any]] | |||
) -> None: | |||
) -> bool: | |||
"""Attempts to invalidate the cache of the given name, ignoring if the | |||
cache doesn't exist. Mainly used for invalidating caches on workers, | |||
where they may not have the cache. | |||
@@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta): | |||
try: | |||
cache = getattr(self, cache_name) | |||
except AttributeError: | |||
# We probably haven't pulled in the cache in this worker, | |||
# which is fine. | |||
return | |||
# Check if an externally defined module cache has been registered | |||
cache = self.external_cached_functions.get(cache_name) | |||
if not cache: | |||
# We probably haven't pulled in the cache in this worker, | |||
# which is fine. | |||
return False | |||
if key is None: | |||
cache.invalidate_all() | |||
@@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta): | |||
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate) | |||
invalidate_method(tuple(key)) | |||
return True | |||
def register_external_cached_function( | |||
self, cache_name: str, func: CachedFunction | |||
) -> None: | |||
self.external_cached_functions[cache_name] = func | |||
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: | |||
""" | |||
@@ -33,7 +33,7 @@ from synapse.storage.database import ( | |||
) | |||
from synapse.storage.engines import PostgresEngine | |||
from synapse.storage.util.id_generators import MultiWriterIdGenerator | |||
from synapse.util.caches.descriptors import _CachedFunction | |||
from synapse.util.caches.descriptors import CachedFunction | |||
from synapse.util.iterutils import batch_iter | |||
if TYPE_CHECKING: | |||
@@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
return | |||
cache_func.invalidate(keys) | |||
await self.db_pool.runInteraction( | |||
"invalidate_cache_and_stream", | |||
self._send_invalidation_to_replication, | |||
await self.send_invalidation_to_replication( | |||
cache_func.__name__, | |||
keys, | |||
) | |||
@@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
def _invalidate_cache_and_stream( | |||
self, | |||
txn: LoggingTransaction, | |||
cache_func: _CachedFunction, | |||
cache_func: CachedFunction, | |||
keys: Tuple[Any, ...], | |||
) -> None: | |||
"""Invalidates the cache and adds it to the cache stream so slaves | |||
@@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
self._send_invalidation_to_replication(txn, cache_func.__name__, keys) | |||
def _invalidate_all_cache_and_stream( | |||
self, txn: LoggingTransaction, cache_func: _CachedFunction | |||
self, txn: LoggingTransaction, cache_func: CachedFunction | |||
) -> None: | |||
"""Invalidates the entire cache and adds it to the cache stream so slaves | |||
will know to invalidate their caches. | |||
@@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
txn, CURRENT_STATE_CACHE_NAME, [room_id] | |||
) | |||
async def send_invalidation_to_replication( | |||
self, cache_name: str, keys: Optional[Collection[Any]] | |||
) -> None: | |||
await self.db_pool.runInteraction( | |||
"send_invalidation_to_replication", | |||
self._send_invalidation_to_replication, | |||
cache_name, | |||
keys, | |||
) | |||
def _send_invalidation_to_replication( | |||
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] | |||
) -> None: | |||
@@ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any] | |||
F = TypeVar("F", bound=Callable[..., Any]) | |||
class _CachedFunction(Generic[F]): | |||
class CachedFunction(Generic[F]): | |||
invalidate: Any = None | |||
invalidate_all: Any = None | |||
prefill: Any = None | |||
@@ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): | |||
return ret2 | |||
wrapped = cast(_CachedFunction, _wrapped) | |||
wrapped = cast(CachedFunction, _wrapped) | |||
wrapped.cache = cache | |||
obj.__dict__[self.name] = wrapped | |||
@@ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): | |||
return make_deferred_yieldable(ret) | |||
wrapped = cast(_CachedFunction, _wrapped) | |||
wrapped = cast(CachedFunction, _wrapped) | |||
if self.num_args == 1: | |||
assert not self.tree | |||
@@ -572,7 +572,7 @@ def cached( | |||
iterable: bool = False, | |||
prune_unread_entries: bool = True, | |||
name: Optional[str] = None, | |||
) -> Callable[[F], _CachedFunction[F]]: | |||
) -> Callable[[F], CachedFunction[F]]: | |||
func = lambda orig: DeferredCacheDescriptor( | |||
orig, | |||
max_entries=max_entries, | |||
@@ -585,7 +585,7 @@ def cached( | |||
name=name, | |||
) | |||
return cast(Callable[[F], _CachedFunction[F]], func) | |||
return cast(Callable[[F], CachedFunction[F]], func) | |||
def cachedList( | |||
@@ -594,7 +594,7 @@ def cachedList( | |||
list_name: str, | |||
num_args: Optional[int] = None, | |||
name: Optional[str] = None, | |||
) -> Callable[[F], _CachedFunction[F]]: | |||
) -> Callable[[F], CachedFunction[F]]: | |||
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. | |||
Used to do batch lookups for an already created cache. One of the arguments | |||
@@ -631,7 +631,7 @@ def cachedList( | |||
name=name, | |||
) | |||
return cast(Callable[[F], _CachedFunction[F]], func) | |||
return cast(Callable[[F], CachedFunction[F]], func) | |||
def _get_cache_key_builder( | |||
@@ -0,0 +1,79 @@ | |||
# Copyright 2022 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
import synapse | |||
from synapse.module_api import cached | |||
from tests.replication._base import BaseMultiWorkerStreamTestCase | |||
logger = logging.getLogger(__name__) | |||
FIRST_VALUE = "one" | |||
SECOND_VALUE = "two" | |||
KEY = "mykey" | |||
class TestCache: | |||
current_value = FIRST_VALUE | |||
@cached() | |||
async def cached_function(self, user_id: str) -> str: | |||
return self.current_value | |||
class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase): | |||
servlets = [ | |||
synapse.rest.admin.register_servlets, | |||
] | |||
def test_module_cache_full_invalidation(self): | |||
main_cache = TestCache() | |||
self.hs.get_module_api().register_cached_function(main_cache.cached_function) | |||
worker_hs = self.make_worker_hs("synapse.app.generic_worker") | |||
worker_cache = TestCache() | |||
worker_hs.get_module_api().register_cached_function( | |||
worker_cache.cached_function | |||
) | |||
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) | |||
self.assertEqual( | |||
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) | |||
) | |||
main_cache.current_value = SECOND_VALUE | |||
worker_cache.current_value = SECOND_VALUE | |||
# No invalidation yet, should return the cached value on both the main process and the worker | |||
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) | |||
self.assertEqual( | |||
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) | |||
) | |||
# Full invalidation on the main process, should be replicated on the worker that | |||
# should returned the updated value too | |||
self.get_success( | |||
self.hs.get_module_api().invalidate_cache( | |||
main_cache.cached_function, (KEY,) | |||
) | |||
) | |||
self.assertEqual( | |||
SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)) | |||
) | |||
self.assertEqual( | |||
SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY)) | |||
) |