|
|
@@ -1111,7 +1111,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker |
|
|
|
... |
|
|
|
|
|
|
|
async def claim_e2e_one_time_keys( |
|
|
|
self, query_list: Iterable[Tuple[str, str, str, int]] |
|
|
|
self, query_list: Collection[Tuple[str, str, str, int]] |
|
|
|
) -> Tuple[ |
|
|
|
Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] |
|
|
|
]: |
|
|
@@ -1121,131 +1121,63 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker |
|
|
|
query_list: An iterable of tuples of (user ID, device ID, algorithm). |
|
|
|
|
|
|
|
Returns: |
|
|
|
A tuple pf: |
|
|
|
A tuple (results, missing) of: |
|
|
|
A map of user ID -> a map device ID -> a map of key ID -> JSON. |
|
|
|
|
|
|
|
A copy of the input which has not been fulfilled. |
|
|
|
A copy of the input which has not been fulfilled. The returned counts |
|
|
|
may be less than the input counts. In this case, the returned counts |
|
|
|
are the number of claims that were not fulfilled. |
|
|
|
""" |
|
|
|
|
|
|
|
@trace |
|
|
|
def _claim_e2e_one_time_key_simple( |
|
|
|
txn: LoggingTransaction, |
|
|
|
user_id: str, |
|
|
|
device_id: str, |
|
|
|
algorithm: str, |
|
|
|
count: int, |
|
|
|
) -> List[Tuple[str, str]]: |
|
|
|
"""Claim OTK for device for DBs that don't support RETURNING. |
|
|
|
|
|
|
|
Returns: |
|
|
|
A tuple of key name (algorithm + key ID) and key JSON, if an |
|
|
|
OTK was found. |
|
|
|
""" |
|
|
|
|
|
|
|
sql = """ |
|
|
|
SELECT key_id, key_json FROM e2e_one_time_keys_json |
|
|
|
WHERE user_id = ? AND device_id = ? AND algorithm = ? |
|
|
|
LIMIT ? |
|
|
|
""" |
|
|
|
|
|
|
|
txn.execute(sql, (user_id, device_id, algorithm, count)) |
|
|
|
otk_rows = list(txn) |
|
|
|
if not otk_rows: |
|
|
|
return [] |
|
|
|
|
|
|
|
self.db_pool.simple_delete_many_txn( |
|
|
|
txn, |
|
|
|
table="e2e_one_time_keys_json", |
|
|
|
column="key_id", |
|
|
|
values=[otk_row[0] for otk_row in otk_rows], |
|
|
|
keyvalues={ |
|
|
|
"user_id": user_id, |
|
|
|
"device_id": device_id, |
|
|
|
"algorithm": algorithm, |
|
|
|
}, |
|
|
|
) |
|
|
|
self._invalidate_cache_and_stream( |
|
|
|
txn, self.count_e2e_one_time_keys, (user_id, device_id) |
|
|
|
) |
|
|
|
|
|
|
|
return [ |
|
|
|
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows |
|
|
|
] |
|
|
|
|
|
|
|
@trace |
|
|
|
def _claim_e2e_one_time_key_returning( |
|
|
|
txn: LoggingTransaction, |
|
|
|
user_id: str, |
|
|
|
device_id: str, |
|
|
|
algorithm: str, |
|
|
|
count: int, |
|
|
|
) -> List[Tuple[str, str]]: |
|
|
|
"""Claim OTK for device for DBs that support RETURNING. |
|
|
|
|
|
|
|
Returns: |
|
|
|
A tuple of key name (algorithm + key ID) and key JSON, if an |
|
|
|
OTK was found. |
|
|
|
""" |
|
|
|
|
|
|
|
# We can use RETURNING to do the fetch and DELETE in once step. |
|
|
|
sql = """ |
|
|
|
DELETE FROM e2e_one_time_keys_json |
|
|
|
WHERE user_id = ? AND device_id = ? AND algorithm = ? |
|
|
|
AND key_id IN ( |
|
|
|
SELECT key_id FROM e2e_one_time_keys_json |
|
|
|
WHERE user_id = ? AND device_id = ? AND algorithm = ? |
|
|
|
LIMIT ? |
|
|
|
) |
|
|
|
RETURNING key_id, key_json |
|
|
|
""" |
|
|
|
|
|
|
|
txn.execute( |
|
|
|
sql, |
|
|
|
(user_id, device_id, algorithm, user_id, device_id, algorithm, count), |
|
|
|
) |
|
|
|
otk_rows = list(txn) |
|
|
|
if not otk_rows: |
|
|
|
return [] |
|
|
|
|
|
|
|
self._invalidate_cache_and_stream( |
|
|
|
txn, self.count_e2e_one_time_keys, (user_id, device_id) |
|
|
|
) |
|
|
|
|
|
|
|
return [ |
|
|
|
(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows |
|
|
|
] |
|
|
|
|
|
|
|
results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} |
|
|
|
missing: List[Tuple[str, str, str, int]] = [] |
|
|
|
for user_id, device_id, algorithm, count in query_list: |
|
|
|
if self.database_engine.supports_returning: |
|
|
|
# If we support RETURNING clause we can use a single query that |
|
|
|
# allows us to use autocommit mode. |
|
|
|
_claim_e2e_one_time_key = _claim_e2e_one_time_key_returning |
|
|
|
db_autocommit = True |
|
|
|
else: |
|
|
|
_claim_e2e_one_time_key = _claim_e2e_one_time_key_simple |
|
|
|
db_autocommit = False |
|
|
|
if isinstance(self.database_engine, PostgresEngine): |
|
|
|
# If we can use execute_values we can use a single batch query |
|
|
|
# in autocommit mode. |
|
|
|
unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {} |
|
|
|
for user_id, device_id, algorithm, count in query_list: |
|
|
|
unfulfilled_claim_counts[user_id, device_id, algorithm] = count |
|
|
|
|
|
|
|
claim_rows = await self.db_pool.runInteraction( |
|
|
|
bulk_claims = await self.db_pool.runInteraction( |
|
|
|
"claim_e2e_one_time_keys", |
|
|
|
_claim_e2e_one_time_key, |
|
|
|
user_id, |
|
|
|
device_id, |
|
|
|
algorithm, |
|
|
|
count, |
|
|
|
db_autocommit=db_autocommit, |
|
|
|
self._claim_e2e_one_time_keys_bulk, |
|
|
|
query_list, |
|
|
|
db_autocommit=True, |
|
|
|
) |
|
|
|
if claim_rows: |
|
|
|
|
|
|
|
for user_id, device_id, algorithm, key_id, key_json in bulk_claims: |
|
|
|
device_results = results.setdefault(user_id, {}).setdefault( |
|
|
|
device_id, {} |
|
|
|
) |
|
|
|
for claim_row in claim_rows: |
|
|
|
device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) |
|
|
|
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) |
|
|
|
unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1 |
|
|
|
|
|
|
|
# Did we get enough OTKs? |
|
|
|
count -= len(claim_rows) |
|
|
|
if count: |
|
|
|
missing.append((user_id, device_id, algorithm, count)) |
|
|
|
missing = [ |
|
|
|
(user, device, alg, count) |
|
|
|
for (user, device, alg), count in unfulfilled_claim_counts.items() |
|
|
|
if count > 0 |
|
|
|
] |
|
|
|
else: |
|
|
|
for user_id, device_id, algorithm, count in query_list: |
|
|
|
claim_rows = await self.db_pool.runInteraction( |
|
|
|
"claim_e2e_one_time_keys", |
|
|
|
self._claim_e2e_one_time_key_simple, |
|
|
|
user_id, |
|
|
|
device_id, |
|
|
|
algorithm, |
|
|
|
count, |
|
|
|
db_autocommit=False, |
|
|
|
) |
|
|
|
if claim_rows: |
|
|
|
device_results = results.setdefault(user_id, {}).setdefault( |
|
|
|
device_id, {} |
|
|
|
) |
|
|
|
for claim_row in claim_rows: |
|
|
|
device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) |
|
|
|
# Did we get enough OTKs? |
|
|
|
count -= len(claim_rows) |
|
|
|
if count: |
|
|
|
missing.append((user_id, device_id, algorithm, count)) |
|
|
|
|
|
|
|
return results, missing |
|
|
|
|
|
|
@@ -1362,6 +1294,99 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker |
|
|
|
|
|
|
|
return results |
|
|
|
|
|
|
|
@trace |
|
|
|
def _claim_e2e_one_time_key_simple( |
|
|
|
self, |
|
|
|
txn: LoggingTransaction, |
|
|
|
user_id: str, |
|
|
|
device_id: str, |
|
|
|
algorithm: str, |
|
|
|
count: int, |
|
|
|
) -> List[Tuple[str, str]]: |
|
|
|
"""Claim OTK for device for DBs that don't support RETURNING. |
|
|
|
|
|
|
|
Returns: |
|
|
|
A tuple of key name (algorithm + key ID) and key JSON, if an |
|
|
|
OTK was found. |
|
|
|
""" |
|
|
|
|
|
|
|
sql = """ |
|
|
|
SELECT key_id, key_json FROM e2e_one_time_keys_json |
|
|
|
WHERE user_id = ? AND device_id = ? AND algorithm = ? |
|
|
|
LIMIT ? |
|
|
|
""" |
|
|
|
|
|
|
|
txn.execute(sql, (user_id, device_id, algorithm, count)) |
|
|
|
otk_rows = list(txn) |
|
|
|
if not otk_rows: |
|
|
|
return [] |
|
|
|
|
|
|
|
self.db_pool.simple_delete_many_txn( |
|
|
|
txn, |
|
|
|
table="e2e_one_time_keys_json", |
|
|
|
column="key_id", |
|
|
|
values=[otk_row[0] for otk_row in otk_rows], |
|
|
|
keyvalues={ |
|
|
|
"user_id": user_id, |
|
|
|
"device_id": device_id, |
|
|
|
"algorithm": algorithm, |
|
|
|
}, |
|
|
|
) |
|
|
|
self._invalidate_cache_and_stream( |
|
|
|
txn, self.count_e2e_one_time_keys, (user_id, device_id) |
|
|
|
) |
|
|
|
|
|
|
|
return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows] |
|
|
|
|
|
|
|
@trace |
|
|
|
def _claim_e2e_one_time_keys_bulk( |
|
|
|
self, |
|
|
|
txn: LoggingTransaction, |
|
|
|
query_list: Iterable[Tuple[str, str, str, int]], |
|
|
|
) -> List[Tuple[str, str, str, str, str]]: |
|
|
|
"""Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING. |
|
|
|
|
|
|
|
Args: |
|
|
|
query_list: Collection of tuples (user_id, device_id, algorithm, count) |
|
|
|
as passed to claim_e2e_one_time_keys. |
|
|
|
|
|
|
|
Returns: |
|
|
|
A list of tuples (user_id, device_id, algorithm, key_id, key_json) |
|
|
|
for each OTK claimed. |
|
|
|
""" |
|
|
|
sql = """ |
|
|
|
WITH claims(user_id, device_id, algorithm, claim_count) AS ( |
|
|
|
VALUES ? |
|
|
|
), ranked_keys AS ( |
|
|
|
SELECT |
|
|
|
user_id, device_id, algorithm, key_id, claim_count, |
|
|
|
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r |
|
|
|
FROM e2e_one_time_keys_json |
|
|
|
JOIN claims USING (user_id, device_id, algorithm) |
|
|
|
) |
|
|
|
DELETE FROM e2e_one_time_keys_json k |
|
|
|
WHERE (user_id, device_id, algorithm, key_id) IN ( |
|
|
|
SELECT user_id, device_id, algorithm, key_id |
|
|
|
FROM ranked_keys |
|
|
|
WHERE r <= claim_count |
|
|
|
) |
|
|
|
RETURNING user_id, device_id, algorithm, key_id, key_json; |
|
|
|
""" |
|
|
|
otk_rows = cast( |
|
|
|
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list) |
|
|
|
) |
|
|
|
|
|
|
|
seen_user_device: Set[Tuple[str, str]] = set() |
|
|
|
for user_id, device_id, _, _, _ in otk_rows: |
|
|
|
if (user_id, device_id) in seen_user_device: |
|
|
|
continue |
|
|
|
seen_user_device.add((user_id, device_id)) |
|
|
|
self._invalidate_cache_and_stream( |
|
|
|
txn, self.count_e2e_one_time_keys, (user_id, device_id) |
|
|
|
) |
|
|
|
|
|
|
|
return otk_rows |
|
|
|
|
|
|
|
|
|
|
|
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): |
|
|
|
def __init__( |
|
|
|