@@ -0,0 +1 @@ | |||
Delete device messages asynchronously and in staged batches using the task scheduler. |
@@ -43,9 +43,12 @@ from synapse.metrics.background_process_metrics import ( | |||
) | |||
from synapse.types import ( | |||
JsonDict, | |||
JsonMapping, | |||
ScheduledTask, | |||
StrCollection, | |||
StreamKeyType, | |||
StreamToken, | |||
TaskStatus, | |||
UserID, | |||
get_domain_from_id, | |||
get_verify_key_from_cross_signing_key, | |||
@@ -62,6 +65,7 @@ if TYPE_CHECKING: | |||
logger = logging.getLogger(__name__) | |||
DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages" | |||
MAX_DEVICE_DISPLAY_NAME_LEN = 100 | |||
DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000 | |||
@@ -78,6 +82,7 @@ class DeviceWorkerHandler: | |||
self._appservice_handler = hs.get_application_service_handler() | |||
self._state_storage = hs.get_storage_controllers().state | |||
self._auth_handler = hs.get_auth_handler() | |||
self._event_sources = hs.get_event_sources() | |||
self.server_name = hs.hostname | |||
self._msc3852_enabled = hs.config.experimental.msc3852_enabled | |||
self._query_appservices_for_keys = ( | |||
@@ -386,6 +391,7 @@ class DeviceHandler(DeviceWorkerHandler): | |||
self._account_data_handler = hs.get_account_data_handler() | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self.db_pool = hs.get_datastores().main.db_pool | |||
self._task_scheduler = hs.get_task_scheduler() | |||
self.device_list_updater = DeviceListUpdater(hs, self) | |||
@@ -419,6 +425,10 @@ class DeviceHandler(DeviceWorkerHandler): | |||
self._delete_stale_devices, | |||
) | |||
self._task_scheduler.register_action( | |||
self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME | |||
) | |||
def _check_device_name_length(self, name: Optional[str]) -> None: | |||
""" | |||
Checks whether a device name is longer than the maximum allowed length. | |||
@@ -530,6 +540,7 @@ class DeviceHandler(DeviceWorkerHandler): | |||
user_id: The user to delete devices from. | |||
device_ids: The list of device IDs to delete | |||
""" | |||
to_device_stream_id = self._event_sources.get_current_token().to_device_key | |||
try: | |||
await self.store.delete_devices(user_id, device_ids) | |||
@@ -559,12 +570,49 @@ class DeviceHandler(DeviceWorkerHandler): | |||
f"org.matrix.msc3890.local_notification_settings.{device_id}", | |||
) | |||
# Delete device messages asynchronously and in batches using the task scheduler | |||
await self._task_scheduler.schedule_task( | |||
DELETE_DEVICE_MSGS_TASK_NAME, | |||
resource_id=device_id, | |||
params={ | |||
"user_id": user_id, | |||
"device_id": device_id, | |||
"up_to_stream_id": to_device_stream_id, | |||
}, | |||
) | |||
# Pushers are deleted after `delete_access_tokens_for_user` is called so that | |||
# modules using `on_logged_out` hook can use them if needed. | |||
await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids) | |||
await self.notify_device_update(user_id, device_ids) | |||
DEVICE_MSGS_DELETE_BATCH_LIMIT = 100 | |||
async def _delete_device_messages( | |||
self, | |||
task: ScheduledTask, | |||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: | |||
"""Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`.""" | |||
assert task.params is not None | |||
user_id = task.params["user_id"] | |||
device_id = task.params["device_id"] | |||
up_to_stream_id = task.params["up_to_stream_id"] | |||
res = await self.store.delete_messages_for_device( | |||
user_id=user_id, | |||
device_id=device_id, | |||
up_to_stream_id=up_to_stream_id, | |||
limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT, | |||
) | |||
if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT: | |||
return TaskStatus.COMPLETE, None, None | |||
else: | |||
# There is probably still device messages to be deleted, let's keep the task active and it will be run | |||
# again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running). | |||
return TaskStatus.ACTIVE, None, None | |||
async def update_device(self, user_id: str, device_id: str, content: dict) -> None: | |||
"""Update the given device | |||
@@ -183,6 +183,7 @@ class BasePresenceHandler(abc.ABC): | |||
writer""" | |||
def __init__(self, hs: "HomeServer"): | |||
self.hs = hs | |||
self.clock = hs.get_clock() | |||
self.store = hs.get_datastores().main | |||
self._storage_controllers = hs.get_storage_controllers() | |||
@@ -473,8 +474,6 @@ class _NullContextManager(ContextManager[None]): | |||
class WorkerPresenceHandler(BasePresenceHandler): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.hs = hs | |||
self._presence_writer_instance = hs.config.worker.writers.presence[0] | |||
# Route presence EDUs to the right worker | |||
@@ -738,7 +737,6 @@ class WorkerPresenceHandler(BasePresenceHandler): | |||
class PresenceHandler(BasePresenceHandler): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.hs = hs | |||
self.wheel_timer: WheelTimer[str] = WheelTimer() | |||
self.notifier = hs.get_notifier() | |||
@@ -40,6 +40,7 @@ from synapse.api.filtering import FilterCollection | |||
from synapse.api.presence import UserPresenceState | |||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS | |||
from synapse.events import EventBase | |||
from synapse.handlers.device import DELETE_DEVICE_MSGS_TASK_NAME | |||
from synapse.handlers.relations import BundledAggregations | |||
from synapse.logging import issue9533_logger | |||
from synapse.logging.context import current_context | |||
@@ -268,6 +269,7 @@ class SyncHandler: | |||
self._storage_controllers = hs.get_storage_controllers() | |||
self._state_storage_controller = self._storage_controllers.state | |||
self._device_handler = hs.get_device_handler() | |||
self._task_scheduler = hs.get_task_scheduler() | |||
self.should_calculate_push_rules = hs.config.push.enable_push | |||
@@ -360,11 +362,19 @@ class SyncHandler: | |||
# (since we now know that the device has received them) | |||
if since_token is not None: | |||
since_stream_id = since_token.to_device_key | |||
deleted = await self.store.delete_messages_for_device( | |||
sync_config.user.to_string(), sync_config.device_id, since_stream_id | |||
# Delete device messages asynchronously and in batches using the task scheduler | |||
await self._task_scheduler.schedule_task( | |||
DELETE_DEVICE_MSGS_TASK_NAME, | |||
resource_id=sync_config.device_id, | |||
params={ | |||
"user_id": sync_config.user.to_string(), | |||
"device_id": sync_config.device_id, | |||
"up_to_stream_id": since_stream_id, | |||
}, | |||
) | |||
logger.debug( | |||
"Deleted %d to-device messages up to %d", deleted, since_stream_id | |||
"Deletion of to-device messages up to %d scheduled", | |||
since_stream_id, | |||
) | |||
if timeout == 0 or since_token is None or full_state: | |||
@@ -445,13 +445,18 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
@trace | |||
async def delete_messages_for_device( | |||
self, user_id: str, device_id: Optional[str], up_to_stream_id: int | |||
self, | |||
user_id: str, | |||
device_id: Optional[str], | |||
up_to_stream_id: int, | |||
limit: int, | |||
) -> int: | |||
""" | |||
Args: | |||
user_id: The recipient user_id. | |||
device_id: The recipient device_id. | |||
up_to_stream_id: Where to delete messages up to. | |||
limit: maximum number of messages to delete | |||
Returns: | |||
The number of messages deleted. | |||
@@ -472,12 +477,16 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
log_kv({"message": "No changes in cache since last check"}) | |||
return 0 | |||
ROW_ID_NAME = self.database_engine.row_id_name | |||
def delete_messages_for_device_txn(txn: LoggingTransaction) -> int: | |||
sql = ( | |||
"DELETE FROM device_inbox" | |||
" WHERE user_id = ? AND device_id = ?" | |||
" AND stream_id <= ?" | |||
) | |||
sql = f""" | |||
DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN ( | |||
SELECT {ROW_ID_NAME} FROM device_inbox | |||
WHERE user_id = ? AND device_id = ? AND stream_id <= ? | |||
LIMIT {limit} | |||
) | |||
""" | |||
txn.execute(sql, (user_id, device_id, up_to_stream_id)) | |||
return txn.rowcount | |||
@@ -487,6 +496,11 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
log_kv({"message": f"deleted {count} messages for device", "count": count}) | |||
# In this case we don't know if we hit the limit or the delete is complete | |||
# so let's not update the cache. | |||
if count == limit: | |||
return count | |||
# Update the cache, ensuring that we only ever increase the value | |||
updated_last_deleted_stream_id = self._last_device_delete_cache.get( | |||
(user_id, device_id), 0 | |||
@@ -1766,14 +1766,6 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
keyvalues={"user_id": user_id, "hidden": False}, | |||
) | |||
self.db_pool.simple_delete_many_txn( | |||
txn, | |||
table="device_inbox", | |||
column="device_id", | |||
values=device_ids, | |||
keyvalues={"user_id": user_id}, | |||
) | |||
self.db_pool.simple_delete_many_txn( | |||
txn, | |||
table="device_auth_providers", | |||
@@ -939,11 +939,7 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore): | |||
receipts.""" | |||
def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None: | |||
if isinstance(self.database_engine, PostgresEngine): | |||
ROW_ID_NAME = "ctid" | |||
else: | |||
ROW_ID_NAME = "rowid" | |||
ROW_ID_NAME = self.database_engine.row_id_name | |||
# Identify any duplicate receipts arising from | |||
# https://github.com/matrix-org/synapse/issues/14406. | |||
# The following query takes less than a minute on matrix.org. | |||
@@ -100,6 +100,12 @@ class BaseDatabaseEngine(Generic[ConnectionType, CursorType], metaclass=abc.ABCM | |||
"""Gets a string giving the server version. For example: '3.22.0'""" | |||
... | |||
@property | |||
@abc.abstractmethod | |||
def row_id_name(self) -> str: | |||
"""Gets the literal name representing a row id for this engine.""" | |||
... | |||
@abc.abstractmethod | |||
def in_transaction(self, conn: ConnectionType) -> bool: | |||
"""Whether the connection is currently in a transaction.""" | |||
@@ -211,6 +211,10 @@ class PostgresEngine( | |||
else: | |||
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100) | |||
@property | |||
def row_id_name(self) -> str: | |||
return "ctid" | |||
def in_transaction(self, conn: psycopg2.extensions.connection) -> bool: | |||
return conn.status != psycopg2.extensions.STATUS_READY | |||
@@ -123,6 +123,10 @@ class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection, sqlite3.Cursor]): | |||
"""Gets a string giving the server version. For example: '3.22.0'.""" | |||
return "%i.%i.%i" % sqlite3.sqlite_version_info | |||
@property | |||
def row_id_name(self) -> str: | |||
return "rowid" | |||
def in_transaction(self, conn: sqlite3.Connection) -> bool: | |||
return conn.in_transaction | |||
@@ -14,7 +14,7 @@ | |||
from synapse.storage.database import LoggingTransaction | |||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine | |||
from synapse.storage.engines import BaseDatabaseEngine | |||
from synapse.storage.prepare_database import get_statements | |||
FIX_INDEXES = """ | |||
@@ -37,7 +37,7 @@ CREATE INDEX group_rooms_r_idx ON group_rooms(room_id); | |||
def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None: | |||
rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid" | |||
rowid = database_engine.row_id_name | |||
# remove duplicates from group_users & group_invites tables | |||
cur.execute( | |||
@@ -77,6 +77,7 @@ class TaskScheduler: | |||
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs | |||
def __init__(self, hs: "HomeServer"): | |||
self._hs = hs | |||
self._store = hs.get_datastores().main | |||
self._clock = hs.get_clock() | |||
self._running_tasks: Set[str] = set() | |||
@@ -97,8 +98,6 @@ class TaskScheduler: | |||
"handle_scheduled_tasks", | |||
self._handle_scheduled_tasks, | |||
) | |||
else: | |||
self.replication_client = hs.get_replication_command_handler() | |||
def register_action( | |||
self, | |||
@@ -133,7 +132,7 @@ class TaskScheduler: | |||
params: Optional[JsonMapping] = None, | |||
) -> str: | |||
"""Schedule a new potentially resumable task. A function matching the specified | |||
`action` should have been previously registered with `register_action`. | |||
`action` should have be registered with `register_action` before the task is run. | |||
Args: | |||
action: the name of a previously registered action | |||
@@ -149,11 +148,6 @@ class TaskScheduler: | |||
Returns: | |||
The id of the scheduled task | |||
""" | |||
if action not in self._actions: | |||
raise Exception( | |||
f"No function associated with action {action} of the scheduled task" | |||
) | |||
status = TaskStatus.SCHEDULED | |||
if timestamp is None or timestamp < self._clock.time_msec(): | |||
timestamp = self._clock.time_msec() | |||
@@ -175,7 +169,7 @@ class TaskScheduler: | |||
if self._run_background_tasks: | |||
await self._launch_task(task) | |||
else: | |||
self.replication_client.send_new_active_task(task.id) | |||
self._hs.get_replication_command_handler().send_new_active_task(task.id) | |||
return task.id | |||
@@ -315,7 +309,10 @@ class TaskScheduler: | |||
""" | |||
assert self._run_background_tasks | |||
assert task.action in self._actions | |||
if task.action not in self._actions: | |||
raise Exception( | |||
f"No function associated with action {task.action} of the scheduled task {task.id}" | |||
) | |||
function = self._actions[task.action] | |||
async def wrapper() -> None: | |||
@@ -30,6 +30,7 @@ from synapse.server import HomeServer | |||
from synapse.storage.databases.main.appservice import _make_exclusive_regex | |||
from synapse.types import JsonDict, create_requester | |||
from synapse.util import Clock | |||
from synapse.util.task_scheduler import TaskScheduler | |||
from tests import unittest | |||
from tests.unittest import override_config | |||
@@ -49,6 +50,7 @@ class DeviceTestCase(unittest.HomeserverTestCase): | |||
assert isinstance(handler, DeviceHandler) | |||
self.handler = handler | |||
self.store = hs.get_datastores().main | |||
self.device_message_handler = hs.get_device_message_handler() | |||
return hs | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
@@ -211,6 +213,51 @@ class DeviceTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertIsNone(res) | |||
def test_delete_device_and_big_device_inbox(self) -> None: | |||
"""Check that deleting a big device inbox is staged and batched asynchronously.""" | |||
DEVICE_ID = "abc" | |||
sender = "@sender:" + self.hs.hostname | |||
receiver = "@receiver:" + self.hs.hostname | |||
self._record_user(sender, DEVICE_ID, DEVICE_ID) | |||
self._record_user(receiver, DEVICE_ID, DEVICE_ID) | |||
# queue a bunch of messages in the inbox | |||
requester = create_requester(sender, device_id=DEVICE_ID) | |||
for i in range(0, DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10): | |||
self.get_success( | |||
self.device_message_handler.send_device_message( | |||
requester, "message_type", {receiver: {"*": {"val": i}}} | |||
) | |||
) | |||
# delete the device | |||
self.get_success(self.handler.delete_devices(receiver, [DEVICE_ID])) | |||
# messages should be deleted up to DEVICE_MSGS_DELETE_BATCH_LIMIT straight away | |||
res = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="device_inbox", | |||
keyvalues={"user_id": receiver}, | |||
retcols=("user_id", "device_id", "stream_id"), | |||
desc="get_device_id_from_device_inbox", | |||
) | |||
) | |||
self.assertEqual(10, len(res)) | |||
# wait for the task scheduler to do a second delete pass | |||
self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000) | |||
# remaining messages should now be deleted | |||
res = self.get_success( | |||
self.store.db_pool.simple_select_list( | |||
table="device_inbox", | |||
keyvalues={"user_id": receiver}, | |||
retcols=("user_id", "device_id", "stream_id"), | |||
desc="get_device_id_from_device_inbox", | |||
) | |||
) | |||
self.assertEqual(0, len(res)) | |||
def test_update_device(self) -> None: | |||
self._record_users() | |||