@@ -0,0 +1 @@ | |||
Improve the performance of claiming encryption keys. |
@@ -659,6 +659,20 @@ class E2eKeysHandler: | |||
timeout: Optional[int], | |||
always_include_fallback_keys: bool, | |||
) -> JsonDict: | |||
""" | |||
Args: | |||
query: A chain of maps from (user_id, device_id, algorithm) to the requested | |||
number of keys to claim. | |||
user: The user who is claiming these keys. | |||
timeout: How long to wait for any federation key claim requests before | |||
giving up. | |||
always_include_fallback_keys: always include a fallback key for local users' | |||
devices, even if we managed to claim a one-time-key. | |||
Returns: a heterogeneous dict with two keys: | |||
one_time_keys: chain of maps user ID -> device ID -> key ID -> key. | |||
failures: map from remote destination to a JsonDict describing the error. | |||
""" | |||
local_query: List[Tuple[str, str, str, int]] = [] | |||
remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {} | |||
@@ -420,6 +420,16 @@ class LoggingTransaction: | |||
self._do_execute(self.txn.execute, sql, parameters) | |||
def executemany(self, sql: str, *args: Any) -> None: | |||
"""Repeatedly execute the same piece of SQL with different parameters. | |||
See https://peps.python.org/pep-0249/#executemany. Note in particular that | |||
> Use of this method for an operation which produces one or more result sets | |||
> constitutes undefined behavior | |||
so you can't use this for e.g. a SELECT, an UPDATE ... RETURNING, or a | |||
DELETE FROM... RETURNING. | |||
""" | |||
# TODO: we should add a type for *args here. Looking at Cursor.executemany | |||
# and DBAPI2 it ought to be Sequence[_Parameter], but we pass in | |||
# Iterable[Iterable[Any]] in execute_batch and execute_values above, which mypy | |||
@@ -24,6 +24,7 @@ from typing import ( | |||
Mapping, | |||
Optional, | |||
Sequence, | |||
Set, | |||
Tuple, | |||
Union, | |||
cast, | |||
@@ -1260,6 +1261,65 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker | |||
Returns: | |||
A map of user ID -> a map device ID -> a map of key ID -> JSON. | |||
""" | |||
if isinstance(self.database_engine, PostgresEngine): | |||
return await self.db_pool.runInteraction( | |||
"_claim_e2e_fallback_keys_bulk", | |||
self._claim_e2e_fallback_keys_bulk_txn, | |||
query_list, | |||
db_autocommit=True, | |||
) | |||
# Use an UPDATE FROM... RETURNING combined with a VALUES block to do | |||
# everything in one query. Note: this is also supported in SQLite 3.33.0, | |||
# (see https://www.sqlite.org/lang_update.html#update_from), but we do not | |||
# have an equivalent of psycopg2's execute_values to do this in one query. | |||
else: | |||
return await self._claim_e2e_fallback_keys_simple(query_list) | |||
def _claim_e2e_fallback_keys_bulk_txn( | |||
self, | |||
txn: LoggingTransaction, | |||
query_list: Iterable[Tuple[str, str, str, bool]], | |||
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: | |||
"""Efficient implementation of claim_e2e_fallback_keys for Postgres. | |||
Safe to autocommit: this is a single query. | |||
""" | |||
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} | |||
sql = """ | |||
WITH claims(user_id, device_id, algorithm, mark_as_used) AS ( | |||
VALUES ? | |||
) | |||
UPDATE e2e_fallback_keys_json k | |||
SET used = used OR mark_as_used | |||
FROM claims | |||
WHERE (k.user_id, k.device_id, k.algorithm) = (claims.user_id, claims.device_id, claims.algorithm) | |||
RETURNING k.user_id, k.device_id, k.algorithm, k.key_id, k.key_json; | |||
""" | |||
claimed_keys = cast( | |||
List[Tuple[str, str, str, str, str]], | |||
txn.execute_values(sql, query_list), | |||
) | |||
seen_user_device: Set[Tuple[str, str]] = set() | |||
for user_id, device_id, algorithm, key_id, key_json in claimed_keys: | |||
device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) | |||
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) | |||
if (user_id, device_id) in seen_user_device: | |||
continue | |||
seen_user_device.add((user_id, device_id)) | |||
self._invalidate_cache_and_stream( | |||
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id) | |||
) | |||
return results | |||
async def _claim_e2e_fallback_keys_simple( | |||
self, | |||
query_list: Iterable[Tuple[str, str, str, bool]], | |||
) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: | |||
"""Naive, inefficient implementation of claim_e2e_fallback_keys for SQLite.""" | |||
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} | |||
for user_id, device_id, algorithm, mark_as_used in query_list: | |||
row = await self.db_pool.simple_select_one( | |||
@@ -322,6 +322,83 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): | |||
{"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, | |||
) | |||
def test_fallback_key_bulk(self) -> None: | |||
"""Like test_fallback_key, but claims multiple keys in one handler call.""" | |||
alice = f"@alice:{self.hs.hostname}" | |||
brian = f"@brian:{self.hs.hostname}" | |||
chris = f"@chris:{self.hs.hostname}" | |||
# Have three users upload fallback keys for two devices. | |||
fallback_keys = { | |||
alice: { | |||
"alice_dev_1": {"alg1:k1": "fallback_key1"}, | |||
"alice_dev_2": {"alg2:k2": "fallback_key2"}, | |||
}, | |||
brian: { | |||
"brian_dev_1": {"alg1:k3": "fallback_key3"}, | |||
"brian_dev_2": {"alg2:k4": "fallback_key4"}, | |||
}, | |||
chris: { | |||
"chris_dev_1": {"alg1:k5": "fallback_key5"}, | |||
"chris_dev_2": {"alg2:k6": "fallback_key6"}, | |||
}, | |||
} | |||
for user_id, devices in fallback_keys.items(): | |||
for device_id, key_dict in devices.items(): | |||
self.get_success( | |||
self.handler.upload_keys_for_user( | |||
user_id, | |||
device_id, | |||
{"fallback_keys": key_dict}, | |||
) | |||
) | |||
# Each device should have an unused fallback key. | |||
for user_id, devices in fallback_keys.items(): | |||
for device_id in devices: | |||
fallback_res = self.get_success( | |||
self.store.get_e2e_unused_fallback_key_types(user_id, device_id) | |||
) | |||
expected_algorithm_name = f"alg{device_id[-1]}" | |||
self.assertEqual(fallback_res, [expected_algorithm_name]) | |||
# Claim the fallback key for one device per user. | |||
claim_res = self.get_success( | |||
self.handler.claim_one_time_keys( | |||
{ | |||
alice: {"alice_dev_1": {"alg1": 1}}, | |||
brian: {"brian_dev_2": {"alg2": 1}}, | |||
chris: {"chris_dev_2": {"alg2": 1}}, | |||
}, | |||
self.requester, | |||
timeout=None, | |||
always_include_fallback_keys=False, | |||
) | |||
) | |||
expected_claims = { | |||
alice: {"alice_dev_1": {"alg1:k1": "fallback_key1"}}, | |||
brian: {"brian_dev_2": {"alg2:k4": "fallback_key4"}}, | |||
chris: {"chris_dev_2": {"alg2:k6": "fallback_key6"}}, | |||
} | |||
self.assertEqual( | |||
claim_res, | |||
{"failures": {}, "one_time_keys": expected_claims}, | |||
) | |||
for user_id, devices in fallback_keys.items(): | |||
for device_id in devices: | |||
fallback_res = self.get_success( | |||
self.store.get_e2e_unused_fallback_key_types(user_id, device_id) | |||
) | |||
# Claimed fallback keys should no longer show up as unused. | |||
# Unclaimed fallback keys should still be unused. | |||
if device_id in expected_claims[user_id]: | |||
self.assertEqual(fallback_res, []) | |||
else: | |||
expected_algorithm_name = f"alg{device_id[-1]}" | |||
self.assertEqual(fallback_res, [expected_algorithm_name]) | |||
def test_fallback_key_always_returned(self) -> None: | |||
local_user = "@boris:" + self.hs.hostname | |||
device_id = "xyz" | |||