@@ -0,0 +1 @@ | |||
Task scheduler: add replication notify for new task to launch ASAP. |
@@ -452,6 +452,17 @@ class LockReleasedCommand(Command): | |||
return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key]) | |||
class NewActiveTaskCommand(_SimpleCommand): | |||
"""Sent to inform instance handling background tasks that a new active task is available to run. | |||
Format:: | |||
NEW_ACTIVE_TASK "<task_id>" | |||
""" | |||
NAME = "NEW_ACTIVE_TASK" | |||
_COMMANDS: Tuple[Type[Command], ...] = ( | |||
ServerCommand, | |||
RdataCommand, | |||
@@ -466,6 +477,7 @@ _COMMANDS: Tuple[Type[Command], ...] = ( | |||
RemoteServerUpCommand, | |||
ClearUserSyncsCommand, | |||
LockReleasedCommand, | |||
NewActiveTaskCommand, | |||
) | |||
# Map of command name to command type. | |||
@@ -40,6 +40,7 @@ from synapse.replication.tcp.commands import ( | |||
Command, | |||
FederationAckCommand, | |||
LockReleasedCommand, | |||
NewActiveTaskCommand, | |||
PositionCommand, | |||
RdataCommand, | |||
RemoteServerUpCommand, | |||
@@ -238,6 +239,10 @@ class ReplicationCommandHandler: | |||
if self._is_master: | |||
self._server_notices_sender = hs.get_server_notices_sender() | |||
self._task_scheduler = None | |||
if hs.config.worker.run_background_tasks: | |||
self._task_scheduler = hs.get_task_scheduler() | |||
if hs.config.redis.redis_enabled: | |||
# If we're using Redis, it's the background worker that should | |||
# receive USER_IP commands and store the relevant client IPs. | |||
@@ -663,6 +668,15 @@ class ReplicationCommandHandler: | |||
cmd.instance_name, cmd.lock_name, cmd.lock_key | |||
) | |||
async def on_NEW_ACTIVE_TASK( | |||
self, conn: IReplicationConnection, cmd: NewActiveTaskCommand | |||
) -> None: | |||
"""Called when get a new NEW_ACTIVE_TASK command.""" | |||
if self._task_scheduler: | |||
task = await self._task_scheduler.get_task(cmd.data) | |||
if task: | |||
await self._task_scheduler._launch_task(task) | |||
def new_connection(self, connection: IReplicationConnection) -> None: | |||
"""Called when we have a new connection.""" | |||
self._connections.append(connection) | |||
@@ -776,6 +790,10 @@ class ReplicationCommandHandler: | |||
if instance_name == self._instance_name: | |||
self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key)) | |||
def send_new_active_task(self, task_id: str) -> None: | |||
"""Called when a new task has been scheduled for immediate launch and is ACTIVE.""" | |||
self.send_command(NewActiveTaskCommand(task_id)) | |||
UpdateToken = TypeVar("UpdateToken") | |||
UpdateRow = TypeVar("UpdateRow") | |||
@@ -57,14 +57,13 @@ class TaskScheduler: | |||
the code launching the task. | |||
You can also specify the `result` (and/or an `error`) when returning from the function. | |||
The reconciliation loop runs every 5 mns, so this is not a precise scheduler. When wanting | |||
to launch now, the launch will still not happen before the next loop run. | |||
Tasks will be run on the worker specified with `run_background_tasks_on` config, | |||
or the main one by default. | |||
The reconciliation loop runs every minute, so this is not a precise scheduler. | |||
There is a limit of 10 concurrent tasks, so tasks may be delayed if the pool is already | |||
full. In this regard, please take great care that scheduled tasks can actually finished. | |||
For now there is no mechanism to stop a running task if it is stuck. | |||
Tasks will be run on the worker specified with `run_background_tasks_on` config, | |||
or the main one by default. | |||
""" | |||
# Precision of the scheduler, evaluation of tasks to run will only happen | |||
@@ -85,7 +84,7 @@ class TaskScheduler: | |||
self._actions: Dict[ | |||
str, | |||
Callable[ | |||
[ScheduledTask, bool], | |||
[ScheduledTask], | |||
Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], | |||
], | |||
] = {} | |||
@@ -98,11 +97,13 @@ class TaskScheduler: | |||
"handle_scheduled_tasks", | |||
self._handle_scheduled_tasks, | |||
) | |||
else: | |||
self.replication_client = hs.get_replication_command_handler() | |||
def register_action( | |||
self, | |||
function: Callable[ | |||
[ScheduledTask, bool], | |||
[ScheduledTask], | |||
Awaitable[Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]], | |||
], | |||
action_name: str, | |||
@@ -115,10 +116,9 @@ class TaskScheduler: | |||
calling `schedule_task` but rather in an `__init__` method. | |||
Args: | |||
function: The function to be executed for this action. The parameters | |||
passed to the function when launched are the `ScheduledTask` being run, | |||
and a `first_launch` boolean to signal if it's a resumed task or the first | |||
launch of it. The function should return a tuple of new `status`, `result` | |||
function: The function to be executed for this action. The parameter | |||
passed to the function when launched is the `ScheduledTask` being run. | |||
The function should return a tuple of new `status`, `result` | |||
and `error` as specified in `ScheduledTask`. | |||
action_name: The name of the action to be associated with the function | |||
""" | |||
@@ -171,6 +171,12 @@ class TaskScheduler: | |||
) | |||
await self._store.insert_scheduled_task(task) | |||
if status == TaskStatus.ACTIVE: | |||
if self._run_background_tasks: | |||
await self._launch_task(task) | |||
else: | |||
self.replication_client.send_new_active_task(task.id) | |||
return task.id | |||
async def update_task( | |||
@@ -265,21 +271,13 @@ class TaskScheduler: | |||
Args: | |||
id: id of the task to delete | |||
""" | |||
if self.task_is_running(id): | |||
raise Exception(f"Task {id} is currently running and can't be deleted") | |||
task = await self.get_task(id) | |||
if task is None: | |||
raise Exception(f"Task {id} does not exist") | |||
if task.status == TaskStatus.ACTIVE: | |||
raise Exception(f"Task {id} is currently ACTIVE and can't be deleted") | |||
await self._store.delete_scheduled_task(id) | |||
def task_is_running(self, id: str) -> bool: | |||
"""Check if a task is currently running. | |||
Can only be called from the worker handling the task scheduling. | |||
Args: | |||
id: id of the task to check | |||
""" | |||
assert self._run_background_tasks | |||
return id in self._running_tasks | |||
async def _handle_scheduled_tasks(self) -> None: | |||
"""Main loop taking care of launching tasks and cleaning up old ones.""" | |||
await self._launch_scheduled_tasks() | |||
@@ -288,29 +286,11 @@ class TaskScheduler: | |||
async def _launch_scheduled_tasks(self) -> None: | |||
"""Retrieve and launch scheduled tasks that should be running at that time.""" | |||
for task in await self.get_tasks(statuses=[TaskStatus.ACTIVE]): | |||
if not self.task_is_running(task.id): | |||
if ( | |||
len(self._running_tasks) | |||
< TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS | |||
): | |||
await self._launch_task(task, first_launch=False) | |||
else: | |||
if ( | |||
self._clock.time_msec() | |||
> task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS | |||
): | |||
logger.warn( | |||
f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck" | |||
) | |||
await self._launch_task(task) | |||
for task in await self.get_tasks( | |||
statuses=[TaskStatus.SCHEDULED], max_timestamp=self._clock.time_msec() | |||
): | |||
if ( | |||
not self.task_is_running(task.id) | |||
and len(self._running_tasks) | |||
< TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS | |||
): | |||
await self._launch_task(task, first_launch=True) | |||
await self._launch_task(task) | |||
running_tasks_gauge.set(len(self._running_tasks)) | |||
@@ -320,27 +300,27 @@ class TaskScheduler: | |||
statuses=[TaskStatus.FAILED, TaskStatus.COMPLETE] | |||
): | |||
# FAILED and COMPLETE tasks should never be running | |||
assert not self.task_is_running(task.id) | |||
assert task.id not in self._running_tasks | |||
if ( | |||
self._clock.time_msec() | |||
> task.timestamp + TaskScheduler.KEEP_TASKS_FOR_MS | |||
): | |||
await self._store.delete_scheduled_task(task.id) | |||
async def _launch_task(self, task: ScheduledTask, first_launch: bool) -> None: | |||
async def _launch_task(self, task: ScheduledTask) -> None: | |||
"""Launch a scheduled task now. | |||
Args: | |||
task: the task to launch | |||
first_launch: `True` if it's the first time is launched, `False` otherwise | |||
""" | |||
assert task.action in self._actions | |||
assert self._run_background_tasks | |||
assert task.action in self._actions | |||
function = self._actions[task.action] | |||
async def wrapper() -> None: | |||
try: | |||
(status, result, error) = await function(task, first_launch) | |||
(status, result, error) = await function(task) | |||
except Exception: | |||
f = Failure() | |||
logger.error( | |||
@@ -360,6 +340,20 @@ class TaskScheduler: | |||
) | |||
self._running_tasks.remove(task.id) | |||
if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: | |||
return | |||
if ( | |||
self._clock.time_msec() | |||
> task.timestamp + TaskScheduler.LAST_UPDATE_BEFORE_WARNING_MS | |||
): | |||
logger.warn( | |||
f"Task {task.id} (action {task.action}) has seen no update for more than 24h and may be stuck" | |||
) | |||
if task.id in self._running_tasks: | |||
return | |||
self._running_tasks.add(task.id) | |||
await self.update_task(task.id, status=TaskStatus.ACTIVE) | |||
description = f"{task.id}-{task.action}" | |||
@@ -22,10 +22,11 @@ from synapse.types import JsonMapping, ScheduledTask, TaskStatus | |||
from synapse.util import Clock | |||
from synapse.util.task_scheduler import TaskScheduler | |||
from tests import unittest | |||
from tests.replication._base import BaseMultiWorkerStreamTestCase | |||
from tests.unittest import HomeserverTestCase, override_config | |||
class TestTaskScheduler(unittest.HomeserverTestCase): | |||
class TestTaskScheduler(HomeserverTestCase): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.task_scheduler = hs.get_task_scheduler() | |||
self.task_scheduler.register_action(self._test_task, "_test_task") | |||
@@ -34,7 +35,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase): | |||
self.task_scheduler.register_action(self._resumable_task, "_resumable_task") | |||
async def _test_task( | |||
self, task: ScheduledTask, first_launch: bool | |||
self, task: ScheduledTask | |||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: | |||
# This test task will copy the parameters to the result | |||
result = None | |||
@@ -77,7 +78,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase): | |||
self.assertIsNone(task) | |||
async def _sleeping_task( | |||
self, task: ScheduledTask, first_launch: bool | |||
self, task: ScheduledTask | |||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: | |||
# Sleep for a second | |||
await deferLater(self.reactor, 1, lambda: None) | |||
@@ -85,24 +86,18 @@ class TestTaskScheduler(unittest.HomeserverTestCase): | |||
def test_schedule_lot_of_tasks(self) -> None: | |||
"""Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior.""" | |||
timestamp = self.clock.time_msec() + 30 * 1000 | |||
task_ids = [] | |||
for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1): | |||
task_ids.append( | |||
self.get_success( | |||
self.task_scheduler.schedule_task( | |||
"_sleeping_task", | |||
timestamp=timestamp, | |||
params={"val": i}, | |||
) | |||
) | |||
) | |||
# The timestamp being 30s after now the task should been executed | |||
# after the first scheduling loop is run | |||
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) | |||
# This is to give the time to the sleeping tasks to finish | |||
# This is to give the time to the active tasks to finish | |||
self.reactor.advance(1) | |||
# Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one | |||
@@ -120,10 +115,11 @@ class TestTaskScheduler(unittest.HomeserverTestCase): | |||
) | |||
scheduled_tasks = [ | |||
t for t in tasks if t is not None and t.status == TaskStatus.SCHEDULED | |||
t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE | |||
] | |||
self.assertEquals(len(scheduled_tasks), 1) | |||
# We need to wait for the next run of the scheduler loop | |||
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) | |||
self.reactor.advance(1) | |||
@@ -138,7 +134,7 @@ class TestTaskScheduler(unittest.HomeserverTestCase): | |||
) | |||
async def _raising_task( | |||
self, task: ScheduledTask, first_launch: bool | |||
self, task: ScheduledTask | |||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: | |||
raise Exception("raising") | |||
@@ -146,15 +142,13 @@ class TestTaskScheduler(unittest.HomeserverTestCase): | |||
"""Schedule a task raising an exception and check it runs to failure and report exception content.""" | |||
task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task")) | |||
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) | |||
task = self.get_success(self.task_scheduler.get_task(task_id)) | |||
assert task is not None | |||
self.assertEqual(task.status, TaskStatus.FAILED) | |||
self.assertEqual(task.error, "raising") | |||
async def _resumable_task( | |||
self, task: ScheduledTask, first_launch: bool | |||
self, task: ScheduledTask | |||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: | |||
if task.result and "in_progress" in task.result: | |||
return TaskStatus.COMPLETE, {"success": True}, None | |||
@@ -169,8 +163,6 @@ class TestTaskScheduler(unittest.HomeserverTestCase): | |||
"""Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart.""" | |||
task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task")) | |||
self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) | |||
task = self.get_success(self.task_scheduler.get_task(task_id)) | |||
assert task is not None | |||
self.assertEqual(task.status, TaskStatus.ACTIVE) | |||
@@ -184,3 +176,33 @@ class TestTaskScheduler(unittest.HomeserverTestCase): | |||
self.assertEqual(task.status, TaskStatus.COMPLETE) | |||
assert task.result is not None | |||
self.assertTrue(task.result.get("success")) | |||
class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase): | |||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: | |||
self.task_scheduler = hs.get_task_scheduler() | |||
self.task_scheduler.register_action(self._test_task, "_test_task") | |||
async def _test_task( | |||
self, task: ScheduledTask | |||
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: | |||
return (TaskStatus.COMPLETE, None, None) | |||
@override_config({"run_background_tasks_on": "worker1"}) | |||
def test_schedule_task(self) -> None: | |||
"""Check that a task scheduled to run now is launch right away on the background worker.""" | |||
bg_worker_hs = self.make_worker_hs( | |||
"synapse.app.generic_worker", | |||
extra_config={"worker_name": "worker1"}, | |||
) | |||
bg_worker_hs.get_task_scheduler().register_action(self._test_task, "_test_task") | |||
task_id = self.get_success( | |||
self.task_scheduler.schedule_task( | |||
"_test_task", | |||
) | |||
) | |||
task = self.get_success(self.task_scheduler.get_task(task_id)) | |||
assert task is not None | |||
self.assertEqual(task.status, TaskStatus.COMPLETE) |