@@ -0,0 +1 @@ | |||
Add experimental support for [MSC3391](https://github.com/matrix-org/matrix-spec-proposals/pull/3391) (removing account data). |
@@ -102,6 +102,8 @@ experimental_features: | |||
{% endif %} | |||
# Filtering /messages by relation type. | |||
msc3874_enabled: true | |||
# Enable removing account data support | |||
msc3391_enabled: true | |||
server_notices: | |||
system_mxid_localpart: _server | |||
@@ -190,7 +190,7 @@ fi | |||
extra_test_args=() | |||
test_tags="synapse_blacklist,msc3787,msc3874" | |||
test_tags="synapse_blacklist,msc3787,msc3874,msc3391" | |||
# All environment variables starting with PASS_ will be shared. | |||
# (The prefix is stripped off before reaching the container.) | |||
@@ -136,3 +136,6 @@ class ExperimentalConfig(Config): | |||
# Enable room version (and thus applicable push rules from MSC3931/3932) | |||
version_id = RoomVersions.MSC1767v10.identifier | |||
KNOWN_ROOM_VERSIONS[version_id] = RoomVersions.MSC1767v10 | |||
# MSC3391: Removing account data. | |||
self.msc3391_enabled = experimental.get("msc3391_enabled", False) |
@@ -17,10 +17,12 @@ import random | |||
from typing import TYPE_CHECKING, Awaitable, Callable, Collection, List, Optional, Tuple | |||
from synapse.replication.http.account_data import ( | |||
ReplicationAddRoomAccountDataRestServlet, | |||
ReplicationAddTagRestServlet, | |||
ReplicationAddUserAccountDataRestServlet, | |||
ReplicationRemoveRoomAccountDataRestServlet, | |||
ReplicationRemoveTagRestServlet, | |||
ReplicationRoomAccountDataRestServlet, | |||
ReplicationUserAccountDataRestServlet, | |||
ReplicationRemoveUserAccountDataRestServlet, | |||
) | |||
from synapse.streams import EventSource | |||
from synapse.types import JsonDict, StreamKeyType, UserID | |||
@@ -41,8 +43,18 @@ class AccountDataHandler: | |||
self._instance_name = hs.get_instance_name() | |||
self._notifier = hs.get_notifier() | |||
self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs) | |||
self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs) | |||
self._add_user_data_client = ( | |||
ReplicationAddUserAccountDataRestServlet.make_client(hs) | |||
) | |||
self._remove_user_data_client = ( | |||
ReplicationRemoveUserAccountDataRestServlet.make_client(hs) | |||
) | |||
self._add_room_data_client = ( | |||
ReplicationAddRoomAccountDataRestServlet.make_client(hs) | |||
) | |||
self._remove_room_data_client = ( | |||
ReplicationRemoveRoomAccountDataRestServlet.make_client(hs) | |||
) | |||
self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs) | |||
self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) | |||
self._account_data_writers = hs.config.worker.writers.account_data | |||
@@ -112,7 +124,7 @@ class AccountDataHandler: | |||
return max_stream_id | |||
else: | |||
response = await self._room_data_client( | |||
response = await self._add_room_data_client( | |||
instance_name=random.choice(self._account_data_writers), | |||
user_id=user_id, | |||
room_id=room_id, | |||
@@ -121,15 +133,59 @@ class AccountDataHandler: | |||
) | |||
return response["max_stream_id"] | |||
async def remove_account_data_for_room( | |||
self, user_id: str, room_id: str, account_data_type: str | |||
) -> Optional[int]: | |||
""" | |||
Deletes the room account data for the given user and account data type. | |||
"Deleting" account data merely means setting the content of the account data | |||
to an empty JSON object: {}. | |||
Args: | |||
user_id: The user ID to remove room account data for. | |||
room_id: The room ID to target. | |||
account_data_type: The account data type to remove. | |||
Returns: | |||
The maximum stream ID, or None if the room account data item did not exist. | |||
""" | |||
if self._instance_name in self._account_data_writers: | |||
max_stream_id = await self._store.remove_account_data_for_room( | |||
user_id, room_id, account_data_type | |||
) | |||
if max_stream_id is None: | |||
# The referenced account data did not exist, so no delete occurred. | |||
return None | |||
self._notifier.on_new_event( | |||
StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] | |||
) | |||
# Notify Synapse modules that the content of the type has changed to an | |||
# empty dictionary. | |||
await self._notify_modules(user_id, room_id, account_data_type, {}) | |||
return max_stream_id | |||
else: | |||
response = await self._remove_room_data_client( | |||
instance_name=random.choice(self._account_data_writers), | |||
user_id=user_id, | |||
room_id=room_id, | |||
account_data_type=account_data_type, | |||
content={}, | |||
) | |||
return response["max_stream_id"] | |||
async def add_account_data_for_user( | |||
self, user_id: str, account_data_type: str, content: JsonDict | |||
) -> int: | |||
"""Add some global account_data for a user. | |||
Args: | |||
user_id: The user to add a tag for. | |||
user_id: The user to add some account data for. | |||
account_data_type: The type of account_data to add. | |||
content: A json object to associate with the tag. | |||
content: The content json dictionary. | |||
Returns: | |||
The maximum stream ID. | |||
@@ -148,7 +204,7 @@ class AccountDataHandler: | |||
return max_stream_id | |||
else: | |||
response = await self._user_data_client( | |||
response = await self._add_user_data_client( | |||
instance_name=random.choice(self._account_data_writers), | |||
user_id=user_id, | |||
account_data_type=account_data_type, | |||
@@ -156,6 +212,45 @@ class AccountDataHandler: | |||
) | |||
return response["max_stream_id"] | |||
async def remove_account_data_for_user( | |||
self, user_id: str, account_data_type: str | |||
) -> Optional[int]: | |||
"""Removes a piece of global account_data for a user. | |||
Args: | |||
user_id: The user to remove account data for. | |||
account_data_type: The type of account_data to remove. | |||
Returns: | |||
The maximum stream ID, or None if the room account data item did not exist. | |||
""" | |||
if self._instance_name in self._account_data_writers: | |||
max_stream_id = await self._store.remove_account_data_for_user( | |||
user_id, account_data_type | |||
) | |||
if max_stream_id is None: | |||
# The referenced account data did not exist, so no delete occurred. | |||
return None | |||
self._notifier.on_new_event( | |||
StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id] | |||
) | |||
# Notify Synapse modules that the content of the type has changed to an | |||
# empty dictionary. | |||
await self._notify_modules(user_id, None, account_data_type, {}) | |||
return max_stream_id | |||
else: | |||
response = await self._remove_user_data_client( | |||
instance_name=random.choice(self._account_data_writers), | |||
user_id=user_id, | |||
account_data_type=account_data_type, | |||
content={}, | |||
) | |||
return response["max_stream_id"] | |||
async def add_tag_to_room( | |||
self, user_id: str, room_id: str, tag: str, content: JsonDict | |||
) -> int: | |||
@@ -28,7 +28,7 @@ if TYPE_CHECKING: | |||
logger = logging.getLogger(__name__) | |||
class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): | |||
class ReplicationAddUserAccountDataRestServlet(ReplicationEndpoint): | |||
"""Add user account data on the appropriate account data worker. | |||
Request format: | |||
@@ -49,7 +49,6 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): | |||
super().__init__(hs) | |||
self.handler = hs.get_account_data_handler() | |||
self.clock = hs.get_clock() | |||
@staticmethod | |||
async def _serialize_payload( # type: ignore[override] | |||
@@ -73,7 +72,45 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): | |||
return 200, {"max_stream_id": max_stream_id} | |||
class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): | |||
class ReplicationRemoveUserAccountDataRestServlet(ReplicationEndpoint): | |||
"""Remove user account data on the appropriate account data worker. | |||
Request format: | |||
POST /_synapse/replication/remove_user_account_data/:user_id/:type | |||
{ | |||
"content": { ... }, | |||
} | |||
""" | |||
NAME = "remove_user_account_data" | |||
PATH_ARGS = ("user_id", "account_data_type") | |||
CACHE = False | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.handler = hs.get_account_data_handler() | |||
@staticmethod | |||
async def _serialize_payload( # type: ignore[override] | |||
user_id: str, account_data_type: str | |||
) -> JsonDict: | |||
return {} | |||
async def _handle_request( # type: ignore[override] | |||
self, request: Request, user_id: str, account_data_type: str | |||
) -> Tuple[int, JsonDict]: | |||
max_stream_id = await self.handler.remove_account_data_for_user( | |||
user_id, account_data_type | |||
) | |||
return 200, {"max_stream_id": max_stream_id} | |||
class ReplicationAddRoomAccountDataRestServlet(ReplicationEndpoint): | |||
"""Add room account data on the appropriate account data worker. | |||
Request format: | |||
@@ -94,7 +131,6 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): | |||
super().__init__(hs) | |||
self.handler = hs.get_account_data_handler() | |||
self.clock = hs.get_clock() | |||
@staticmethod | |||
async def _serialize_payload( # type: ignore[override] | |||
@@ -118,6 +154,44 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): | |||
return 200, {"max_stream_id": max_stream_id} | |||
class ReplicationRemoveRoomAccountDataRestServlet(ReplicationEndpoint): | |||
"""Remove room account data on the appropriate account data worker. | |||
Request format: | |||
POST /_synapse/replication/remove_room_account_data/:user_id/:room_id/:account_data_type | |||
{ | |||
"content": { ... }, | |||
} | |||
""" | |||
NAME = "remove_room_account_data" | |||
PATH_ARGS = ("user_id", "room_id", "account_data_type") | |||
CACHE = False | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.handler = hs.get_account_data_handler() | |||
@staticmethod | |||
async def _serialize_payload( # type: ignore[override] | |||
user_id: str, room_id: str, account_data_type: str, content: JsonDict | |||
) -> JsonDict: | |||
return {} | |||
async def _handle_request( # type: ignore[override] | |||
self, request: Request, user_id: str, room_id: str, account_data_type: str | |||
) -> Tuple[int, JsonDict]: | |||
max_stream_id = await self.handler.remove_account_data_for_room( | |||
user_id, room_id, account_data_type | |||
) | |||
return 200, {"max_stream_id": max_stream_id} | |||
class ReplicationAddTagRestServlet(ReplicationEndpoint): | |||
"""Add tag on the appropriate account data worker. | |||
@@ -139,7 +213,6 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint): | |||
super().__init__(hs) | |||
self.handler = hs.get_account_data_handler() | |||
self.clock = hs.get_clock() | |||
@staticmethod | |||
async def _serialize_payload( # type: ignore[override] | |||
@@ -186,7 +259,6 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): | |||
super().__init__(hs) | |||
self.handler = hs.get_account_data_handler() | |||
self.clock = hs.get_clock() | |||
@staticmethod | |||
async def _serialize_payload(user_id: str, room_id: str, tag: str) -> JsonDict: # type: ignore[override] | |||
@@ -206,7 +278,11 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): | |||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: | |||
ReplicationUserAccountDataRestServlet(hs).register(http_server) | |||
ReplicationRoomAccountDataRestServlet(hs).register(http_server) | |||
ReplicationAddUserAccountDataRestServlet(hs).register(http_server) | |||
ReplicationAddRoomAccountDataRestServlet(hs).register(http_server) | |||
ReplicationAddTagRestServlet(hs).register(http_server) | |||
ReplicationRemoveTagRestServlet(hs).register(http_server) | |||
if hs.config.experimental.msc3391_enabled: | |||
ReplicationRemoveUserAccountDataRestServlet(hs).register(http_server) | |||
ReplicationRemoveRoomAccountDataRestServlet(hs).register(http_server) |
@@ -41,6 +41,7 @@ class AccountDataServlet(RestServlet): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__() | |||
self._hs = hs | |||
self.auth = hs.get_auth() | |||
self.store = hs.get_datastores().main | |||
self.handler = hs.get_account_data_handler() | |||
@@ -54,6 +55,16 @@ class AccountDataServlet(RestServlet): | |||
body = parse_json_object_from_request(request) | |||
# If experimental support for MSC3391 is enabled, then providing an empty dict | |||
# as the value for an account data type should be functionally equivalent to | |||
# calling the DELETE method on the same type. | |||
if self._hs.config.experimental.msc3391_enabled: | |||
if body == {}: | |||
await self.handler.remove_account_data_for_user( | |||
user_id, account_data_type | |||
) | |||
return 200, {} | |||
await self.handler.add_account_data_for_user(user_id, account_data_type, body) | |||
return 200, {} | |||
@@ -72,9 +83,48 @@ class AccountDataServlet(RestServlet): | |||
if event is None: | |||
raise NotFoundError("Account data not found") | |||
# If experimental support for MSC3391 is enabled, then this endpoint should | |||
# return a 404 if the content for an account data type is an empty dict. | |||
if self._hs.config.experimental.msc3391_enabled and event == {}: | |||
raise NotFoundError("Account data not found") | |||
return 200, event | |||
class UnstableAccountDataServlet(RestServlet): | |||
""" | |||
Contains an unstable endpoint for removing user account data, as specified by | |||
MSC3391. If that MSC is accepted, this code should have unstable prefixes removed | |||
and become incorporated into AccountDataServlet above. | |||
""" | |||
PATTERNS = client_patterns( | |||
"/org.matrix.msc3391/user/(?P<user_id>[^/]*)" | |||
"/account_data/(?P<account_data_type>[^/]*)", | |||
unstable=True, | |||
releases=(), | |||
) | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__() | |||
self.auth = hs.get_auth() | |||
self.handler = hs.get_account_data_handler() | |||
async def on_DELETE( | |||
self, | |||
request: SynapseRequest, | |||
user_id: str, | |||
account_data_type: str, | |||
) -> Tuple[int, JsonDict]: | |||
requester = await self.auth.get_user_by_req(request) | |||
if user_id != requester.user.to_string(): | |||
raise AuthError(403, "Cannot delete account data for other users.") | |||
await self.handler.remove_account_data_for_user(user_id, account_data_type) | |||
return 200, {} | |||
class RoomAccountDataServlet(RestServlet): | |||
""" | |||
PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 | |||
@@ -89,6 +139,7 @@ class RoomAccountDataServlet(RestServlet): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__() | |||
self._hs = hs | |||
self.auth = hs.get_auth() | |||
self.store = hs.get_datastores().main | |||
self.handler = hs.get_account_data_handler() | |||
@@ -121,6 +172,16 @@ class RoomAccountDataServlet(RestServlet): | |||
Codes.BAD_JSON, | |||
) | |||
# If experimental support for MSC3391 is enabled, then providing an empty dict | |||
# as the value for an account data type should be functionally equivalent to | |||
# calling the DELETE method on the same type. | |||
if self._hs.config.experimental.msc3391_enabled: | |||
if body == {}: | |||
await self.handler.remove_account_data_for_room( | |||
user_id, room_id, account_data_type | |||
) | |||
return 200, {} | |||
await self.handler.add_account_data_to_room( | |||
user_id, room_id, account_data_type, body | |||
) | |||
@@ -152,9 +213,63 @@ class RoomAccountDataServlet(RestServlet): | |||
if event is None: | |||
raise NotFoundError("Room account data not found") | |||
# If experimental support for MSC3391 is enabled, then this endpoint should | |||
# return a 404 if the content for an account data type is an empty dict. | |||
if self._hs.config.experimental.msc3391_enabled and event == {}: | |||
raise NotFoundError("Room account data not found") | |||
return 200, event | |||
class UnstableRoomAccountDataServlet(RestServlet): | |||
""" | |||
Contains an unstable endpoint for removing room account data, as specified by | |||
MSC3391. If that MSC is accepted, this code should have unstable prefixes removed | |||
and become incorporated into RoomAccountDataServlet above. | |||
""" | |||
PATTERNS = client_patterns( | |||
"/org.matrix.msc3391/user/(?P<user_id>[^/]*)" | |||
"/rooms/(?P<room_id>[^/]*)" | |||
"/account_data/(?P<account_data_type>[^/]*)", | |||
unstable=True, | |||
releases=(), | |||
) | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__() | |||
self.auth = hs.get_auth() | |||
self.handler = hs.get_account_data_handler() | |||
async def on_DELETE( | |||
self, | |||
request: SynapseRequest, | |||
user_id: str, | |||
room_id: str, | |||
account_data_type: str, | |||
) -> Tuple[int, JsonDict]: | |||
requester = await self.auth.get_user_by_req(request) | |||
if user_id != requester.user.to_string(): | |||
raise AuthError(403, "Cannot delete account data for other users.") | |||
if not RoomID.is_valid(room_id): | |||
raise SynapseError( | |||
400, | |||
f"{room_id} is not a valid room ID", | |||
Codes.INVALID_PARAM, | |||
) | |||
await self.handler.remove_account_data_for_room( | |||
user_id, room_id, account_data_type | |||
) | |||
return 200, {} | |||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: | |||
AccountDataServlet(hs).register(http_server) | |||
RoomAccountDataServlet(hs).register(http_server) | |||
if hs.config.experimental.msc3391_enabled: | |||
UnstableAccountDataServlet(hs).register(http_server) | |||
UnstableRoomAccountDataServlet(hs).register(http_server) |
@@ -1762,7 +1762,8 @@ class DatabasePool: | |||
desc: description of the transaction, for logging and metrics | |||
Returns: | |||
A list of dictionaries. | |||
A list of dictionaries, one per result row, each a mapping between the | |||
column names from `retcols` and that column's value for the row. | |||
""" | |||
return await self.runInteraction( | |||
desc, | |||
@@ -1791,6 +1792,10 @@ class DatabasePool: | |||
column names and values to select the rows with, or None to not | |||
apply a WHERE clause. | |||
retcols: the names of the columns to return | |||
Returns: | |||
A list of dictionaries, one per result row, each a mapping between the | |||
column names from `retcols` and that column's value for the row. | |||
""" | |||
if keyvalues: | |||
sql = "SELECT %s FROM %s WHERE %s" % ( | |||
@@ -1898,6 +1903,19 @@ class DatabasePool: | |||
updatevalues: Dict[str, Any], | |||
desc: str, | |||
) -> int: | |||
""" | |||
Update rows in the given database table. | |||
If the given keyvalues don't match anything, nothing will be updated. | |||
Args: | |||
table: The database table to update. | |||
keyvalues: A mapping of column name to value to match rows on. | |||
updatevalues: A mapping of column name to value to replace in any matched rows. | |||
desc: description of the transaction, for logging and metrics. | |||
Returns: | |||
The number of rows that were updated. Will be 0 if no matching rows were found. | |||
""" | |||
return await self.runInteraction( | |||
desc, self.simple_update_txn, table, keyvalues, updatevalues | |||
) | |||
@@ -1909,6 +1927,19 @@ class DatabasePool: | |||
keyvalues: Dict[str, Any], | |||
updatevalues: Dict[str, Any], | |||
) -> int: | |||
""" | |||
Update rows in the given database table. | |||
If the given keyvalues don't match anything, nothing will be updated. | |||
Args: | |||
txn: The database transaction object. | |||
table: The database table to update. | |||
keyvalues: A mapping of column name to value to match rows on. | |||
updatevalues: A mapping of column name to value to replace in any matched rows. | |||
Returns: | |||
The number of rows that were updated. Will be 0 if no matching rows were found. | |||
""" | |||
if keyvalues: | |||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys()) | |||
else: | |||
@@ -123,7 +123,11 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
async def get_account_data_for_user( | |||
self, user_id: str | |||
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: | |||
"""Get all the client account_data for a user. | |||
""" | |||
Get all the client account_data for a user. | |||
If experimental MSC3391 support is enabled, any entries with an empty | |||
content body are excluded; as this means they have been deleted. | |||
Args: | |||
user_id: The user to get the account_data for. | |||
@@ -135,27 +139,48 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
def get_account_data_for_user_txn( | |||
txn: LoggingTransaction, | |||
) -> Tuple[Dict[str, JsonDict], Dict[str, Dict[str, JsonDict]]]: | |||
rows = self.db_pool.simple_select_list_txn( | |||
txn, | |||
"account_data", | |||
{"user_id": user_id}, | |||
["account_data_type", "content"], | |||
) | |||
# The 'content != '{}' condition below prevents us from using | |||
# `simple_select_list_txn` here, as it doesn't support conditions | |||
# other than 'equals'. | |||
sql = """ | |||
SELECT account_data_type, content FROM account_data | |||
WHERE user_id = ? | |||
""" | |||
# If experimental MSC3391 support is enabled, then account data entries | |||
# with an empty content are considered "deleted". So skip adding them to | |||
# the results. | |||
if self.hs.config.experimental.msc3391_enabled: | |||
sql += " AND content != '{}'" | |||
txn.execute(sql, (user_id,)) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
global_account_data = { | |||
row["account_data_type"]: db_to_json(row["content"]) for row in rows | |||
} | |||
rows = self.db_pool.simple_select_list_txn( | |||
txn, | |||
"room_account_data", | |||
{"user_id": user_id}, | |||
["room_id", "account_data_type", "content"], | |||
) | |||
# The 'content != '{}' condition below prevents us from using | |||
# `simple_select_list_txn` here, as it doesn't support conditions | |||
# other than 'equals'. | |||
sql = """ | |||
SELECT room_id, account_data_type, content FROM room_account_data | |||
WHERE user_id = ? | |||
""" | |||
# If experimental MSC3391 support is enabled, then account data entries | |||
# with an empty content are considered "deleted". So skip adding them to | |||
# the results. | |||
if self.hs.config.experimental.msc3391_enabled: | |||
sql += " AND content != '{}'" | |||
txn.execute(sql, (user_id,)) | |||
rows = self.db_pool.cursor_to_dict(txn) | |||
by_room: Dict[str, Dict[str, JsonDict]] = {} | |||
for row in rows: | |||
room_data = by_room.setdefault(row["room_id"], {}) | |||
room_data[row["account_data_type"]] = db_to_json(row["content"]) | |||
return global_account_data, by_room | |||
@@ -469,6 +494,72 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
return self._account_data_id_gen.get_current_token() | |||
async def remove_account_data_for_room( | |||
self, user_id: str, room_id: str, account_data_type: str | |||
) -> Optional[int]: | |||
"""Delete the room account data for the user of a given type. | |||
Args: | |||
user_id: The user to remove account_data for. | |||
room_id: The room ID to scope the request to. | |||
account_data_type: The account data type to delete. | |||
Returns: | |||
The maximum stream position, or None if there was no matching room account | |||
data to delete. | |||
""" | |||
assert self._can_write_to_account_data | |||
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) | |||
def _remove_account_data_for_room_txn( | |||
txn: LoggingTransaction, next_id: int | |||
) -> bool: | |||
""" | |||
Args: | |||
txn: The transaction object. | |||
next_id: The stream_id to update any existing rows to. | |||
Returns: | |||
True if an entry in room_account_data had its content set to '{}', | |||
otherwise False. This informs callers of whether there actually was an | |||
existing room account data entry to delete, or if the call was a no-op. | |||
""" | |||
# We can't use `simple_update` as it doesn't have the ability to specify | |||
# where clauses other than '=', which we need for `content != '{}'` below. | |||
sql = """ | |||
UPDATE room_account_data | |||
SET stream_id = ?, content = '{}' | |||
WHERE user_id = ? | |||
AND room_id = ? | |||
AND account_data_type = ? | |||
AND content != '{}' | |||
""" | |||
txn.execute( | |||
sql, | |||
(next_id, user_id, room_id, account_data_type), | |||
) | |||
# Return true if any rows were updated. | |||
return txn.rowcount != 0 | |||
async with self._account_data_id_gen.get_next() as next_id: | |||
row_updated = await self.db_pool.runInteraction( | |||
"remove_account_data_for_room", | |||
_remove_account_data_for_room_txn, | |||
next_id, | |||
) | |||
if not row_updated: | |||
return None | |||
self._account_data_stream_cache.entity_has_changed(user_id, next_id) | |||
self.get_account_data_for_user.invalidate((user_id,)) | |||
self.get_account_data_for_room.invalidate((user_id, room_id)) | |||
self.get_account_data_for_room_and_type.prefill( | |||
(user_id, room_id, account_data_type), {} | |||
) | |||
return self._account_data_id_gen.get_current_token() | |||
async def add_account_data_for_user( | |||
self, user_id: str, account_data_type: str, content: JsonDict | |||
) -> int: | |||
@@ -569,6 +660,108 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore) | |||
self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) | |||
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) | |||
async def remove_account_data_for_user( | |||
self, | |||
user_id: str, | |||
account_data_type: str, | |||
) -> Optional[int]: | |||
""" | |||
Delete a single piece of user account data by type. | |||
A "delete" is performed by updating a potentially existing row in the | |||
"account_data" database table for (user_id, account_data_type) and | |||
setting its content to "{}". | |||
Args: | |||
user_id: The user ID to modify the account data of. | |||
account_data_type: The type to remove. | |||
Returns: | |||
The maximum stream position, or None if there was no matching account data | |||
to delete. | |||
""" | |||
assert self._can_write_to_account_data | |||
assert isinstance(self._account_data_id_gen, AbstractStreamIdGenerator) | |||
def _remove_account_data_for_user_txn( | |||
txn: LoggingTransaction, next_id: int | |||
) -> bool: | |||
""" | |||
Args: | |||
txn: The transaction object. | |||
next_id: The stream_id to update any existing rows to. | |||
Returns: | |||
True if an entry in account_data had its content set to '{}', otherwise | |||
False. This informs callers of whether there actually was an existing | |||
account data entry to delete, or if the call was a no-op. | |||
""" | |||
# We can't use `simple_update` as it doesn't have the ability to specify | |||
# where clauses other than '=', which we need for `content != '{}'` below. | |||
sql = """ | |||
UPDATE account_data | |||
SET stream_id = ?, content = '{}' | |||
WHERE user_id = ? | |||
AND account_data_type = ? | |||
AND content != '{}' | |||
""" | |||
txn.execute(sql, (next_id, user_id, account_data_type)) | |||
if txn.rowcount == 0: | |||
# We didn't update any rows. This means that there was no matching room | |||
# account data entry to delete in the first place. | |||
return False | |||
# Ignored users get denormalized into a separate table as an optimisation. | |||
if account_data_type == AccountDataTypes.IGNORED_USER_LIST: | |||
# If this method was called with the ignored users account data type, we | |||
# simply delete all ignored users. | |||
# First pull all the users that this user ignores. | |||
previously_ignored_users = set( | |||
self.db_pool.simple_select_onecol_txn( | |||
txn, | |||
table="ignored_users", | |||
keyvalues={"ignorer_user_id": user_id}, | |||
retcol="ignored_user_id", | |||
) | |||
) | |||
# Then delete them from the database. | |||
self.db_pool.simple_delete_txn( | |||
txn, | |||
table="ignored_users", | |||
keyvalues={"ignorer_user_id": user_id}, | |||
) | |||
# Invalidate the cache for ignored users which were removed. | |||
for ignored_user_id in previously_ignored_users: | |||
self._invalidate_cache_and_stream( | |||
txn, self.ignored_by, (ignored_user_id,) | |||
) | |||
# Invalidate for this user the cache tracking ignored users. | |||
self._invalidate_cache_and_stream(txn, self.ignored_users, (user_id,)) | |||
return True | |||
async with self._account_data_id_gen.get_next() as next_id: | |||
row_updated = await self.db_pool.runInteraction( | |||
"remove_account_data_for_user", | |||
_remove_account_data_for_user_txn, | |||
next_id, | |||
) | |||
if not row_updated: | |||
return None | |||
self._account_data_stream_cache.entity_has_changed(user_id, next_id) | |||
self.get_account_data_for_user.invalidate((user_id,)) | |||
self.get_global_account_data_by_type_for_user.prefill( | |||
(user_id, account_data_type), {} | |||
) | |||
return self._account_data_id_gen.get_current_token() | |||
async def purge_account_data_for_user(self, user_id: str) -> None: | |||
""" | |||
Removes ALL the account data for a user. | |||