Co-authored-by: Patrick Cloke <patrickc@matrix.org>tags/v1.97.0rc1
@@ -0,0 +1 @@ | |||
Improve the performance of claiming encryption keys in multi-worker deployments. |
@@ -1116,7 +1116,7 @@ class DatabasePool: | |||
def simple_insert_many_txn( | |||
txn: LoggingTransaction, | |||
table: str, | |||
keys: Collection[str], | |||
keys: Sequence[str], | |||
values: Collection[Iterable[Any]], | |||
) -> None: | |||
"""Executes an INSERT query on the named table. | |||
@@ -483,6 +483,30 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
txn.call_after(cache_func.invalidate, keys) | |||
self._send_invalidation_to_replication(txn, cache_func.__name__, keys) | |||
def _invalidate_cache_and_stream_bulk( | |||
self, | |||
txn: LoggingTransaction, | |||
cache_func: CachedFunction, | |||
key_tuples: Collection[Tuple[Any, ...]], | |||
) -> None: | |||
"""A bulk version of _invalidate_cache_and_stream. | |||
Locally invalidate every key-tuple in `key_tuples`, then emit invalidations | |||
for each key-tuple over replication. | |||
This implementation is more efficient than a loop which repeatedly calls the | |||
non-bulk version. | |||
""" | |||
if not key_tuples: | |||
return | |||
for keys in key_tuples: | |||
txn.call_after(cache_func.invalidate, keys) | |||
self._send_invalidation_to_replication_bulk( | |||
txn, cache_func.__name__, key_tuples | |||
) | |||
def _invalidate_all_cache_and_stream( | |||
self, txn: LoggingTransaction, cache_func: CachedFunction | |||
) -> None: | |||
@@ -564,10 +588,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
if isinstance(self.database_engine, PostgresEngine): | |||
assert self._cache_id_gen is not None | |||
# get_next() returns a context manager which is designed to wrap | |||
# the transaction. However, we want to only get an ID when we want | |||
# to use it, here, so we need to call __enter__ manually, and have | |||
# __exit__ called after the transaction finishes. | |||
stream_id = self._cache_id_gen.get_next_txn(txn) | |||
txn.call_after(self.hs.get_notifier().on_new_replication_data) | |||
@@ -586,6 +606,53 @@ class CacheInvalidationWorkerStore(SQLBaseStore): | |||
}, | |||
) | |||
def _send_invalidation_to_replication_bulk( | |||
self, | |||
txn: LoggingTransaction, | |||
cache_name: str, | |||
key_tuples: Collection[Tuple[Any, ...]], | |||
) -> None: | |||
"""Announce the invalidation of multiple (but not all) cache entries. | |||
This is more efficient than repeated calls to the non-bulk version. It should | |||
NOT be used to invalidating the entire cache: use | |||
`_send_invalidation_to_replication` with keys=None. | |||
Note that this does *not* invalidate the cache locally. | |||
Args: | |||
txn | |||
cache_name | |||
key_tuples: Key-tuples to invalidate. Assumed to be non-empty. | |||
""" | |||
if isinstance(self.database_engine, PostgresEngine): | |||
assert self._cache_id_gen is not None | |||
stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples)) | |||
ts = self._clock.time_msec() | |||
txn.call_after(self.hs.get_notifier().on_new_replication_data) | |||
self.db_pool.simple_insert_many_txn( | |||
txn, | |||
table="cache_invalidation_stream_by_instance", | |||
keys=( | |||
"stream_id", | |||
"instance_name", | |||
"cache_func", | |||
"keys", | |||
"invalidation_ts", | |||
), | |||
values=[ | |||
# We convert key_tuples to a list here because psycopg2 serialises | |||
# lists as pq arrrays, but serialises tuples as "composite types". | |||
# (We need an array because the `keys` column has type `[]text`.) | |||
# See: | |||
# https://www.psycopg.org/docs/usage.html#adapt-list | |||
# https://www.psycopg.org/docs/usage.html#adapt-tuple | |||
(stream_id, self._instance_name, cache_name, list(key_tuple), ts) | |||
for stream_id, key_tuple in zip(stream_ids, key_tuples) | |||
], | |||
) | |||
def get_cache_stream_token_for_writer(self, instance_name: str) -> int: | |||
if self._cache_id_gen: | |||
return self._cache_id_gen.get_current_token_for_writer(instance_name) | |||
@@ -1237,13 +1237,11 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
for user_id, device_id, algorithm, key_id, key_json in claimed_keys: | |||
device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) | |||
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) | |||
if (user_id, device_id) in seen_user_device: | |||
continue | |||
seen_user_device.add((user_id, device_id)) | |||
self._invalidate_cache_and_stream( | |||
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) | |||
) | |||
self._invalidate_cache_and_stream_bulk( | |||
txn, self.get_e2e_unused_fallback_key_types, seen_user_device | |||
) | |||
return results | |||
@@ -1376,14 +1374,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list) | |||
) | |||
seen_user_device: Set[Tuple[str, str]] = set() | |||
for user_id, device_id, _, _, _ in otk_rows: | |||
if (user_id, device_id) in seen_user_device: | |||
continue | |||
seen_user_device.add((user_id, device_id)) | |||
self._invalidate_cache_and_stream( | |||
txn, self.count_e2e_one_time_keys, (user_id, device_id) | |||
) | |||
seen_user_device = { | |||
(user_id, device_id) for user_id, device_id, _, _, _ in otk_rows | |||
} | |||
self._invalidate_cache_and_stream_bulk( | |||
txn, | |||
self.count_e2e_one_time_keys, | |||
seen_user_device, | |||
) | |||
return otk_rows | |||
@@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): | |||
next_id = self._load_next_id_txn(txn) | |||
txn.call_after(self._mark_id_as_finished, next_id) | |||
txn.call_on_exception(self._mark_id_as_finished, next_id) | |||
txn.call_after(self._mark_ids_as_finished, [next_id]) | |||
txn.call_on_exception(self._mark_ids_as_finished, [next_id]) | |||
txn.call_after(self._notifier.notify_replication) | |||
# Update the `stream_positions` table with newly updated stream | |||
@@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): | |||
return self._return_factor * next_id | |||
def _mark_id_as_finished(self, next_id: int) -> None: | |||
"""The ID has finished being processed so we should advance the | |||
def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]: | |||
""" | |||
Usage: | |||
stream_id = stream_id_gen.get_next_txn(txn) | |||
# ... persist event ... | |||
""" | |||
# If we have a list of instances that are allowed to write to this | |||
# stream, make sure we're in it. | |||
if self._writers and self._instance_name not in self._writers: | |||
raise Exception("Tried to allocate stream ID on non-writer") | |||
next_ids = self._load_next_mult_id_txn(txn, n) | |||
txn.call_after(self._mark_ids_as_finished, next_ids) | |||
txn.call_on_exception(self._mark_ids_as_finished, next_ids) | |||
txn.call_after(self._notifier.notify_replication) | |||
# Update the `stream_positions` table with newly updated stream | |||
# ID (unless self._writers is not set in which case we don't | |||
# bother, as nothing will read it). | |||
# | |||
# We only do this on the success path so that the persisted current | |||
# position points to a persisted row with the correct instance name. | |||
if self._writers: | |||
txn.call_after( | |||
run_as_background_process, | |||
"MultiWriterIdGenerator._update_table", | |||
self._db.runInteraction, | |||
"MultiWriterIdGenerator._update_table", | |||
self._update_stream_positions_table_txn, | |||
) | |||
return [self._return_factor * next_id for next_id in next_ids] | |||
def _mark_ids_as_finished(self, next_ids: List[int]) -> None: | |||
"""These IDs have finished being processed so we should advance the | |||
current position if possible. | |||
""" | |||
with self._lock: | |||
self._unfinished_ids.discard(next_id) | |||
self._finished_ids.add(next_id) | |||
self._unfinished_ids.difference_update(next_ids) | |||
self._finished_ids.update(next_ids) | |||
new_cur: Optional[int] = None | |||
@@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator): | |||
curr, new_cur, self._max_position_of_local_instance | |||
) | |||
self._add_persisted_position(next_id) | |||
# TODO Can we call this for just the last position or somehow batch | |||
# _add_persisted_position. | |||
for next_id in next_ids: | |||
self._add_persisted_position(next_id) | |||
def get_current_token(self) -> int: | |||
return self.get_persisted_upto_position() | |||
@@ -933,8 +972,7 @@ class _MultiWriterCtxManager: | |||
exc: Optional[BaseException], | |||
tb: Optional[TracebackType], | |||
) -> bool: | |||
for i in self.stream_ids: | |||
self.id_gen._mark_id_as_finished(i) | |||
self.id_gen._mark_ids_as_finished(self.stream_ids) | |||
self.notifier.notify_replication() | |||
@@ -0,0 +1,117 @@ | |||
# Copyright 2023 The Matrix.org Foundation C.I.C. | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from unittest.mock import Mock, call | |||
from synapse.storage.database import LoggingTransaction | |||
from tests.replication._base import BaseMultiWorkerStreamTestCase | |||
from tests.unittest import HomeserverTestCase | |||
class CacheInvalidationTestCase(HomeserverTestCase): | |||
def setUp(self) -> None: | |||
super().setUp() | |||
self.store = self.hs.get_datastores().main | |||
def test_bulk_invalidation(self) -> None: | |||
master_invalidate = Mock() | |||
self.store._get_cached_user_device.invalidate = master_invalidate | |||
keys_to_invalidate = [ | |||
("a", "b"), | |||
("c", "d"), | |||
("e", "f"), | |||
("g", "h"), | |||
] | |||
def test_txn(txn: LoggingTransaction) -> None: | |||
self.store._invalidate_cache_and_stream_bulk( | |||
txn, | |||
# This is an arbitrarily chosen cached store function. It was chosen | |||
# because it takes more than one argument. We'll use this later to | |||
# check that the invalidation was actioned over replication. | |||
cache_func=self.store._get_cached_user_device, | |||
key_tuples=keys_to_invalidate, | |||
) | |||
self.get_success( | |||
self.store.db_pool.runInteraction( | |||
"test_invalidate_cache_and_stream_bulk", test_txn | |||
) | |||
) | |||
master_invalidate.assert_has_calls( | |||
[call(key_list) for key_list in keys_to_invalidate], | |||
any_order=True, | |||
) | |||
class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase): | |||
def setUp(self) -> None: | |||
super().setUp() | |||
self.store = self.hs.get_datastores().main | |||
def test_bulk_invalidation_replicates(self) -> None: | |||
"""Like test_bulk_invalidation, but also checks the invalidations replicate.""" | |||
master_invalidate = Mock() | |||
worker_invalidate = Mock() | |||
self.store._get_cached_user_device.invalidate = master_invalidate | |||
worker = self.make_worker_hs("synapse.app.generic_worker") | |||
worker_ds = worker.get_datastores().main | |||
worker_ds._get_cached_user_device.invalidate = worker_invalidate | |||
keys_to_invalidate = [ | |||
("a", "b"), | |||
("c", "d"), | |||
("e", "f"), | |||
("g", "h"), | |||
] | |||
def test_txn(txn: LoggingTransaction) -> None: | |||
self.store._invalidate_cache_and_stream_bulk( | |||
txn, | |||
# This is an arbitrarily chosen cached store function. It was chosen | |||
# because it takes more than one argument. We'll use this later to | |||
# check that the invalidation was actioned over replication. | |||
cache_func=self.store._get_cached_user_device, | |||
key_tuples=keys_to_invalidate, | |||
) | |||
assert self.store._cache_id_gen is not None | |||
initial_token = self.store._cache_id_gen.get_current_token() | |||
self.get_success( | |||
self.database_pool.runInteraction( | |||
"test_invalidate_cache_and_stream_bulk", test_txn | |||
) | |||
) | |||
second_token = self.store._cache_id_gen.get_current_token() | |||
self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate)) | |||
self.get_success( | |||
worker.get_replication_data_handler().wait_for_stream_position( | |||
"master", "caches", second_token | |||
) | |||
) | |||
master_invalidate.assert_has_calls( | |||
[call(key_list) for key_list in keys_to_invalidate], | |||
any_order=True, | |||
) | |||
worker_invalidate.assert_has_calls( | |||
[call(key_list) for key_list in keys_to_invalidate], | |||
any_order=True, | |||
) |