diff --git a/changelog.d/16261.misc b/changelog.d/16261.misc new file mode 100644 index 0000000000..d3ad59ca4a --- /dev/null +++ b/changelog.d/16261.misc @@ -0,0 +1 @@ +Simplify server key storage. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 260aab3241..fe86f54d80 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -23,12 +23,7 @@ from signedjson.key import ( get_verify_key, is_signing_algorithm_supported, ) -from signedjson.sign import ( - SignatureVerifyException, - encode_canonical_json, - signature_ids, - verify_signed_json, -) +from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json from signedjson.types import VerifyKey from unpaddedbase64 import decode_base64 @@ -596,24 +591,12 @@ class BaseV2KeyFetcher(KeyFetcher): verify_key=verify_key, valid_until_ts=key_data["expired_ts"] ) - key_json_bytes = encode_canonical_json(response_json) - - await make_deferred_yieldable( - defer.gatherResults( - [ - run_in_background( - self.store.store_server_keys_json, - server_name=server_name, - key_id=key_id, - from_server=from_server, - ts_now_ms=time_added_ms, - ts_expires_ms=ts_valid_until_ms, - key_json_bytes=key_json_bytes, - ) - for key_id in verify_keys - ], - consumeErrors=True, - ).addErrback(unwrapFirstError) + await self.store.store_server_keys_response( + server_name=server_name, + from_server=from_server, + ts_added_ms=time_added_ms, + verify_keys=verify_keys, + response_json=response_json, ) return verify_keys @@ -775,10 +758,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): keys.setdefault(server_name, {}).update(processed_response) - await self.store.store_server_signature_keys( - perspective_name, time_now_ms, added_keys - ) - return keys def _validate_perspectives_response( diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py index 57aa4921e1..41563371dc 100644 --- a/synapse/storage/databases/main/keys.py +++ b/synapse/storage/databases/main/keys.py @@ -16,14 +16,17 @@ import itertools import json import logging -from typing import Dict, Iterable, Mapping, Optional, Tuple +from typing import Dict, Iterable, Optional, Tuple +from canonicaljson import encode_canonical_json from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 +from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote from synapse.storage.types import Cursor +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached, cachedList from synapse.util.iterutils import batch_iter @@ -36,162 +39,84 @@ db_binary_type = memoryview class KeyStore(CacheInvalidationWorkerStore): """Persistence for signature verification keys""" - @cached() - def _get_server_signature_key( - self, server_name_and_key_id: Tuple[str, str] - ) -> FetchKeyResult: - raise NotImplementedError() - - @cachedList( - cached_method_name="_get_server_signature_key", - list_name="server_name_and_key_ids", - ) - async def get_server_signature_keys( - self, server_name_and_key_ids: Iterable[Tuple[str, str]] - ) -> Dict[Tuple[str, str], FetchKeyResult]: - """ - Args: - server_name_and_key_ids: - iterable of (server_name, key-id) tuples to fetch keys for - - Returns: - A map from (server_name, key_id) -> FetchKeyResult, or None if the - key is unknown - """ - keys = {} - - def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: - """Processes a batch of keys to fetch, and adds the result to `keys`.""" - - # batch_iter always returns tuples so it's safe to do len(batch) - sql = """ - SELECT server_name, key_id, verify_key, ts_valid_until_ms - FROM server_signature_keys WHERE 1=0 - """ + " OR (server_name=? AND key_id=?)" * len( - batch - ) - - txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) - - for row in txn: - server_name, key_id, key_bytes, ts_valid_until_ms = row - - if ts_valid_until_ms is None: - # Old keys may be stored with a ts_valid_until_ms of null, - # in which case we treat this as if it was set to `0`, i.e. - # it won't match key requests that define a minimum - # `ts_valid_until_ms`. - ts_valid_until_ms = 0 - - keys[(server_name, key_id)] = FetchKeyResult( - verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), - valid_until_ts=ts_valid_until_ms, - ) - - def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: - for batch in batch_iter(server_name_and_key_ids, 50): - _get_keys(txn, batch) - return keys - - return await self.db_pool.runInteraction("get_server_signature_keys", _txn) - - async def store_server_signature_keys( + async def store_server_keys_response( self, + server_name: str, from_server: str, ts_added_ms: int, - verify_keys: Mapping[Tuple[str, str], FetchKeyResult], + verify_keys: Dict[str, FetchKeyResult], + response_json: JsonDict, ) -> None: - """Stores NACL verification keys for remote servers. + """Stores the keys for the given server that we got from `from_server`. + Args: - from_server: Where the verification keys were looked up - ts_added_ms: The time to record that the key was added - verify_keys: - keys to be stored. Each entry is a triplet of - (server_name, key_id, key). + server_name: The owner of the keys + from_server: Which server we got the keys from + ts_added_ms: When we're adding the keys + verify_keys: The decoded keys + response_json: The full *signed* response JSON that contains the keys. """ - key_values = [] - value_values = [] - invalidations = [] - for (server_name, key_id), fetch_result in verify_keys.items(): - key_values.append((server_name, key_id)) - value_values.append( - ( - from_server, - ts_added_ms, - fetch_result.valid_until_ts, - db_binary_type(fetch_result.verify_key.encode()), - ) - ) - # invalidate takes a tuple corresponding to the params of - # _get_server_signature_key. _get_server_signature_key only takes one - # param, which is itself the 2-tuple (server_name, key_id). - invalidations.append((server_name, key_id)) - await self.db_pool.simple_upsert_many( - table="server_signature_keys", - key_names=("server_name", "key_id"), - key_values=key_values, - value_names=( - "from_server", - "ts_added_ms", - "ts_valid_until_ms", - "verify_key", - ), - value_values=value_values, - desc="store_server_signature_keys", - ) + key_json_bytes = encode_canonical_json(response_json) + + def store_server_keys_response_txn(txn: LoggingTransaction) -> None: + self.db_pool.simple_upsert_many_txn( + txn, + table="server_signature_keys", + key_names=("server_name", "key_id"), + key_values=[(server_name, key_id) for key_id in verify_keys], + value_names=( + "from_server", + "ts_added_ms", + "ts_valid_until_ms", + "verify_key", + ), + value_values=[ + ( + from_server, + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(fetch_result.verify_key.encode()), + ) + for fetch_result in verify_keys.values() + ], + ) - invalidate = self._get_server_signature_key.invalidate - for i in invalidations: - invalidate((i,)) + self.db_pool.simple_upsert_many_txn( + txn, + table="server_keys_json", + key_names=("server_name", "key_id", "from_server"), + key_values=[ + (server_name, key_id, from_server) for key_id in verify_keys + ], + value_names=( + "ts_added_ms", + "ts_valid_until_ms", + "key_json", + ), + value_values=[ + ( + ts_added_ms, + fetch_result.valid_until_ts, + db_binary_type(key_json_bytes), + ) + for fetch_result in verify_keys.values() + ], + ) - async def store_server_keys_json( - self, - server_name: str, - key_id: str, - from_server: str, - ts_now_ms: int, - ts_expires_ms: int, - key_json_bytes: bytes, - ) -> None: - """Stores the JSON bytes for a set of keys from a server - The JSON should be signed by the originating server, the intermediate - server, and by this server. Updates the value for the - (server_name, key_id, from_server) triplet if one already existed. - Args: - server_name: The name of the server. - key_id: The identifier of the key this JSON is for. - from_server: The server this JSON was fetched from. - ts_now_ms: The time now in milliseconds. - ts_valid_until_ms: The time when this json stops being valid. - key_json_bytes: The encoded JSON. - """ - await self.db_pool.simple_upsert( - table="server_keys_json", - keyvalues={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - }, - values={ - "server_name": server_name, - "key_id": key_id, - "from_server": from_server, - "ts_added_ms": ts_now_ms, - "ts_valid_until_ms": ts_expires_ms, - "key_json": db_binary_type(key_json_bytes), - }, - desc="store_server_keys_json", - ) + # invalidate takes a tuple corresponding to the params of + # _get_server_keys_json. _get_server_keys_json only takes one + # param, which is itself the 2-tuple (server_name, key_id). + for key_id in verify_keys: + self._invalidate_cache_and_stream( + txn, self._get_server_keys_json, ((server_name, key_id),) + ) + self._invalidate_cache_and_stream( + txn, self.get_server_key_json_for_remote, (server_name, key_id) + ) - # invalidate takes a tuple corresponding to the params of - # _get_server_keys_json. _get_server_keys_json only takes one - # param, which is itself the 2-tuple (server_name, key_id). - await self.invalidate_cache_and_stream( - "_get_server_keys_json", ((server_name, key_id),) - ) - await self.invalidate_cache_and_stream( - "get_server_key_json_for_remote", (server_name, key_id) + await self.db_pool.runInteraction( + "store_server_keys_response", store_server_keys_response_txn ) @cached() diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index f93ba5d4cf..c5700771b0 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -13,7 +13,7 @@ # limitations under the License. import time from typing import Any, Dict, List, Optional, cast -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock import attr import canonicaljson @@ -189,23 +189,24 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key("1") - r = self.hs.get_datastores().main.store_server_keys_json( + r = self.hs.get_datastores().main.store_server_keys_response( "server9", - get_key_id(key1), from_server="test", - ts_now_ms=int(time.time() * 1000), - ts_expires_ms=1000, + ts_added_ms=int(time.time() * 1000), + verify_keys={ + get_key_id(key1): FetchKeyResult( + verify_key=get_verify_key(key1), valid_until_ts=1000 + ) + }, # The entire response gets signed & stored, just include the bits we # care about. - key_json_bytes=canonicaljson.encode_canonical_json( - { - "verify_keys": { - get_key_id(key1): { - "key": encode_verify_key_base64(get_verify_key(key1)) - } + response_json={ + "verify_keys": { + get_key_id(key1): { + "key": encode_verify_key_base64(get_verify_key(key1)) } } - ), + }, ) self.get_success(r) @@ -285,34 +286,6 @@ class KeyringTestCase(unittest.HomeserverTestCase): d = kr.verify_json_for_server(self.hs.hostname, json1, 0) self.get_success(d) - def test_verify_json_for_server_with_null_valid_until_ms(self) -> None: - """Tests that we correctly handle key requests for keys we've stored - with a null `ts_valid_until_ms` - """ - mock_fetcher = Mock() - mock_fetcher.get_keys = AsyncMock(return_value={}) - - key1 = signedjson.key.generate_signing_key("1") - r = self.hs.get_datastores().main.store_server_signature_keys( - "server9", - int(time.time() * 1000), - # None is not a valid value in FetchKeyResult, but we're abusing this - # API to insert null values into the database. The nulls get converted - # to 0 when fetched in KeyStore.get_server_signature_keys. - {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type] - ) - self.get_success(r) - - json1: JsonDict = {} - signedjson.sign.sign_json(json1, "server9", key1) - - # should succeed on a signed object with a 0 minimum_valid_until_ms - d = self.hs.get_datastores().main.get_server_signature_keys( - [("server9", get_key_id(key1))] - ) - result = self.get_success(d) - self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0) - def test_verify_json_dedupes_key_requests(self) -> None: """Two requests for the same key should be deduped.""" key1 = signedjson.key.generate_signing_key("1") diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py deleted file mode 100644 index 5d7c13e6d0..0000000000 --- a/tests/storage/test_keys.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import signedjson.key -import signedjson.types -import unpaddedbase64 - -from synapse.storage.keys import FetchKeyResult - -import tests.unittest - - -def decode_verify_key_base64( - key_id: str, key_base64: str -) -> signedjson.types.VerifyKey: - key_bytes = unpaddedbase64.decode_base64(key_base64) - return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) - - -KEY_1 = decode_verify_key_base64( - "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" -) -KEY_2 = decode_verify_key_base64( - "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" -) - - -class KeyStoreTestCase(tests.unittest.HomeserverTestCase): - def test_get_server_signature_keys(self) -> None: - store = self.hs.get_datastores().main - - key_id_1 = "ed25519:key1" - key_id_2 = "ed25519:KEY_ID_2" - self.get_success( - store.store_server_signature_keys( - "from_server", - 10, - { - ("server1", key_id_1): FetchKeyResult(KEY_1, 100), - ("server1", key_id_2): FetchKeyResult(KEY_2, 200), - }, - ) - ) - - res = self.get_success( - store.get_server_signature_keys( - [ - ("server1", key_id_1), - ("server1", key_id_2), - ("server1", "ed25519:key3"), - ] - ) - ) - - self.assertEqual(len(res.keys()), 3) - res1 = res[("server1", key_id_1)] - self.assertEqual(res1.verify_key, KEY_1) - self.assertEqual(res1.verify_key.version, "key1") - self.assertEqual(res1.valid_until_ts, 100) - - res2 = res[("server1", key_id_2)] - self.assertEqual(res2.verify_key, KEY_2) - # version comes from the ID it was stored with - self.assertEqual(res2.verify_key.version, "KEY_ID_2") - self.assertEqual(res2.valid_until_ts, 200) - - # non-existent result gives None - self.assertIsNone(res[("server1", "ed25519:key3")]) - - def test_cache(self) -> None: - """Check that updates correctly invalidate the cache.""" - - store = self.hs.get_datastores().main - - key_id_1 = "ed25519:key1" - key_id_2 = "ed25519:key2" - - self.get_success( - store.store_server_signature_keys( - "from_server", - 0, - { - ("srv1", key_id_1): FetchKeyResult(KEY_1, 100), - ("srv1", key_id_2): FetchKeyResult(KEY_2, 200), - }, - ) - ) - - res = self.get_success( - store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) - ) - self.assertEqual(len(res.keys()), 2) - - res1 = res[("srv1", key_id_1)] - self.assertEqual(res1.verify_key, KEY_1) - self.assertEqual(res1.valid_until_ts, 100) - - res2 = res[("srv1", key_id_2)] - self.assertEqual(res2.verify_key, KEY_2) - self.assertEqual(res2.valid_until_ts, 200) - - # we should be able to look up the same thing again without a db hit - res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)])) - self.assertEqual(len(res.keys()), 1) - self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) - - new_key_2 = signedjson.key.get_verify_key( - signedjson.key.generate_signing_key("key2") - ) - d = store.store_server_signature_keys( - "from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)} - ) - self.get_success(d) - - res = self.get_success( - store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) - ) - self.assertEqual(len(res.keys()), 2) - - res1 = res[("srv1", key_id_1)] - self.assertEqual(res1.verify_key, KEY_1) - self.assertEqual(res1.valid_until_ts, 100) - - res2 = res[("srv1", key_id_2)] - self.assertEqual(res2.verify_key, new_key_2) - self.assertEqual(res2.valid_until_ts, 300) diff --git a/tests/unittest.py b/tests/unittest.py index 5d3640d8ac..dbaff361b4 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -70,6 +70,7 @@ from synapse.logging.context import ( ) from synapse.rest import RegisterServletsFunc from synapse.server import HomeServer +from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree @@ -858,23 +859,22 @@ class FederatingHomeserverTestCase(HomeserverTestCase): verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) self.get_success( - hs.get_datastores().main.store_server_keys_json( + hs.get_datastores().main.store_server_keys_response( self.OTHER_SERVER_NAME, - verify_key_id, from_server=self.OTHER_SERVER_NAME, - ts_now_ms=clock.time_msec(), - ts_expires_ms=clock.time_msec() + 10000, - key_json_bytes=canonicaljson.encode_canonical_json( - { - "verify_keys": { - verify_key_id: { - "key": signedjson.key.encode_verify_key_base64( - verify_key - ) - } + ts_added_ms=clock.time_msec(), + verify_keys={ + verify_key_id: FetchKeyResult( + verify_key=verify_key, valid_until_ts=clock.time_msec() + 10000 + ), + }, + response_json={ + "verify_keys": { + verify_key_id: { + "key": signedjson.key.encode_verify_key_base64(verify_key) } } - ), + }, ) )