Selaa lähdekoodia

Add cache to `get_server_keys_json_for_remote` (#16123)

tags/v1.91.0rc1
Erik Johnston 9 kuukautta sitten
committed by GitHub
vanhempi
commit
0aba4a4eaa
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
5 muutettua tiedostoa jossa 144 lisäystä ja 101 poistoa
  1. +1
    -0
      changelog.d/16123.misc
  2. +25
    -19
      synapse/rest/key/v2/remote_key_resource.py
  3. +88
    -44
      synapse/storage/databases/main/keys.py
  4. +7
    -0
      synapse/storage/keys.py
  5. +23
    -38
      tests/crypto/test_keyring.py

+ 1
- 0
changelog.d/16123.misc Näytä tiedosto

@@ -0,0 +1 @@
Add cache to `get_server_keys_json_for_remote`.

+ 25
- 19
synapse/rest/key/v2/remote_key_resource.py Näytä tiedosto

@@ -14,7 +14,7 @@

import logging
import re
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple

from signedjson.sign import sign_json

@@ -27,6 +27,7 @@ from synapse.http.servlet import (
parse_integer,
parse_json_object_from_request,
)
from synapse.storage.keys import FetchKeyResultForRemote
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
) -> JsonDict:
logger.info("Handling query for keys %r", query)

store_queries = []
server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
for server_name, key_ids in query.items():
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
if key_ids:
results: Mapping[
str, Optional[FetchKeyResultForRemote]
] = await self.store.get_server_keys_json_for_remote(
server_name, key_ids
)
else:
results = await self.store.get_all_server_keys_json_for_remote(
server_name
)

cached = await self.store.get_server_keys_json_for_remote(store_queries)
server_keys.update(
((server_name, key_id), res) for key_id, res in results.items()
)

json_results: Set[bytes] = set()

@@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
# Map server_name->key_id->int. Note that the value of the int is unused.
# XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in key_results]

if key_id is None:
for (server_name, key_id), key_result in server_keys.items():
if not query[server_name]:
# all keys were requested. Just return what we have without worrying
# about validity
for _, result in results:
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
if key_result:
json_results.add(key_result.key_json)
continue

miss = False
if not results:
if key_result is None:
miss = True
else:
ts_added_ms, most_recent_result = max(results)
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
ts_added_ms = key_result.added_ts
ts_valid_until_ms = key_result.valid_until_ts
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
if req_valid_until is not None:
@@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
ts_valid_until_ms,
time_now_ms,
)
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
json_results.add(key_result.key_json)

if miss and query_remote_on_cache_miss:
# only bother attempting to fetch keys from servers on our whitelist


+ 88
- 44
synapse/storage/databases/main/keys.py Näytä tiedosto

@@ -16,14 +16,13 @@
import itertools
import json
import logging
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
from typing import Dict, Iterable, Mapping, Optional, Tuple

from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.keys import FetchKeyResult
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
db_binary_type = memoryview


class KeyStore(SQLBaseStore):
class KeyStore(CacheInvalidationWorkerStore):
"""Persistence for signature verification keys"""

@cached()
@@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
# invalidate takes a tuple corresponding to the params of
# _get_server_keys_json. _get_server_keys_json only takes one
# param, which is itself the 2-tuple (server_name, key_id).
self._get_server_keys_json.invalidate(((server_name, key_id),))
await self.invalidate_cache_and_stream(
"_get_server_keys_json", ((server_name, key_id),)
)
await self.invalidate_cache_and_stream(
"get_server_key_json_for_remote", (server_name, key_id)
)

@cached()
def _get_server_keys_json(
@@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):

return await self.db_pool.runInteraction("get_server_keys_json", _txn)

async def get_server_keys_json_for_remote(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
"""Retrieve the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
The JSON is returned as a byte array so that it can be efficiently
used in an HTTP response.
@cached()
def get_server_key_json_for_remote(
self,
server_name: str,
key_id: str,
) -> Optional[FetchKeyResultForRemote]:
raise NotImplementedError()

Args:
server_keys: List of (server_name, key_id, source) triplets.
@cachedList(
cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
)
async def get_server_keys_json_for_remote(
self, server_name: str, key_ids: Iterable[str]
) -> Dict[str, Optional[FetchKeyResultForRemote]]:
"""Fetch the cached keys for the given server/key IDs.

Returns:
A mapping from (server_name, key_id, source) triplets to a list of dicts
If we have multiple entries for a given key ID, returns the most recent.
"""
rows = await self.db_pool.simple_select_many_batch(
table="server_keys_json",
column="key_id",
iterable=key_ids,
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
)

def _get_server_keys_json_txn(
txn: LoggingTransaction,
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
results = {}
for server_name, key_id, from_server in server_keys:
keyvalues = {"server_name": server_name}
if key_id is not None:
keyvalues["key_id"] = key_id
if from_server is not None:
keyvalues["from_server"] = from_server
rows = self.db_pool.simple_select_list_txn(
txn,
"server_keys_json",
keyvalues=keyvalues,
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
)
results[(server_name, key_id, from_server)] = rows
return results
if not rows:
return {}

# We sort the rows so that the most recently added entry is picked up.
rows.sort(key=lambda r: r["ts_added_ms"])

return {
row["key_id"]: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
)
for row in rows
}

return await self.db_pool.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
async def get_all_server_keys_json_for_remote(
self,
server_name: str,
) -> Dict[str, FetchKeyResultForRemote]:
"""Fetch the cached keys for the given server.

If we have multiple entries for a given key ID, returns the most recent.
"""
rows = await self.db_pool.simple_select_list(
table="server_keys_json",
keyvalues={"server_name": server_name},
retcols=(
"key_id",
"from_server",
"ts_added_ms",
"ts_valid_until_ms",
"key_json",
),
desc="get_server_keys_json_for_remote",
)

if not rows:
return {}

rows.sort(key=lambda r: r["ts_added_ms"])

return {
row["key_id"]: FetchKeyResultForRemote(
# Cast to bytes since postgresql returns a memoryview.
key_json=bytes(row["key_json"]),
valid_until_ts=row["ts_valid_until_ms"],
added_ts=row["ts_added_ms"],
)
for row in rows
}

+ 7
- 0
synapse/storage/keys.py Näytä tiedosto

@@ -25,3 +25,10 @@ logger = logging.getLogger(__name__)
class FetchKeyResult:
verify_key: VerifyKey # the key itself
valid_until_ts: int # how long we can use this key for


@attr.s(slots=True, frozen=True, auto_attribs=True)
class FetchKeyResultForRemote:
key_json: bytes # the full key JSON
valid_until_ts: int # how long we can use this key for, in milliseconds.
added_ts: int # When we added this key, in milliseconds.

+ 23
- 38
tests/crypto/test_keyring.py Näytä tiedosto

@@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")

# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], SERVER_NAME)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)

# we expect it to be encoded as canonical json *before* it hits the db
self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))

# change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
@@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")

# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)

self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)

self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))

def test_get_multiple_keys_from_perspectives(self) -> None:
"""Check that we can correctly request multiple keys for the same server"""
@@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")

# check that the perspectives store is correctly updated
lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
[lookup_triplet]
SERVER_NAME, [testverifykey_id]
)
)
res_keys = key_json[lookup_triplet]
self.assertEqual(len(res_keys), 1)
res = res_keys[0]
self.assertEqual(res["key_id"], testverifykey_id)
self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)

self.assertEqual(
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
res = key_json[testverifykey_id]
self.assertIsNotNone(res)
assert res is not None
self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)

self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))

def test_invalid_perspectives_responses(self) -> None:
"""Check that invalid responses from the perspectives server are rejected"""


Ladataan…
Peruuta
Tallenna