@@ -0,0 +1 @@ | |||
Simplify server key storage. |
@@ -23,12 +23,7 @@ from signedjson.key import ( | |||
get_verify_key, | |||
is_signing_algorithm_supported, | |||
) | |||
from signedjson.sign import ( | |||
SignatureVerifyException, | |||
encode_canonical_json, | |||
signature_ids, | |||
verify_signed_json, | |||
) | |||
from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json | |||
from signedjson.types import VerifyKey | |||
from unpaddedbase64 import decode_base64 | |||
@@ -596,24 +591,12 @@ class BaseV2KeyFetcher(KeyFetcher): | |||
verify_key=verify_key, valid_until_ts=key_data["expired_ts"] | |||
) | |||
key_json_bytes = encode_canonical_json(response_json) | |||
await make_deferred_yieldable( | |||
defer.gatherResults( | |||
[ | |||
run_in_background( | |||
self.store.store_server_keys_json, | |||
server_name=server_name, | |||
key_id=key_id, | |||
from_server=from_server, | |||
ts_now_ms=time_added_ms, | |||
ts_expires_ms=ts_valid_until_ms, | |||
key_json_bytes=key_json_bytes, | |||
) | |||
for key_id in verify_keys | |||
], | |||
consumeErrors=True, | |||
).addErrback(unwrapFirstError) | |||
await self.store.store_server_keys_response( | |||
server_name=server_name, | |||
from_server=from_server, | |||
ts_added_ms=time_added_ms, | |||
verify_keys=verify_keys, | |||
response_json=response_json, | |||
) | |||
return verify_keys | |||
@@ -775,10 +758,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): | |||
keys.setdefault(server_name, {}).update(processed_response) | |||
await self.store.store_server_signature_keys( | |||
perspective_name, time_now_ms, added_keys | |||
) | |||
return keys | |||
def _validate_perspectives_response( | |||
@@ -16,14 +16,17 @@ | |||
import itertools | |||
import json | |||
import logging | |||
from typing import Dict, Iterable, Mapping, Optional, Tuple | |||
from typing import Dict, Iterable, Optional, Tuple | |||
from canonicaljson import encode_canonical_json | |||
from signedjson.key import decode_verify_key_bytes | |||
from unpaddedbase64 import decode_base64 | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | |||
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote | |||
from synapse.storage.types import Cursor | |||
from synapse.types import JsonDict | |||
from synapse.util.caches.descriptors import cached, cachedList | |||
from synapse.util.iterutils import batch_iter | |||
@@ -36,162 +39,84 @@ db_binary_type = memoryview | |||
class KeyStore(CacheInvalidationWorkerStore): | |||
"""Persistence for signature verification keys""" | |||
@cached() | |||
def _get_server_signature_key( | |||
self, server_name_and_key_id: Tuple[str, str] | |||
) -> FetchKeyResult: | |||
raise NotImplementedError() | |||
@cachedList( | |||
cached_method_name="_get_server_signature_key", | |||
list_name="server_name_and_key_ids", | |||
) | |||
async def get_server_signature_keys( | |||
self, server_name_and_key_ids: Iterable[Tuple[str, str]] | |||
) -> Dict[Tuple[str, str], FetchKeyResult]: | |||
""" | |||
Args: | |||
server_name_and_key_ids: | |||
iterable of (server_name, key-id) tuples to fetch keys for | |||
Returns: | |||
A map from (server_name, key_id) -> FetchKeyResult, or None if the | |||
key is unknown | |||
""" | |||
keys = {} | |||
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: | |||
"""Processes a batch of keys to fetch, and adds the result to `keys`.""" | |||
# batch_iter always returns tuples so it's safe to do len(batch) | |||
sql = """ | |||
SELECT server_name, key_id, verify_key, ts_valid_until_ms | |||
FROM server_signature_keys WHERE 1=0 | |||
""" + " OR (server_name=? AND key_id=?)" * len( | |||
batch | |||
) | |||
txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) | |||
for row in txn: | |||
server_name, key_id, key_bytes, ts_valid_until_ms = row | |||
if ts_valid_until_ms is None: | |||
# Old keys may be stored with a ts_valid_until_ms of null, | |||
# in which case we treat this as if it was set to `0`, i.e. | |||
# it won't match key requests that define a minimum | |||
# `ts_valid_until_ms`. | |||
ts_valid_until_ms = 0 | |||
keys[(server_name, key_id)] = FetchKeyResult( | |||
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), | |||
valid_until_ts=ts_valid_until_ms, | |||
) | |||
def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: | |||
for batch in batch_iter(server_name_and_key_ids, 50): | |||
_get_keys(txn, batch) | |||
return keys | |||
return await self.db_pool.runInteraction("get_server_signature_keys", _txn) | |||
async def store_server_signature_keys( | |||
async def store_server_keys_response( | |||
self, | |||
server_name: str, | |||
from_server: str, | |||
ts_added_ms: int, | |||
verify_keys: Mapping[Tuple[str, str], FetchKeyResult], | |||
verify_keys: Dict[str, FetchKeyResult], | |||
response_json: JsonDict, | |||
) -> None: | |||
"""Stores NACL verification keys for remote servers. | |||
"""Stores the keys for the given server that we got from `from_server`. | |||
Args: | |||
from_server: Where the verification keys were looked up | |||
ts_added_ms: The time to record that the key was added | |||
verify_keys: | |||
keys to be stored. Each entry is a triplet of | |||
(server_name, key_id, key). | |||
server_name: The owner of the keys | |||
from_server: Which server we got the keys from | |||
ts_added_ms: When we're adding the keys | |||
verify_keys: The decoded keys | |||
response_json: The full *signed* response JSON that contains the keys. | |||
""" | |||
key_values = [] | |||
value_values = [] | |||
invalidations = [] | |||
for (server_name, key_id), fetch_result in verify_keys.items(): | |||
key_values.append((server_name, key_id)) | |||
value_values.append( | |||
( | |||
from_server, | |||
ts_added_ms, | |||
fetch_result.valid_until_ts, | |||
db_binary_type(fetch_result.verify_key.encode()), | |||
) | |||
) | |||
# invalidate takes a tuple corresponding to the params of | |||
# _get_server_signature_key. _get_server_signature_key only takes one | |||
# param, which is itself the 2-tuple (server_name, key_id). | |||
invalidations.append((server_name, key_id)) | |||
await self.db_pool.simple_upsert_many( | |||
table="server_signature_keys", | |||
key_names=("server_name", "key_id"), | |||
key_values=key_values, | |||
value_names=( | |||
"from_server", | |||
"ts_added_ms", | |||
"ts_valid_until_ms", | |||
"verify_key", | |||
), | |||
value_values=value_values, | |||
desc="store_server_signature_keys", | |||
) | |||
key_json_bytes = encode_canonical_json(response_json) | |||
def store_server_keys_response_txn(txn: LoggingTransaction) -> None: | |||
self.db_pool.simple_upsert_many_txn( | |||
txn, | |||
table="server_signature_keys", | |||
key_names=("server_name", "key_id"), | |||
key_values=[(server_name, key_id) for key_id in verify_keys], | |||
value_names=( | |||
"from_server", | |||
"ts_added_ms", | |||
"ts_valid_until_ms", | |||
"verify_key", | |||
), | |||
value_values=[ | |||
( | |||
from_server, | |||
ts_added_ms, | |||
fetch_result.valid_until_ts, | |||
db_binary_type(fetch_result.verify_key.encode()), | |||
) | |||
for fetch_result in verify_keys.values() | |||
], | |||
) | |||
invalidate = self._get_server_signature_key.invalidate | |||
for i in invalidations: | |||
invalidate((i,)) | |||
self.db_pool.simple_upsert_many_txn( | |||
txn, | |||
table="server_keys_json", | |||
key_names=("server_name", "key_id", "from_server"), | |||
key_values=[ | |||
(server_name, key_id, from_server) for key_id in verify_keys | |||
], | |||
value_names=( | |||
"ts_added_ms", | |||
"ts_valid_until_ms", | |||
"key_json", | |||
), | |||
value_values=[ | |||
( | |||
ts_added_ms, | |||
fetch_result.valid_until_ts, | |||
db_binary_type(key_json_bytes), | |||
) | |||
for fetch_result in verify_keys.values() | |||
], | |||
) | |||
async def store_server_keys_json( | |||
self, | |||
server_name: str, | |||
key_id: str, | |||
from_server: str, | |||
ts_now_ms: int, | |||
ts_expires_ms: int, | |||
key_json_bytes: bytes, | |||
) -> None: | |||
"""Stores the JSON bytes for a set of keys from a server | |||
The JSON should be signed by the originating server, the intermediate | |||
server, and by this server. Updates the value for the | |||
(server_name, key_id, from_server) triplet if one already existed. | |||
Args: | |||
server_name: The name of the server. | |||
key_id: The identifier of the key this JSON is for. | |||
from_server: The server this JSON was fetched from. | |||
ts_now_ms: The time now in milliseconds. | |||
ts_valid_until_ms: The time when this json stops being valid. | |||
key_json_bytes: The encoded JSON. | |||
""" | |||
await self.db_pool.simple_upsert( | |||
table="server_keys_json", | |||
keyvalues={ | |||
"server_name": server_name, | |||
"key_id": key_id, | |||
"from_server": from_server, | |||
}, | |||
values={ | |||
"server_name": server_name, | |||
"key_id": key_id, | |||
"from_server": from_server, | |||
"ts_added_ms": ts_now_ms, | |||
"ts_valid_until_ms": ts_expires_ms, | |||
"key_json": db_binary_type(key_json_bytes), | |||
}, | |||
desc="store_server_keys_json", | |||
) | |||
# invalidate takes a tuple corresponding to the params of | |||
# _get_server_keys_json. _get_server_keys_json only takes one | |||
# param, which is itself the 2-tuple (server_name, key_id). | |||
for key_id in verify_keys: | |||
self._invalidate_cache_and_stream( | |||
txn, self._get_server_keys_json, ((server_name, key_id),) | |||
) | |||
self._invalidate_cache_and_stream( | |||
txn, self.get_server_key_json_for_remote, (server_name, key_id) | |||
) | |||
# invalidate takes a tuple corresponding to the params of | |||
# _get_server_keys_json. _get_server_keys_json only takes one | |||
# param, which is itself the 2-tuple (server_name, key_id). | |||
await self.invalidate_cache_and_stream( | |||
"_get_server_keys_json", ((server_name, key_id),) | |||
) | |||
await self.invalidate_cache_and_stream( | |||
"get_server_key_json_for_remote", (server_name, key_id) | |||
await self.db_pool.runInteraction( | |||
"store_server_keys_response", store_server_keys_response_txn | |||
) | |||
@cached() | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
import time | |||
from typing import Any, Dict, List, Optional, cast | |||
from unittest.mock import AsyncMock, Mock | |||
from unittest.mock import Mock | |||
import attr | |||
import canonicaljson | |||
@@ -189,23 +189,24 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||
kr = keyring.Keyring(self.hs) | |||
key1 = signedjson.key.generate_signing_key("1") | |||
r = self.hs.get_datastores().main.store_server_keys_json( | |||
r = self.hs.get_datastores().main.store_server_keys_response( | |||
"server9", | |||
get_key_id(key1), | |||
from_server="test", | |||
ts_now_ms=int(time.time() * 1000), | |||
ts_expires_ms=1000, | |||
ts_added_ms=int(time.time() * 1000), | |||
verify_keys={ | |||
get_key_id(key1): FetchKeyResult( | |||
verify_key=get_verify_key(key1), valid_until_ts=1000 | |||
) | |||
}, | |||
# The entire response gets signed & stored, just include the bits we | |||
# care about. | |||
key_json_bytes=canonicaljson.encode_canonical_json( | |||
{ | |||
"verify_keys": { | |||
get_key_id(key1): { | |||
"key": encode_verify_key_base64(get_verify_key(key1)) | |||
} | |||
response_json={ | |||
"verify_keys": { | |||
get_key_id(key1): { | |||
"key": encode_verify_key_base64(get_verify_key(key1)) | |||
} | |||
} | |||
), | |||
}, | |||
) | |||
self.get_success(r) | |||
@@ -285,34 +286,6 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||
d = kr.verify_json_for_server(self.hs.hostname, json1, 0) | |||
self.get_success(d) | |||
def test_verify_json_for_server_with_null_valid_until_ms(self) -> None: | |||
"""Tests that we correctly handle key requests for keys we've stored | |||
with a null `ts_valid_until_ms` | |||
""" | |||
mock_fetcher = Mock() | |||
mock_fetcher.get_keys = AsyncMock(return_value={}) | |||
key1 = signedjson.key.generate_signing_key("1") | |||
r = self.hs.get_datastores().main.store_server_signature_keys( | |||
"server9", | |||
int(time.time() * 1000), | |||
# None is not a valid value in FetchKeyResult, but we're abusing this | |||
# API to insert null values into the database. The nulls get converted | |||
# to 0 when fetched in KeyStore.get_server_signature_keys. | |||
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type] | |||
) | |||
self.get_success(r) | |||
json1: JsonDict = {} | |||
signedjson.sign.sign_json(json1, "server9", key1) | |||
# should succeed on a signed object with a 0 minimum_valid_until_ms | |||
d = self.hs.get_datastores().main.get_server_signature_keys( | |||
[("server9", get_key_id(key1))] | |||
) | |||
result = self.get_success(d) | |||
self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0) | |||
def test_verify_json_dedupes_key_requests(self) -> None: | |||
"""Two requests for the same key should be deduped.""" | |||
key1 = signedjson.key.generate_signing_key("1") | |||
@@ -1,137 +0,0 @@ | |||
# Copyright 2017 Vector Creations Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import signedjson.key | |||
import signedjson.types | |||
import unpaddedbase64 | |||
from synapse.storage.keys import FetchKeyResult | |||
import tests.unittest | |||
def decode_verify_key_base64( | |||
key_id: str, key_base64: str | |||
) -> signedjson.types.VerifyKey: | |||
key_bytes = unpaddedbase64.decode_base64(key_base64) | |||
return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) | |||
KEY_1 = decode_verify_key_base64( | |||
"ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" | |||
) | |||
KEY_2 = decode_verify_key_base64( | |||
"ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" | |||
) | |||
class KeyStoreTestCase(tests.unittest.HomeserverTestCase): | |||
def test_get_server_signature_keys(self) -> None: | |||
store = self.hs.get_datastores().main | |||
key_id_1 = "ed25519:key1" | |||
key_id_2 = "ed25519:KEY_ID_2" | |||
self.get_success( | |||
store.store_server_signature_keys( | |||
"from_server", | |||
10, | |||
{ | |||
("server1", key_id_1): FetchKeyResult(KEY_1, 100), | |||
("server1", key_id_2): FetchKeyResult(KEY_2, 200), | |||
}, | |||
) | |||
) | |||
res = self.get_success( | |||
store.get_server_signature_keys( | |||
[ | |||
("server1", key_id_1), | |||
("server1", key_id_2), | |||
("server1", "ed25519:key3"), | |||
] | |||
) | |||
) | |||
self.assertEqual(len(res.keys()), 3) | |||
res1 = res[("server1", key_id_1)] | |||
self.assertEqual(res1.verify_key, KEY_1) | |||
self.assertEqual(res1.verify_key.version, "key1") | |||
self.assertEqual(res1.valid_until_ts, 100) | |||
res2 = res[("server1", key_id_2)] | |||
self.assertEqual(res2.verify_key, KEY_2) | |||
# version comes from the ID it was stored with | |||
self.assertEqual(res2.verify_key.version, "KEY_ID_2") | |||
self.assertEqual(res2.valid_until_ts, 200) | |||
# non-existent result gives None | |||
self.assertIsNone(res[("server1", "ed25519:key3")]) | |||
def test_cache(self) -> None: | |||
"""Check that updates correctly invalidate the cache.""" | |||
store = self.hs.get_datastores().main | |||
key_id_1 = "ed25519:key1" | |||
key_id_2 = "ed25519:key2" | |||
self.get_success( | |||
store.store_server_signature_keys( | |||
"from_server", | |||
0, | |||
{ | |||
("srv1", key_id_1): FetchKeyResult(KEY_1, 100), | |||
("srv1", key_id_2): FetchKeyResult(KEY_2, 200), | |||
}, | |||
) | |||
) | |||
res = self.get_success( | |||
store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) | |||
) | |||
self.assertEqual(len(res.keys()), 2) | |||
res1 = res[("srv1", key_id_1)] | |||
self.assertEqual(res1.verify_key, KEY_1) | |||
self.assertEqual(res1.valid_until_ts, 100) | |||
res2 = res[("srv1", key_id_2)] | |||
self.assertEqual(res2.verify_key, KEY_2) | |||
self.assertEqual(res2.valid_until_ts, 200) | |||
# we should be able to look up the same thing again without a db hit | |||
res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)])) | |||
self.assertEqual(len(res.keys()), 1) | |||
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) | |||
new_key_2 = signedjson.key.get_verify_key( | |||
signedjson.key.generate_signing_key("key2") | |||
) | |||
d = store.store_server_signature_keys( | |||
"from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)} | |||
) | |||
self.get_success(d) | |||
res = self.get_success( | |||
store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) | |||
) | |||
self.assertEqual(len(res.keys()), 2) | |||
res1 = res[("srv1", key_id_1)] | |||
self.assertEqual(res1.verify_key, KEY_1) | |||
self.assertEqual(res1.valid_until_ts, 100) | |||
res2 = res[("srv1", key_id_2)] | |||
self.assertEqual(res2.verify_key, new_key_2) | |||
self.assertEqual(res2.valid_until_ts, 300) |
@@ -70,6 +70,7 @@ from synapse.logging.context import ( | |||
) | |||
from synapse.rest import RegisterServletsFunc | |||
from synapse.server import HomeServer | |||
from synapse.storage.keys import FetchKeyResult | |||
from synapse.types import JsonDict, Requester, UserID, create_requester | |||
from synapse.util import Clock | |||
from synapse.util.httpresourcetree import create_resource_tree | |||
@@ -858,23 +859,22 @@ class FederatingHomeserverTestCase(HomeserverTestCase): | |||
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) | |||
self.get_success( | |||
hs.get_datastores().main.store_server_keys_json( | |||
hs.get_datastores().main.store_server_keys_response( | |||
self.OTHER_SERVER_NAME, | |||
verify_key_id, | |||
from_server=self.OTHER_SERVER_NAME, | |||
ts_now_ms=clock.time_msec(), | |||
ts_expires_ms=clock.time_msec() + 10000, | |||
key_json_bytes=canonicaljson.encode_canonical_json( | |||
{ | |||
"verify_keys": { | |||
verify_key_id: { | |||
"key": signedjson.key.encode_verify_key_base64( | |||
verify_key | |||
) | |||
} | |||
ts_added_ms=clock.time_msec(), | |||
verify_keys={ | |||
verify_key_id: FetchKeyResult( | |||
verify_key=verify_key, valid_until_ts=clock.time_msec() + 10000 | |||
), | |||
}, | |||
response_json={ | |||
"verify_keys": { | |||
verify_key_id: { | |||
"key": signedjson.key.encode_verify_key_base64(verify_key) | |||
} | |||
} | |||
), | |||
}, | |||
) | |||
) | |||