@@ -0,0 +1 @@ | |||||
Simplify server key storage. |
@@ -23,12 +23,7 @@ from signedjson.key import ( | |||||
get_verify_key, | get_verify_key, | ||||
is_signing_algorithm_supported, | is_signing_algorithm_supported, | ||||
) | ) | ||||
from signedjson.sign import ( | |||||
SignatureVerifyException, | |||||
encode_canonical_json, | |||||
signature_ids, | |||||
verify_signed_json, | |||||
) | |||||
from signedjson.sign import SignatureVerifyException, signature_ids, verify_signed_json | |||||
from signedjson.types import VerifyKey | from signedjson.types import VerifyKey | ||||
from unpaddedbase64 import decode_base64 | from unpaddedbase64 import decode_base64 | ||||
@@ -596,24 +591,12 @@ class BaseV2KeyFetcher(KeyFetcher): | |||||
verify_key=verify_key, valid_until_ts=key_data["expired_ts"] | verify_key=verify_key, valid_until_ts=key_data["expired_ts"] | ||||
) | ) | ||||
key_json_bytes = encode_canonical_json(response_json) | |||||
await make_deferred_yieldable( | |||||
defer.gatherResults( | |||||
[ | |||||
run_in_background( | |||||
self.store.store_server_keys_json, | |||||
server_name=server_name, | |||||
key_id=key_id, | |||||
from_server=from_server, | |||||
ts_now_ms=time_added_ms, | |||||
ts_expires_ms=ts_valid_until_ms, | |||||
key_json_bytes=key_json_bytes, | |||||
) | |||||
for key_id in verify_keys | |||||
], | |||||
consumeErrors=True, | |||||
).addErrback(unwrapFirstError) | |||||
await self.store.store_server_keys_response( | |||||
server_name=server_name, | |||||
from_server=from_server, | |||||
ts_added_ms=time_added_ms, | |||||
verify_keys=verify_keys, | |||||
response_json=response_json, | |||||
) | ) | ||||
return verify_keys | return verify_keys | ||||
@@ -775,10 +758,6 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): | |||||
keys.setdefault(server_name, {}).update(processed_response) | keys.setdefault(server_name, {}).update(processed_response) | ||||
await self.store.store_server_signature_keys( | |||||
perspective_name, time_now_ms, added_keys | |||||
) | |||||
return keys | return keys | ||||
def _validate_perspectives_response( | def _validate_perspectives_response( | ||||
@@ -16,14 +16,17 @@ | |||||
import itertools | import itertools | ||||
import json | import json | ||||
import logging | import logging | ||||
from typing import Dict, Iterable, Mapping, Optional, Tuple | |||||
from typing import Dict, Iterable, Optional, Tuple | |||||
from canonicaljson import encode_canonical_json | |||||
from signedjson.key import decode_verify_key_bytes | from signedjson.key import decode_verify_key_bytes | ||||
from unpaddedbase64 import decode_base64 | from unpaddedbase64 import decode_base64 | ||||
from synapse.storage.database import LoggingTransaction | |||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | ||||
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote | from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote | ||||
from synapse.storage.types import Cursor | from synapse.storage.types import Cursor | ||||
from synapse.types import JsonDict | |||||
from synapse.util.caches.descriptors import cached, cachedList | from synapse.util.caches.descriptors import cached, cachedList | ||||
from synapse.util.iterutils import batch_iter | from synapse.util.iterutils import batch_iter | ||||
@@ -36,162 +39,84 @@ db_binary_type = memoryview | |||||
class KeyStore(CacheInvalidationWorkerStore): | class KeyStore(CacheInvalidationWorkerStore): | ||||
"""Persistence for signature verification keys""" | """Persistence for signature verification keys""" | ||||
@cached() | |||||
def _get_server_signature_key( | |||||
self, server_name_and_key_id: Tuple[str, str] | |||||
) -> FetchKeyResult: | |||||
raise NotImplementedError() | |||||
@cachedList( | |||||
cached_method_name="_get_server_signature_key", | |||||
list_name="server_name_and_key_ids", | |||||
) | |||||
async def get_server_signature_keys( | |||||
self, server_name_and_key_ids: Iterable[Tuple[str, str]] | |||||
) -> Dict[Tuple[str, str], FetchKeyResult]: | |||||
""" | |||||
Args: | |||||
server_name_and_key_ids: | |||||
iterable of (server_name, key-id) tuples to fetch keys for | |||||
Returns: | |||||
A map from (server_name, key_id) -> FetchKeyResult, or None if the | |||||
key is unknown | |||||
""" | |||||
keys = {} | |||||
def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None: | |||||
"""Processes a batch of keys to fetch, and adds the result to `keys`.""" | |||||
# batch_iter always returns tuples so it's safe to do len(batch) | |||||
sql = """ | |||||
SELECT server_name, key_id, verify_key, ts_valid_until_ms | |||||
FROM server_signature_keys WHERE 1=0 | |||||
""" + " OR (server_name=? AND key_id=?)" * len( | |||||
batch | |||||
) | |||||
txn.execute(sql, tuple(itertools.chain.from_iterable(batch))) | |||||
for row in txn: | |||||
server_name, key_id, key_bytes, ts_valid_until_ms = row | |||||
if ts_valid_until_ms is None: | |||||
# Old keys may be stored with a ts_valid_until_ms of null, | |||||
# in which case we treat this as if it was set to `0`, i.e. | |||||
# it won't match key requests that define a minimum | |||||
# `ts_valid_until_ms`. | |||||
ts_valid_until_ms = 0 | |||||
keys[(server_name, key_id)] = FetchKeyResult( | |||||
verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)), | |||||
valid_until_ts=ts_valid_until_ms, | |||||
) | |||||
def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]: | |||||
for batch in batch_iter(server_name_and_key_ids, 50): | |||||
_get_keys(txn, batch) | |||||
return keys | |||||
return await self.db_pool.runInteraction("get_server_signature_keys", _txn) | |||||
async def store_server_signature_keys( | |||||
async def store_server_keys_response( | |||||
self, | self, | ||||
server_name: str, | |||||
from_server: str, | from_server: str, | ||||
ts_added_ms: int, | ts_added_ms: int, | ||||
verify_keys: Mapping[Tuple[str, str], FetchKeyResult], | |||||
verify_keys: Dict[str, FetchKeyResult], | |||||
response_json: JsonDict, | |||||
) -> None: | ) -> None: | ||||
"""Stores NACL verification keys for remote servers. | |||||
"""Stores the keys for the given server that we got from `from_server`. | |||||
Args: | Args: | ||||
from_server: Where the verification keys were looked up | |||||
ts_added_ms: The time to record that the key was added | |||||
verify_keys: | |||||
keys to be stored. Each entry is a triplet of | |||||
(server_name, key_id, key). | |||||
server_name: The owner of the keys | |||||
from_server: Which server we got the keys from | |||||
ts_added_ms: When we're adding the keys | |||||
verify_keys: The decoded keys | |||||
response_json: The full *signed* response JSON that contains the keys. | |||||
""" | """ | ||||
key_values = [] | |||||
value_values = [] | |||||
invalidations = [] | |||||
for (server_name, key_id), fetch_result in verify_keys.items(): | |||||
key_values.append((server_name, key_id)) | |||||
value_values.append( | |||||
( | |||||
from_server, | |||||
ts_added_ms, | |||||
fetch_result.valid_until_ts, | |||||
db_binary_type(fetch_result.verify_key.encode()), | |||||
) | |||||
) | |||||
# invalidate takes a tuple corresponding to the params of | |||||
# _get_server_signature_key. _get_server_signature_key only takes one | |||||
# param, which is itself the 2-tuple (server_name, key_id). | |||||
invalidations.append((server_name, key_id)) | |||||
await self.db_pool.simple_upsert_many( | |||||
table="server_signature_keys", | |||||
key_names=("server_name", "key_id"), | |||||
key_values=key_values, | |||||
value_names=( | |||||
"from_server", | |||||
"ts_added_ms", | |||||
"ts_valid_until_ms", | |||||
"verify_key", | |||||
), | |||||
value_values=value_values, | |||||
desc="store_server_signature_keys", | |||||
) | |||||
key_json_bytes = encode_canonical_json(response_json) | |||||
def store_server_keys_response_txn(txn: LoggingTransaction) -> None: | |||||
self.db_pool.simple_upsert_many_txn( | |||||
txn, | |||||
table="server_signature_keys", | |||||
key_names=("server_name", "key_id"), | |||||
key_values=[(server_name, key_id) for key_id in verify_keys], | |||||
value_names=( | |||||
"from_server", | |||||
"ts_added_ms", | |||||
"ts_valid_until_ms", | |||||
"verify_key", | |||||
), | |||||
value_values=[ | |||||
( | |||||
from_server, | |||||
ts_added_ms, | |||||
fetch_result.valid_until_ts, | |||||
db_binary_type(fetch_result.verify_key.encode()), | |||||
) | |||||
for fetch_result in verify_keys.values() | |||||
], | |||||
) | |||||
invalidate = self._get_server_signature_key.invalidate | |||||
for i in invalidations: | |||||
invalidate((i,)) | |||||
self.db_pool.simple_upsert_many_txn( | |||||
txn, | |||||
table="server_keys_json", | |||||
key_names=("server_name", "key_id", "from_server"), | |||||
key_values=[ | |||||
(server_name, key_id, from_server) for key_id in verify_keys | |||||
], | |||||
value_names=( | |||||
"ts_added_ms", | |||||
"ts_valid_until_ms", | |||||
"key_json", | |||||
), | |||||
value_values=[ | |||||
( | |||||
ts_added_ms, | |||||
fetch_result.valid_until_ts, | |||||
db_binary_type(key_json_bytes), | |||||
) | |||||
for fetch_result in verify_keys.values() | |||||
], | |||||
) | |||||
async def store_server_keys_json( | |||||
self, | |||||
server_name: str, | |||||
key_id: str, | |||||
from_server: str, | |||||
ts_now_ms: int, | |||||
ts_expires_ms: int, | |||||
key_json_bytes: bytes, | |||||
) -> None: | |||||
"""Stores the JSON bytes for a set of keys from a server | |||||
The JSON should be signed by the originating server, the intermediate | |||||
server, and by this server. Updates the value for the | |||||
(server_name, key_id, from_server) triplet if one already existed. | |||||
Args: | |||||
server_name: The name of the server. | |||||
key_id: The identifier of the key this JSON is for. | |||||
from_server: The server this JSON was fetched from. | |||||
ts_now_ms: The time now in milliseconds. | |||||
ts_valid_until_ms: The time when this json stops being valid. | |||||
key_json_bytes: The encoded JSON. | |||||
""" | |||||
await self.db_pool.simple_upsert( | |||||
table="server_keys_json", | |||||
keyvalues={ | |||||
"server_name": server_name, | |||||
"key_id": key_id, | |||||
"from_server": from_server, | |||||
}, | |||||
values={ | |||||
"server_name": server_name, | |||||
"key_id": key_id, | |||||
"from_server": from_server, | |||||
"ts_added_ms": ts_now_ms, | |||||
"ts_valid_until_ms": ts_expires_ms, | |||||
"key_json": db_binary_type(key_json_bytes), | |||||
}, | |||||
desc="store_server_keys_json", | |||||
) | |||||
# invalidate takes a tuple corresponding to the params of | |||||
# _get_server_keys_json. _get_server_keys_json only takes one | |||||
# param, which is itself the 2-tuple (server_name, key_id). | |||||
for key_id in verify_keys: | |||||
self._invalidate_cache_and_stream( | |||||
txn, self._get_server_keys_json, ((server_name, key_id),) | |||||
) | |||||
self._invalidate_cache_and_stream( | |||||
txn, self.get_server_key_json_for_remote, (server_name, key_id) | |||||
) | |||||
# invalidate takes a tuple corresponding to the params of | |||||
# _get_server_keys_json. _get_server_keys_json only takes one | |||||
# param, which is itself the 2-tuple (server_name, key_id). | |||||
await self.invalidate_cache_and_stream( | |||||
"_get_server_keys_json", ((server_name, key_id),) | |||||
) | |||||
await self.invalidate_cache_and_stream( | |||||
"get_server_key_json_for_remote", (server_name, key_id) | |||||
await self.db_pool.runInteraction( | |||||
"store_server_keys_response", store_server_keys_response_txn | |||||
) | ) | ||||
@cached() | @cached() | ||||
@@ -13,7 +13,7 @@ | |||||
# limitations under the License. | # limitations under the License. | ||||
import time | import time | ||||
from typing import Any, Dict, List, Optional, cast | from typing import Any, Dict, List, Optional, cast | ||||
from unittest.mock import AsyncMock, Mock | |||||
from unittest.mock import Mock | |||||
import attr | import attr | ||||
import canonicaljson | import canonicaljson | ||||
@@ -189,23 +189,24 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||||
kr = keyring.Keyring(self.hs) | kr = keyring.Keyring(self.hs) | ||||
key1 = signedjson.key.generate_signing_key("1") | key1 = signedjson.key.generate_signing_key("1") | ||||
r = self.hs.get_datastores().main.store_server_keys_json( | |||||
r = self.hs.get_datastores().main.store_server_keys_response( | |||||
"server9", | "server9", | ||||
get_key_id(key1), | |||||
from_server="test", | from_server="test", | ||||
ts_now_ms=int(time.time() * 1000), | |||||
ts_expires_ms=1000, | |||||
ts_added_ms=int(time.time() * 1000), | |||||
verify_keys={ | |||||
get_key_id(key1): FetchKeyResult( | |||||
verify_key=get_verify_key(key1), valid_until_ts=1000 | |||||
) | |||||
}, | |||||
# The entire response gets signed & stored, just include the bits we | # The entire response gets signed & stored, just include the bits we | ||||
# care about. | # care about. | ||||
key_json_bytes=canonicaljson.encode_canonical_json( | |||||
{ | |||||
"verify_keys": { | |||||
get_key_id(key1): { | |||||
"key": encode_verify_key_base64(get_verify_key(key1)) | |||||
} | |||||
response_json={ | |||||
"verify_keys": { | |||||
get_key_id(key1): { | |||||
"key": encode_verify_key_base64(get_verify_key(key1)) | |||||
} | } | ||||
} | } | ||||
), | |||||
}, | |||||
) | ) | ||||
self.get_success(r) | self.get_success(r) | ||||
@@ -285,34 +286,6 @@ class KeyringTestCase(unittest.HomeserverTestCase): | |||||
d = kr.verify_json_for_server(self.hs.hostname, json1, 0) | d = kr.verify_json_for_server(self.hs.hostname, json1, 0) | ||||
self.get_success(d) | self.get_success(d) | ||||
def test_verify_json_for_server_with_null_valid_until_ms(self) -> None: | |||||
"""Tests that we correctly handle key requests for keys we've stored | |||||
with a null `ts_valid_until_ms` | |||||
""" | |||||
mock_fetcher = Mock() | |||||
mock_fetcher.get_keys = AsyncMock(return_value={}) | |||||
key1 = signedjson.key.generate_signing_key("1") | |||||
r = self.hs.get_datastores().main.store_server_signature_keys( | |||||
"server9", | |||||
int(time.time() * 1000), | |||||
# None is not a valid value in FetchKeyResult, but we're abusing this | |||||
# API to insert null values into the database. The nulls get converted | |||||
# to 0 when fetched in KeyStore.get_server_signature_keys. | |||||
{("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type] | |||||
) | |||||
self.get_success(r) | |||||
json1: JsonDict = {} | |||||
signedjson.sign.sign_json(json1, "server9", key1) | |||||
# should succeed on a signed object with a 0 minimum_valid_until_ms | |||||
d = self.hs.get_datastores().main.get_server_signature_keys( | |||||
[("server9", get_key_id(key1))] | |||||
) | |||||
result = self.get_success(d) | |||||
self.assertEqual(result[("server9", get_key_id(key1))].valid_until_ts, 0) | |||||
def test_verify_json_dedupes_key_requests(self) -> None: | def test_verify_json_dedupes_key_requests(self) -> None: | ||||
"""Two requests for the same key should be deduped.""" | """Two requests for the same key should be deduped.""" | ||||
key1 = signedjson.key.generate_signing_key("1") | key1 = signedjson.key.generate_signing_key("1") | ||||
@@ -1,137 +0,0 @@ | |||||
# Copyright 2017 Vector Creations Ltd | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
import signedjson.key | |||||
import signedjson.types | |||||
import unpaddedbase64 | |||||
from synapse.storage.keys import FetchKeyResult | |||||
import tests.unittest | |||||
def decode_verify_key_base64( | |||||
key_id: str, key_base64: str | |||||
) -> signedjson.types.VerifyKey: | |||||
key_bytes = unpaddedbase64.decode_base64(key_base64) | |||||
return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) | |||||
KEY_1 = decode_verify_key_base64( | |||||
"ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" | |||||
) | |||||
KEY_2 = decode_verify_key_base64( | |||||
"ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" | |||||
) | |||||
class KeyStoreTestCase(tests.unittest.HomeserverTestCase): | |||||
def test_get_server_signature_keys(self) -> None: | |||||
store = self.hs.get_datastores().main | |||||
key_id_1 = "ed25519:key1" | |||||
key_id_2 = "ed25519:KEY_ID_2" | |||||
self.get_success( | |||||
store.store_server_signature_keys( | |||||
"from_server", | |||||
10, | |||||
{ | |||||
("server1", key_id_1): FetchKeyResult(KEY_1, 100), | |||||
("server1", key_id_2): FetchKeyResult(KEY_2, 200), | |||||
}, | |||||
) | |||||
) | |||||
res = self.get_success( | |||||
store.get_server_signature_keys( | |||||
[ | |||||
("server1", key_id_1), | |||||
("server1", key_id_2), | |||||
("server1", "ed25519:key3"), | |||||
] | |||||
) | |||||
) | |||||
self.assertEqual(len(res.keys()), 3) | |||||
res1 = res[("server1", key_id_1)] | |||||
self.assertEqual(res1.verify_key, KEY_1) | |||||
self.assertEqual(res1.verify_key.version, "key1") | |||||
self.assertEqual(res1.valid_until_ts, 100) | |||||
res2 = res[("server1", key_id_2)] | |||||
self.assertEqual(res2.verify_key, KEY_2) | |||||
# version comes from the ID it was stored with | |||||
self.assertEqual(res2.verify_key.version, "KEY_ID_2") | |||||
self.assertEqual(res2.valid_until_ts, 200) | |||||
# non-existent result gives None | |||||
self.assertIsNone(res[("server1", "ed25519:key3")]) | |||||
def test_cache(self) -> None: | |||||
"""Check that updates correctly invalidate the cache.""" | |||||
store = self.hs.get_datastores().main | |||||
key_id_1 = "ed25519:key1" | |||||
key_id_2 = "ed25519:key2" | |||||
self.get_success( | |||||
store.store_server_signature_keys( | |||||
"from_server", | |||||
0, | |||||
{ | |||||
("srv1", key_id_1): FetchKeyResult(KEY_1, 100), | |||||
("srv1", key_id_2): FetchKeyResult(KEY_2, 200), | |||||
}, | |||||
) | |||||
) | |||||
res = self.get_success( | |||||
store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) | |||||
) | |||||
self.assertEqual(len(res.keys()), 2) | |||||
res1 = res[("srv1", key_id_1)] | |||||
self.assertEqual(res1.verify_key, KEY_1) | |||||
self.assertEqual(res1.valid_until_ts, 100) | |||||
res2 = res[("srv1", key_id_2)] | |||||
self.assertEqual(res2.verify_key, KEY_2) | |||||
self.assertEqual(res2.valid_until_ts, 200) | |||||
# we should be able to look up the same thing again without a db hit | |||||
res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)])) | |||||
self.assertEqual(len(res.keys()), 1) | |||||
self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) | |||||
new_key_2 = signedjson.key.get_verify_key( | |||||
signedjson.key.generate_signing_key("key2") | |||||
) | |||||
d = store.store_server_signature_keys( | |||||
"from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)} | |||||
) | |||||
self.get_success(d) | |||||
res = self.get_success( | |||||
store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) | |||||
) | |||||
self.assertEqual(len(res.keys()), 2) | |||||
res1 = res[("srv1", key_id_1)] | |||||
self.assertEqual(res1.verify_key, KEY_1) | |||||
self.assertEqual(res1.valid_until_ts, 100) | |||||
res2 = res[("srv1", key_id_2)] | |||||
self.assertEqual(res2.verify_key, new_key_2) | |||||
self.assertEqual(res2.valid_until_ts, 300) |
@@ -70,6 +70,7 @@ from synapse.logging.context import ( | |||||
) | ) | ||||
from synapse.rest import RegisterServletsFunc | from synapse.rest import RegisterServletsFunc | ||||
from synapse.server import HomeServer | from synapse.server import HomeServer | ||||
from synapse.storage.keys import FetchKeyResult | |||||
from synapse.types import JsonDict, Requester, UserID, create_requester | from synapse.types import JsonDict, Requester, UserID, create_requester | ||||
from synapse.util import Clock | from synapse.util import Clock | ||||
from synapse.util.httpresourcetree import create_resource_tree | from synapse.util.httpresourcetree import create_resource_tree | ||||
@@ -858,23 +859,22 @@ class FederatingHomeserverTestCase(HomeserverTestCase): | |||||
verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) | verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) | ||||
self.get_success( | self.get_success( | ||||
hs.get_datastores().main.store_server_keys_json( | |||||
hs.get_datastores().main.store_server_keys_response( | |||||
self.OTHER_SERVER_NAME, | self.OTHER_SERVER_NAME, | ||||
verify_key_id, | |||||
from_server=self.OTHER_SERVER_NAME, | from_server=self.OTHER_SERVER_NAME, | ||||
ts_now_ms=clock.time_msec(), | |||||
ts_expires_ms=clock.time_msec() + 10000, | |||||
key_json_bytes=canonicaljson.encode_canonical_json( | |||||
{ | |||||
"verify_keys": { | |||||
verify_key_id: { | |||||
"key": signedjson.key.encode_verify_key_base64( | |||||
verify_key | |||||
) | |||||
} | |||||
ts_added_ms=clock.time_msec(), | |||||
verify_keys={ | |||||
verify_key_id: FetchKeyResult( | |||||
verify_key=verify_key, valid_until_ts=clock.time_msec() + 10000 | |||||
), | |||||
}, | |||||
response_json={ | |||||
"verify_keys": { | |||||
verify_key_id: { | |||||
"key": signedjson.key.encode_verify_key_base64(verify_key) | |||||
} | } | ||||
} | } | ||||
), | |||||
}, | |||||
) | ) | ||||
) | ) | ||||