Browse Source

Add cache invalidation across workers to module API (#13667)

Signed-off-by: Mathieu Velten <mathieuv@matrix.org>
tags/v1.69.0rc1
Mathieu Velten 1 year ago
committed by GitHub
parent
commit
6bd8763804
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 153 additions and 21 deletions
  1. +1
    -0
      changelog.d/13667.feature
  2. +2
    -2
      scripts-dev/mypy_synapse_plugin.py
  3. +32
    -1
      synapse/module_api/__init__.py
  4. +18
    -5
      synapse/storage/_base.py
  5. +14
    -6
      synapse/storage/databases/main/cache.py
  6. +7
    -7
      synapse/util/caches/descriptors.py
  7. +79
    -0
      tests/replication/test_module_cache_invalidation.py

+ 1
- 0
changelog.d/13667.feature View File

@@ -0,0 +1 @@
Add cache invalidation across workers to module API.

+ 2
- 2
scripts-dev/mypy_synapse_plugin.py View File

@@ -29,7 +29,7 @@ class SynapsePlugin(Plugin):
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
"synapse.util.caches.descriptors.CachedFunction.__call__"
) or fullname.startswith(
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
):
@@ -38,7 +38,7 @@ class SynapsePlugin(Plugin):


def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `_CachedFunction.__call__` signature to be correct.
"""Fixes the `CachedFunction.__call__` signature to be correct.

It already has *almost* the correct signature, except:



+ 32
- 1
synapse/module_api/__init__.py View File

@@ -125,7 +125,7 @@ from synapse.types import (
)
from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import CachedFunction, cached
from synapse.util.frozenutils import freeze

if TYPE_CHECKING:
@@ -836,6 +836,37 @@ class ModuleApi:
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
)

def register_cached_function(self, cached_func: CachedFunction) -> None:
"""Register a cached function that should be invalidated across workers.
Invalidation local to a worker can be done directly using `cached_func.invalidate`,
however invalidation that needs to go to other workers needs to call `invalidate_cache`
on the module API instead.

Args:
cached_function: The cached function that will be registered to receive invalidation
locally and from other workers.
"""
self._store.register_external_cached_function(
f"{cached_func.__module__}.{cached_func.__name__}", cached_func
)

async def invalidate_cache(
self, cached_func: CachedFunction, keys: Tuple[Any, ...]
) -> None:
"""Invalidate a cache entry of a cached function across workers. The cached function
needs to be registered on all workers first with `register_cached_function`.

Args:
cached_function: The cached function that needs an invalidation
keys: keys of the entry to invalidate, usually matching the arguments of the
cached function.
"""
cached_func.invalidate(keys)
await self._store.send_invalidation_to_replication(
f"{cached_func.__module__}.{cached_func.__name__}",
keys,
)

async def complete_sso_login_async(
self,
registered_user_id: str,


+ 18
- 5
synapse/storage/_base.py View File

@@ -15,12 +15,13 @@
# limitations under the License.
import logging
from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union

from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import get_domain_from_id
from synapse.util import json_decoder
from synapse.util.caches.descriptors import CachedFunction

if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta):
self.database_engine = database.engine
self.db_pool = database

self.external_cached_functions: Dict[str, CachedFunction] = {}

def process_replication_rows(
self,
stream_name: str,
@@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta):

def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
) -> None:
) -> bool:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
@@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta):
try:
cache = getattr(self, cache_name)
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
return
# Check if an externally defined module cache has been registered
cache = self.external_cached_functions.get(cache_name)
if not cache:
# We probably haven't pulled in the cache in this worker,
# which is fine.
return False

if key is None:
cache.invalidate_all()
@@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta):
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
invalidate_method(tuple(key))

return True

def register_external_cached_function(
self, cache_name: str, func: CachedFunction
) -> None:
self.external_cached_functions[cache_name] = func


def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""


+ 14
- 6
synapse/storage/databases/main/cache.py View File

@@ -33,7 +33,7 @@ from synapse.storage.database import (
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.caches.descriptors import CachedFunction
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
@@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return

cache_func.invalidate(keys)
await self.db_pool.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
await self.send_invalidation_to_replication(
cache_func.__name__,
keys,
)
@@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
cache_func: CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
@@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)

def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn, CURRENT_STATE_CACHE_NAME, [room_id]
)

async def send_invalidation_to_replication(
self, cache_name: str, keys: Optional[Collection[Any]]
) -> None:
await self.db_pool.runInteraction(
"send_invalidation_to_replication",
self._send_invalidation_to_replication,
cache_name,
keys,
)

def _send_invalidation_to_replication(
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:


+ 7
- 7
synapse/util/caches/descriptors.py View File

@@ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any])


class _CachedFunction(Generic[F]):
class CachedFunction(Generic[F]):
invalidate: Any = None
invalidate_all: Any = None
prefill: Any = None
@@ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):

return ret2

wrapped = cast(_CachedFunction, _wrapped)
wrapped = cast(CachedFunction, _wrapped)
wrapped.cache = cache
obj.__dict__[self.name] = wrapped

@@ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):

return make_deferred_yieldable(ret)

wrapped = cast(_CachedFunction, _wrapped)
wrapped = cast(CachedFunction, _wrapped)

if self.num_args == 1:
assert not self.tree
@@ -572,7 +572,7 @@ def cached(
iterable: bool = False,
prune_unread_entries: bool = True,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
) -> Callable[[F], CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
@@ -585,7 +585,7 @@ def cached(
name=name,
)

return cast(Callable[[F], _CachedFunction[F]], func)
return cast(Callable[[F], CachedFunction[F]], func)


def cachedList(
@@ -594,7 +594,7 @@ def cachedList(
list_name: str,
num_args: Optional[int] = None,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
) -> Callable[[F], CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.

Used to do batch lookups for an already created cache. One of the arguments
@@ -631,7 +631,7 @@ def cachedList(
name=name,
)

return cast(Callable[[F], _CachedFunction[F]], func)
return cast(Callable[[F], CachedFunction[F]], func)


def _get_cache_key_builder(


+ 79
- 0
tests/replication/test_module_cache_invalidation.py View File

@@ -0,0 +1,79 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging

import synapse
from synapse.module_api import cached

from tests.replication._base import BaseMultiWorkerStreamTestCase

logger = logging.getLogger(__name__)

FIRST_VALUE = "one"
SECOND_VALUE = "two"

KEY = "mykey"


class TestCache:
current_value = FIRST_VALUE

@cached()
async def cached_function(self, user_id: str) -> str:
return self.current_value


class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [
synapse.rest.admin.register_servlets,
]

def test_module_cache_full_invalidation(self):
main_cache = TestCache()
self.hs.get_module_api().register_cached_function(main_cache.cached_function)

worker_hs = self.make_worker_hs("synapse.app.generic_worker")

worker_cache = TestCache()
worker_hs.get_module_api().register_cached_function(
worker_cache.cached_function
)

self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

main_cache.current_value = SECOND_VALUE
worker_cache.current_value = SECOND_VALUE
# No invalidation yet, should return the cached value on both the main process and the worker
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

# Full invalidation on the main process, should be replicated on the worker that
# should returned the updated value too
self.get_success(
self.hs.get_module_api().invalidate_cache(
main_cache.cached_function, (KEY,)
)
)

self.assertEqual(
SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
)
self.assertEqual(
SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

Loading…
Cancel
Save