@@ -0,0 +1 @@ | |||
Add cache to `get_server_keys_json_for_remote`. |
@@ -14,7 +14,7 @@ | |||
import logging | |||
import re | |||
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple | |||
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple | |||
from signedjson.sign import sign_json | |||
@@ -27,6 +27,7 @@ from synapse.http.servlet import ( | |||
parse_integer, | |||
parse_json_object_from_request, | |||
) | |||
from synapse.storage.keys import FetchKeyResultForRemote | |||
from synapse.types import JsonDict | |||
from synapse.util import json_decoder | |||
from synapse.util.async_helpers import yieldable_gather_results | |||
@@ -157,14 +158,22 @@ class RemoteKey(RestServlet): | |||
) -> JsonDict: | |||
logger.info("Handling query for keys %r", query) | |||
store_queries = [] | |||
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {} | |||
for server_name, key_ids in query.items(): | |||
if not key_ids: | |||
key_ids = (None,) | |||
for key_id in key_ids: | |||
store_queries.append((server_name, key_id, None)) | |||
if key_ids: | |||
results: Mapping[ | |||
str, Optional[FetchKeyResultForRemote] | |||
] = await self.store.get_server_keys_json_for_remote( | |||
server_name, key_ids | |||
) | |||
else: | |||
results = await self.store.get_all_server_keys_json_for_remote( | |||
server_name | |||
) | |||
cached = await self.store.get_server_keys_json_for_remote(store_queries) | |||
server_keys.update( | |||
((server_name, key_id), res) for key_id, res in results.items() | |||
) | |||
json_results: Set[bytes] = set() | |||
@@ -173,23 +182,20 @@ class RemoteKey(RestServlet): | |||
# Map server_name->key_id->int. Note that the value of the int is unused. | |||
# XXX: why don't we just use a set? | |||
cache_misses: Dict[str, Dict[str, int]] = {} | |||
for (server_name, key_id, _), key_results in cached.items(): | |||
results = [(result["ts_added_ms"], result) for result in key_results] | |||
if key_id is None: | |||
for (server_name, key_id), key_result in server_keys.items(): | |||
if not query[server_name]: | |||
# all keys were requested. Just return what we have without worrying | |||
# about validity | |||
for _, result in results: | |||
# Cast to bytes since postgresql returns a memoryview. | |||
json_results.add(bytes(result["key_json"])) | |||
if key_result: | |||
json_results.add(key_result.key_json) | |||
continue | |||
miss = False | |||
if not results: | |||
if key_result is None: | |||
miss = True | |||
else: | |||
ts_added_ms, most_recent_result = max(results) | |||
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"] | |||
ts_added_ms = key_result.added_ts | |||
ts_valid_until_ms = key_result.valid_until_ts | |||
req_key = query.get(server_name, {}).get(key_id, {}) | |||
req_valid_until = req_key.get("minimum_valid_until_ts") | |||
if req_valid_until is not None: | |||
@@ -235,8 +241,8 @@ class RemoteKey(RestServlet): | |||
ts_valid_until_ms, | |||
time_now_ms, | |||
) | |||
# Cast to bytes since postgresql returns a memoryview. | |||
json_results.add(bytes(most_recent_result["key_json"])) | |||
json_results.add(key_result.key_json) | |||
if miss and query_remote_on_cache_miss: | |||
# only bother attempting to fetch keys from servers on our whitelist | |||
@@ -16,14 +16,13 @@ | |||
import itertools | |||
import json | |||
import logging | |||
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple | |||
from typing import Dict, Iterable, Mapping, Optional, Tuple | |||
from signedjson.key import decode_verify_key_bytes | |||
from unpaddedbase64 import decode_base64 | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.keys import FetchKeyResult | |||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | |||
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote | |||
from synapse.storage.types import Cursor | |||
from synapse.util.caches.descriptors import cached, cachedList | |||
from synapse.util.iterutils import batch_iter | |||
@@ -34,7 +33,7 @@ logger = logging.getLogger(__name__) | |||
db_binary_type = memoryview | |||
class KeyStore(SQLBaseStore): | |||
class KeyStore(CacheInvalidationWorkerStore): | |||
"""Persistence for signature verification keys""" | |||
@cached() | |||
@@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore): | |||
# invalidate takes a tuple corresponding to the params of | |||
# _get_server_keys_json. _get_server_keys_json only takes one | |||
# param, which is itself the 2-tuple (server_name, key_id). | |||
self._get_server_keys_json.invalidate(((server_name, key_id),)) | |||
await self.invalidate_cache_and_stream( | |||
"_get_server_keys_json", ((server_name, key_id),) | |||
) | |||
await self.invalidate_cache_and_stream( | |||
"get_server_key_json_for_remote", (server_name, key_id) | |||
) | |||
@cached() | |||
def _get_server_keys_json( | |||
@@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore): | |||
return await self.db_pool.runInteraction("get_server_keys_json", _txn) | |||
async def get_server_keys_json_for_remote( | |||
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]] | |||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: | |||
"""Retrieve the key json for a list of server_keys and key ids. | |||
If no keys are found for a given server, key_id and source then | |||
that server, key_id, and source triplet entry will be an empty list. | |||
The JSON is returned as a byte array so that it can be efficiently | |||
used in an HTTP response. | |||
@cached() | |||
def get_server_key_json_for_remote( | |||
self, | |||
server_name: str, | |||
key_id: str, | |||
) -> Optional[FetchKeyResultForRemote]: | |||
raise NotImplementedError() | |||
Args: | |||
server_keys: List of (server_name, key_id, source) triplets. | |||
@cachedList( | |||
cached_method_name="get_server_key_json_for_remote", list_name="key_ids" | |||
) | |||
async def get_server_keys_json_for_remote( | |||
self, server_name: str, key_ids: Iterable[str] | |||
) -> Dict[str, Optional[FetchKeyResultForRemote]]: | |||
"""Fetch the cached keys for the given server/key IDs. | |||
Returns: | |||
A mapping from (server_name, key_id, source) triplets to a list of dicts | |||
If we have multiple entries for a given key ID, returns the most recent. | |||
""" | |||
rows = await self.db_pool.simple_select_many_batch( | |||
table="server_keys_json", | |||
column="key_id", | |||
iterable=key_ids, | |||
keyvalues={"server_name": server_name}, | |||
retcols=( | |||
"key_id", | |||
"from_server", | |||
"ts_added_ms", | |||
"ts_valid_until_ms", | |||
"key_json", | |||
), | |||
desc="get_server_keys_json_for_remote", | |||
) | |||
def _get_server_keys_json_txn( | |||
txn: LoggingTransaction, | |||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]: | |||
results = {} | |||
for server_name, key_id, from_server in server_keys: | |||
keyvalues = {"server_name": server_name} | |||
if key_id is not None: | |||
keyvalues["key_id"] = key_id | |||
if from_server is not None: | |||
keyvalues["from_server"] = from_server | |||
rows = self.db_pool.simple_select_list_txn( | |||
txn, | |||
"server_keys_json", | |||
keyvalues=keyvalues, | |||
retcols=( | |||
"key_id", | |||
"from_server", | |||
"ts_added_ms", | |||
"ts_valid_until_ms", | |||
"key_json", | |||
), | |||
) | |||
results[(server_name, key_id, from_server)] = rows | |||
return results | |||
if not rows: | |||
return {} | |||
# We sort the rows so that the most recently added entry is picked up. | |||
rows.sort(key=lambda r: r["ts_added_ms"]) | |||
return { | |||
row["key_id"]: FetchKeyResultForRemote( | |||
# Cast to bytes since postgresql returns a memoryview. | |||
key_json=bytes(row["key_json"]), | |||
valid_until_ts=row["ts_valid_until_ms"], | |||
added_ts=row["ts_added_ms"], | |||
) | |||
for row in rows | |||
} | |||
return await self.db_pool.runInteraction( | |||
"get_server_keys_json", _get_server_keys_json_txn | |||
async def get_all_server_keys_json_for_remote( | |||
self, | |||
server_name: str, | |||
) -> Dict[str, FetchKeyResultForRemote]: | |||
"""Fetch the cached keys for the given server. | |||
If we have multiple entries for a given key ID, returns the most recent. | |||
""" | |||
rows = await self.db_pool.simple_select_list( | |||
table="server_keys_json", | |||
keyvalues={"server_name": server_name}, | |||
retcols=( | |||
"key_id", | |||
"from_server", | |||
"ts_added_ms", | |||
"ts_valid_until_ms", | |||
"key_json", | |||
), | |||
desc="get_server_keys_json_for_remote", | |||
) | |||
if not rows: | |||
return {} | |||
rows.sort(key=lambda r: r["ts_added_ms"]) | |||
return { | |||
row["key_id"]: FetchKeyResultForRemote( | |||
# Cast to bytes since postgresql returns a memoryview. | |||
key_json=bytes(row["key_json"]), | |||
valid_until_ts=row["ts_valid_until_ms"], | |||
added_ts=row["ts_added_ms"], | |||
) | |||
for row in rows | |||
} |
@@ -25,3 +25,10 @@ logger = logging.getLogger(__name__) | |||
class FetchKeyResult: | |||
verify_key: VerifyKey # the key itself | |||
valid_until_ts: int # how long we can use this key for | |||
@attr.s(slots=True, frozen=True, auto_attribs=True) | |||
class FetchKeyResultForRemote: | |||
key_json: bytes # the full key JSON | |||
valid_until_ts: int # how long we can use this key for, in milliseconds. | |||
added_ts: int # When we added this key, in milliseconds. |
@@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(k.verify_key.version, "ver1") | |||
# check that the perspectives store is correctly updated | |||
lookup_triplet = (SERVER_NAME, testverifykey_id, None) | |||
key_json = self.get_success( | |||
self.hs.get_datastores().main.get_server_keys_json_for_remote( | |||
[lookup_triplet] | |||
SERVER_NAME, [testverifykey_id] | |||
) | |||
) | |||
res_keys = key_json[lookup_triplet] | |||
self.assertEqual(len(res_keys), 1) | |||
res = res_keys[0] | |||
self.assertEqual(res["key_id"], testverifykey_id) | |||
self.assertEqual(res["from_server"], SERVER_NAME) | |||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) | |||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) | |||
res = key_json[testverifykey_id] | |||
self.assertIsNotNone(res) | |||
assert res is not None | |||
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) | |||
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) | |||
# we expect it to be encoded as canonical json *before* it hits the db | |||
self.assertEqual( | |||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) | |||
) | |||
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) | |||
# change the server name: the result should be ignored | |||
response["server_name"] = "OTHER_SERVER" | |||
@@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(k.verify_key.version, "ver1") | |||
# check that the perspectives store is correctly updated | |||
lookup_triplet = (SERVER_NAME, testverifykey_id, None) | |||
key_json = self.get_success( | |||
self.hs.get_datastores().main.get_server_keys_json_for_remote( | |||
[lookup_triplet] | |||
SERVER_NAME, [testverifykey_id] | |||
) | |||
) | |||
res_keys = key_json[lookup_triplet] | |||
self.assertEqual(len(res_keys), 1) | |||
res = res_keys[0] | |||
self.assertEqual(res["key_id"], testverifykey_id) | |||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) | |||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) | |||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) | |||
self.assertEqual( | |||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) | |||
) | |||
res = key_json[testverifykey_id] | |||
self.assertIsNotNone(res) | |||
assert res is not None | |||
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) | |||
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) | |||
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) | |||
def test_get_multiple_keys_from_perspectives(self) -> None: | |||
"""Check that we can correctly request multiple keys for the same server""" | |||
@@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): | |||
self.assertEqual(k.verify_key.version, "ver1") | |||
# check that the perspectives store is correctly updated | |||
lookup_triplet = (SERVER_NAME, testverifykey_id, None) | |||
key_json = self.get_success( | |||
self.hs.get_datastores().main.get_server_keys_json_for_remote( | |||
[lookup_triplet] | |||
SERVER_NAME, [testverifykey_id] | |||
) | |||
) | |||
res_keys = key_json[lookup_triplet] | |||
self.assertEqual(len(res_keys), 1) | |||
res = res_keys[0] | |||
self.assertEqual(res["key_id"], testverifykey_id) | |||
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) | |||
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) | |||
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) | |||
self.assertEqual( | |||
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) | |||
) | |||
res = key_json[testverifykey_id] | |||
self.assertIsNotNone(res) | |||
assert res is not None | |||
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) | |||
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) | |||
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) | |||
def test_invalid_perspectives_responses(self) -> None: | |||
"""Check that invalid responses from the perspectives server are rejected""" | |||