c.f. #13476tags/v1.90.0rc1
@@ -0,0 +1 @@ | |||||
Fix bug where purging history and paginating simultaneously could lead to database corruption when using workers. |
@@ -63,6 +63,7 @@ from synapse.federation.federation_base import ( | |||||
) | ) | ||||
from synapse.federation.persistence import TransactionActions | from synapse.federation.persistence import TransactionActions | ||||
from synapse.federation.units import Edu, Transaction | from synapse.federation.units import Edu, Transaction | ||||
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME | |||||
from synapse.http.servlet import assert_params_in_dict | from synapse.http.servlet import assert_params_in_dict | ||||
from synapse.logging.context import ( | from synapse.logging.context import ( | ||||
make_deferred_yieldable, | make_deferred_yieldable, | ||||
@@ -137,6 +138,7 @@ class FederationServer(FederationBase): | |||||
self._event_auth_handler = hs.get_event_auth_handler() | self._event_auth_handler = hs.get_event_auth_handler() | ||||
self._room_member_handler = hs.get_room_member_handler() | self._room_member_handler = hs.get_room_member_handler() | ||||
self._e2e_keys_handler = hs.get_e2e_keys_handler() | self._e2e_keys_handler = hs.get_e2e_keys_handler() | ||||
self._worker_lock_handler = hs.get_worker_locks_handler() | |||||
self._state_storage_controller = hs.get_storage_controllers().state | self._state_storage_controller = hs.get_storage_controllers().state | ||||
@@ -1236,9 +1238,18 @@ class FederationServer(FederationBase): | |||||
logger.info("handling received PDU in room %s: %s", room_id, event) | logger.info("handling received PDU in room %s: %s", room_id, event) | ||||
try: | try: | ||||
with nested_logging_context(event.event_id): | with nested_logging_context(event.event_id): | ||||
await self._federation_event_handler.on_receive_pdu( | |||||
origin, event | |||||
) | |||||
# We're taking out a lock within a lock, which could | |||||
# lead to deadlocks if we're not careful. However, it is | |||||
# safe on this occasion as we only ever take a write | |||||
# lock when deleting a room, which we would never do | |||||
# while holding the `_INBOUND_EVENT_HANDLING_LOCK_NAME` | |||||
# lock. | |||||
async with self._worker_lock_handler.acquire_read_write_lock( | |||||
DELETE_ROOM_LOCK_NAME, room_id, write=False | |||||
): | |||||
await self._federation_event_handler.on_receive_pdu( | |||||
origin, event | |||||
) | |||||
except FederationError as e: | except FederationError as e: | ||||
# XXX: Ideally we'd inform the remote we failed to process | # XXX: Ideally we'd inform the remote we failed to process | ||||
# the event, but we can't return an error in the transaction | # the event, but we can't return an error in the transaction | ||||
@@ -53,6 +53,7 @@ from synapse.events.snapshot import EventContext, UnpersistedEventContextBase | |||||
from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field | from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field | ||||
from synapse.events.validator import EventValidator | from synapse.events.validator import EventValidator | ||||
from synapse.handlers.directory import DirectoryHandler | from synapse.handlers.directory import DirectoryHandler | ||||
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME | |||||
from synapse.logging import opentracing | from synapse.logging import opentracing | ||||
from synapse.logging.context import make_deferred_yieldable, run_in_background | from synapse.logging.context import make_deferred_yieldable, run_in_background | ||||
from synapse.metrics.background_process_metrics import run_as_background_process | from synapse.metrics.background_process_metrics import run_as_background_process | ||||
@@ -485,6 +486,7 @@ class EventCreationHandler: | |||||
self._events_shard_config = self.config.worker.events_shard_config | self._events_shard_config = self.config.worker.events_shard_config | ||||
self._instance_name = hs.get_instance_name() | self._instance_name = hs.get_instance_name() | ||||
self._notifier = hs.get_notifier() | self._notifier = hs.get_notifier() | ||||
self._worker_lock_handler = hs.get_worker_locks_handler() | |||||
self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state | self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state | ||||
@@ -1010,6 +1012,37 @@ class EventCreationHandler: | |||||
event.internal_metadata.stream_ordering, | event.internal_metadata.stream_ordering, | ||||
) | ) | ||||
async with self._worker_lock_handler.acquire_read_write_lock( | |||||
DELETE_ROOM_LOCK_NAME, room_id, write=False | |||||
): | |||||
return await self._create_and_send_nonmember_event_locked( | |||||
requester=requester, | |||||
event_dict=event_dict, | |||||
allow_no_prev_events=allow_no_prev_events, | |||||
prev_event_ids=prev_event_ids, | |||||
state_event_ids=state_event_ids, | |||||
ratelimit=ratelimit, | |||||
txn_id=txn_id, | |||||
ignore_shadow_ban=ignore_shadow_ban, | |||||
outlier=outlier, | |||||
depth=depth, | |||||
) | |||||
async def _create_and_send_nonmember_event_locked( | |||||
self, | |||||
requester: Requester, | |||||
event_dict: dict, | |||||
allow_no_prev_events: bool = False, | |||||
prev_event_ids: Optional[List[str]] = None, | |||||
state_event_ids: Optional[List[str]] = None, | |||||
ratelimit: bool = True, | |||||
txn_id: Optional[str] = None, | |||||
ignore_shadow_ban: bool = False, | |||||
outlier: bool = False, | |||||
depth: Optional[int] = None, | |||||
) -> Tuple[EventBase, int]: | |||||
room_id = event_dict["room_id"] | |||||
# If we don't have any prev event IDs specified then we need to | # If we don't have any prev event IDs specified then we need to | ||||
# check that the host is in the room (as otherwise populating the | # check that the host is in the room (as otherwise populating the | ||||
# prev events will fail), at which point we may as well check the | # prev events will fail), at which point we may as well check the | ||||
@@ -1923,7 +1956,10 @@ class EventCreationHandler: | |||||
) | ) | ||||
for room_id in room_ids: | for room_id in room_ids: | ||||
dummy_event_sent = await self._send_dummy_event_for_room(room_id) | |||||
async with self._worker_lock_handler.acquire_read_write_lock( | |||||
DELETE_ROOM_LOCK_NAME, room_id, write=False | |||||
): | |||||
dummy_event_sent = await self._send_dummy_event_for_room(room_id) | |||||
if not dummy_event_sent: | if not dummy_event_sent: | ||||
# Did not find a valid user in the room, so remove from future attempts | # Did not find a valid user in the room, so remove from future attempts | ||||
@@ -46,6 +46,11 @@ logger = logging.getLogger(__name__) | |||||
BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3 | BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3 | ||||
PURGE_HISTORY_LOCK_NAME = "purge_history_lock" | |||||
DELETE_ROOM_LOCK_NAME = "delete_room_lock" | |||||
@attr.s(slots=True, auto_attribs=True) | @attr.s(slots=True, auto_attribs=True) | ||||
class PurgeStatus: | class PurgeStatus: | ||||
"""Object tracking the status of a purge request | """Object tracking the status of a purge request | ||||
@@ -142,6 +147,7 @@ class PaginationHandler: | |||||
self._server_name = hs.hostname | self._server_name = hs.hostname | ||||
self._room_shutdown_handler = hs.get_room_shutdown_handler() | self._room_shutdown_handler = hs.get_room_shutdown_handler() | ||||
self._relations_handler = hs.get_relations_handler() | self._relations_handler = hs.get_relations_handler() | ||||
self._worker_locks = hs.get_worker_locks_handler() | |||||
self.pagination_lock = ReadWriteLock() | self.pagination_lock = ReadWriteLock() | ||||
# IDs of rooms in which there currently an active purge *or delete* operation. | # IDs of rooms in which there currently an active purge *or delete* operation. | ||||
@@ -356,7 +362,9 @@ class PaginationHandler: | |||||
""" | """ | ||||
self._purges_in_progress_by_room.add(room_id) | self._purges_in_progress_by_room.add(room_id) | ||||
try: | try: | ||||
async with self.pagination_lock.write(room_id): | |||||
async with self._worker_locks.acquire_read_write_lock( | |||||
PURGE_HISTORY_LOCK_NAME, room_id, write=True | |||||
): | |||||
await self._storage_controllers.purge_events.purge_history( | await self._storage_controllers.purge_events.purge_history( | ||||
room_id, token, delete_local_events | room_id, token, delete_local_events | ||||
) | ) | ||||
@@ -412,7 +420,10 @@ class PaginationHandler: | |||||
room_id: room to be purged | room_id: room to be purged | ||||
force: set true to skip checking for joined users. | force: set true to skip checking for joined users. | ||||
""" | """ | ||||
async with self.pagination_lock.write(room_id): | |||||
async with self._worker_locks.acquire_multi_read_write_lock( | |||||
[(PURGE_HISTORY_LOCK_NAME, room_id), (DELETE_ROOM_LOCK_NAME, room_id)], | |||||
write=True, | |||||
): | |||||
# first check that we have no users in this room | # first check that we have no users in this room | ||||
if not force: | if not force: | ||||
joined = await self.store.is_host_joined(room_id, self._server_name) | joined = await self.store.is_host_joined(room_id, self._server_name) | ||||
@@ -471,7 +482,9 @@ class PaginationHandler: | |||||
room_token = from_token.room_key | room_token = from_token.room_key | ||||
async with self.pagination_lock.read(room_id): | |||||
async with self._worker_locks.acquire_read_write_lock( | |||||
PURGE_HISTORY_LOCK_NAME, room_id, write=False | |||||
): | |||||
(membership, member_event_id) = (None, None) | (membership, member_event_id) = (None, None) | ||||
if not use_admin_priviledge: | if not use_admin_priviledge: | ||||
( | ( | ||||
@@ -747,7 +760,9 @@ class PaginationHandler: | |||||
self._purges_in_progress_by_room.add(room_id) | self._purges_in_progress_by_room.add(room_id) | ||||
try: | try: | ||||
async with self.pagination_lock.write(room_id): | |||||
async with self._worker_locks.acquire_read_write_lock( | |||||
PURGE_HISTORY_LOCK_NAME, room_id, write=True | |||||
): | |||||
self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN | self._delete_by_id[delete_id].status = DeleteStatus.STATUS_SHUTTING_DOWN | ||||
self._delete_by_id[ | self._delete_by_id[ | ||||
delete_id | delete_id | ||||
@@ -39,6 +39,7 @@ from synapse.events import EventBase | |||||
from synapse.events.snapshot import EventContext | from synapse.events.snapshot import EventContext | ||||
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN | from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN | ||||
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler | from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler | ||||
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME | |||||
from synapse.logging import opentracing | from synapse.logging import opentracing | ||||
from synapse.metrics import event_processing_positions | from synapse.metrics import event_processing_positions | ||||
from synapse.metrics.background_process_metrics import run_as_background_process | from synapse.metrics.background_process_metrics import run_as_background_process | ||||
@@ -94,6 +95,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||||
self.event_creation_handler = hs.get_event_creation_handler() | self.event_creation_handler = hs.get_event_creation_handler() | ||||
self.account_data_handler = hs.get_account_data_handler() | self.account_data_handler = hs.get_account_data_handler() | ||||
self.event_auth_handler = hs.get_event_auth_handler() | self.event_auth_handler = hs.get_event_auth_handler() | ||||
self._worker_lock_handler = hs.get_worker_locks_handler() | |||||
self.member_linearizer: Linearizer = Linearizer(name="member") | self.member_linearizer: Linearizer = Linearizer(name="member") | ||||
self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter") | self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter") | ||||
@@ -638,26 +640,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): | |||||
# by application services), and then by room ID. | # by application services), and then by room ID. | ||||
async with self.member_as_limiter.queue(as_id): | async with self.member_as_limiter.queue(as_id): | ||||
async with self.member_linearizer.queue(key): | async with self.member_linearizer.queue(key): | ||||
with opentracing.start_active_span("update_membership_locked"): | |||||
result = await self.update_membership_locked( | |||||
requester, | |||||
target, | |||||
room_id, | |||||
action, | |||||
txn_id=txn_id, | |||||
remote_room_hosts=remote_room_hosts, | |||||
third_party_signed=third_party_signed, | |||||
ratelimit=ratelimit, | |||||
content=content, | |||||
new_room=new_room, | |||||
require_consent=require_consent, | |||||
outlier=outlier, | |||||
allow_no_prev_events=allow_no_prev_events, | |||||
prev_event_ids=prev_event_ids, | |||||
state_event_ids=state_event_ids, | |||||
depth=depth, | |||||
origin_server_ts=origin_server_ts, | |||||
) | |||||
async with self._worker_lock_handler.acquire_read_write_lock( | |||||
DELETE_ROOM_LOCK_NAME, room_id, write=False | |||||
): | |||||
with opentracing.start_active_span("update_membership_locked"): | |||||
result = await self.update_membership_locked( | |||||
requester, | |||||
target, | |||||
room_id, | |||||
action, | |||||
txn_id=txn_id, | |||||
remote_room_hosts=remote_room_hosts, | |||||
third_party_signed=third_party_signed, | |||||
ratelimit=ratelimit, | |||||
content=content, | |||||
new_room=new_room, | |||||
require_consent=require_consent, | |||||
outlier=outlier, | |||||
allow_no_prev_events=allow_no_prev_events, | |||||
prev_event_ids=prev_event_ids, | |||||
state_event_ids=state_event_ids, | |||||
depth=depth, | |||||
origin_server_ts=origin_server_ts, | |||||
) | |||||
return result | return result | ||||
@@ -0,0 +1,333 @@ | |||||
# Copyright 2023 The Matrix.org Foundation C.I.C. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
import random | |||||
from types import TracebackType | |||||
from typing import ( | |||||
TYPE_CHECKING, | |||||
AsyncContextManager, | |||||
Collection, | |||||
Dict, | |||||
Optional, | |||||
Tuple, | |||||
Type, | |||||
Union, | |||||
) | |||||
from weakref import WeakSet | |||||
import attr | |||||
from twisted.internet import defer | |||||
from twisted.internet.interfaces import IReactorTime | |||||
from synapse.logging.context import PreserveLoggingContext | |||||
from synapse.logging.opentracing import start_active_span | |||||
from synapse.metrics.background_process_metrics import wrap_as_background_process | |||||
from synapse.storage.databases.main.lock import Lock, LockStore | |||||
from synapse.util.async_helpers import timeout_deferred | |||||
if TYPE_CHECKING: | |||||
from synapse.logging.opentracing import opentracing | |||||
from synapse.server import HomeServer | |||||
DELETE_ROOM_LOCK_NAME = "delete_room_lock" | |||||
class WorkerLocksHandler: | |||||
"""A class for waiting on taking out locks, rather than using the storage | |||||
functions directly (which don't support awaiting). | |||||
""" | |||||
def __init__(self, hs: "HomeServer") -> None: | |||||
self._reactor = hs.get_reactor() | |||||
self._store = hs.get_datastores().main | |||||
self._clock = hs.get_clock() | |||||
self._notifier = hs.get_notifier() | |||||
self._instance_name = hs.get_instance_name() | |||||
# Map from lock name/key to set of `WaitingLock` that are active for | |||||
# that lock. | |||||
self._locks: Dict[ | |||||
Tuple[str, str], WeakSet[Union[WaitingLock, WaitingMultiLock]] | |||||
] = {} | |||||
self._clock.looping_call(self._cleanup_locks, 30_000) | |||||
self._notifier.add_lock_released_callback(self._on_lock_released) | |||||
def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock": | |||||
"""Acquire a standard lock, returns a context manager that will block | |||||
until the lock is acquired. | |||||
Note: Care must be taken to avoid deadlocks. In particular, this | |||||
function does *not* timeout. | |||||
Usage: | |||||
async with handler.acquire_lock(name, key): | |||||
# Do work while holding the lock... | |||||
""" | |||||
lock = WaitingLock( | |||||
reactor=self._reactor, | |||||
store=self._store, | |||||
handler=self, | |||||
lock_name=lock_name, | |||||
lock_key=lock_key, | |||||
write=None, | |||||
) | |||||
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) | |||||
return lock | |||||
def acquire_read_write_lock( | |||||
self, | |||||
lock_name: str, | |||||
lock_key: str, | |||||
*, | |||||
write: bool, | |||||
) -> "WaitingLock": | |||||
"""Acquire a read/write lock, returns a context manager that will block | |||||
until the lock is acquired. | |||||
Note: Care must be taken to avoid deadlocks. In particular, this | |||||
function does *not* timeout. | |||||
Usage: | |||||
async with handler.acquire_read_write_lock(name, key, write=True): | |||||
# Do work while holding the lock... | |||||
""" | |||||
lock = WaitingLock( | |||||
reactor=self._reactor, | |||||
store=self._store, | |||||
handler=self, | |||||
lock_name=lock_name, | |||||
lock_key=lock_key, | |||||
write=write, | |||||
) | |||||
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) | |||||
return lock | |||||
def acquire_multi_read_write_lock( | |||||
self, | |||||
lock_names: Collection[Tuple[str, str]], | |||||
*, | |||||
write: bool, | |||||
) -> "WaitingMultiLock": | |||||
"""Acquires multi read/write locks at once, returns a context manager | |||||
that will block until all the locks are acquired. | |||||
This will try and acquire all locks at once, and will never hold on to a | |||||
subset of the locks. (This avoids accidentally creating deadlocks). | |||||
Note: Care must be taken to avoid deadlocks. In particular, this | |||||
function does *not* timeout. | |||||
""" | |||||
lock = WaitingMultiLock( | |||||
lock_names=lock_names, | |||||
write=write, | |||||
reactor=self._reactor, | |||||
store=self._store, | |||||
handler=self, | |||||
) | |||||
for lock_name, lock_key in lock_names: | |||||
self._locks.setdefault((lock_name, lock_key), WeakSet()).add(lock) | |||||
return lock | |||||
def notify_lock_released(self, lock_name: str, lock_key: str) -> None: | |||||
"""Notify that a lock has been released. | |||||
Pokes both the notifier and replication. | |||||
""" | |||||
self._notifier.notify_lock_released(self._instance_name, lock_name, lock_key) | |||||
def _on_lock_released( | |||||
self, instance_name: str, lock_name: str, lock_key: str | |||||
) -> None: | |||||
"""Called when a lock has been released. | |||||
Wakes up any locks that might be waiting on this. | |||||
""" | |||||
locks = self._locks.get((lock_name, lock_key)) | |||||
if not locks: | |||||
return | |||||
def _wake_deferred(deferred: defer.Deferred) -> None: | |||||
if not deferred.called: | |||||
deferred.callback(None) | |||||
for lock in locks: | |||||
self._clock.call_later(0, _wake_deferred, lock.deferred) | |||||
@wrap_as_background_process("_cleanup_locks") | |||||
async def _cleanup_locks(self) -> None: | |||||
"""Periodically cleans out stale entries in the locks map""" | |||||
self._locks = {key: value for key, value in self._locks.items() if value} | |||||
@attr.s(auto_attribs=True, eq=False) | |||||
class WaitingLock: | |||||
reactor: IReactorTime | |||||
store: LockStore | |||||
handler: WorkerLocksHandler | |||||
lock_name: str | |||||
lock_key: str | |||||
write: Optional[bool] | |||||
deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) | |||||
_inner_lock: Optional[Lock] = None | |||||
_retry_interval: float = 0.1 | |||||
_lock_span: "opentracing.Scope" = attr.Factory( | |||||
lambda: start_active_span("WaitingLock.lock") | |||||
) | |||||
async def __aenter__(self) -> None: | |||||
self._lock_span.__enter__() | |||||
with start_active_span("WaitingLock.waiting_for_lock"): | |||||
while self._inner_lock is None: | |||||
self.deferred = defer.Deferred() | |||||
if self.write is not None: | |||||
lock = await self.store.try_acquire_read_write_lock( | |||||
self.lock_name, self.lock_key, write=self.write | |||||
) | |||||
else: | |||||
lock = await self.store.try_acquire_lock( | |||||
self.lock_name, self.lock_key | |||||
) | |||||
if lock: | |||||
self._inner_lock = lock | |||||
break | |||||
try: | |||||
# Wait until the we get notified the lock might have been | |||||
# released (by the deferred being resolved). We also | |||||
# periodically wake up in case the lock was released but we | |||||
# weren't notified. | |||||
with PreserveLoggingContext(): | |||||
await timeout_deferred( | |||||
deferred=self.deferred, | |||||
timeout=self._get_next_retry_interval(), | |||||
reactor=self.reactor, | |||||
) | |||||
except Exception: | |||||
pass | |||||
return await self._inner_lock.__aenter__() | |||||
async def __aexit__( | |||||
self, | |||||
exc_type: Optional[Type[BaseException]], | |||||
exc: Optional[BaseException], | |||||
tb: Optional[TracebackType], | |||||
) -> Optional[bool]: | |||||
assert self._inner_lock | |||||
self.handler.notify_lock_released(self.lock_name, self.lock_key) | |||||
try: | |||||
r = await self._inner_lock.__aexit__(exc_type, exc, tb) | |||||
finally: | |||||
self._lock_span.__exit__(exc_type, exc, tb) | |||||
return r | |||||
def _get_next_retry_interval(self) -> float: | |||||
next = self._retry_interval | |||||
self._retry_interval = max(5, next * 2) | |||||
return next * random.uniform(0.9, 1.1) | |||||
@attr.s(auto_attribs=True, eq=False) | |||||
class WaitingMultiLock: | |||||
lock_names: Collection[Tuple[str, str]] | |||||
write: bool | |||||
reactor: IReactorTime | |||||
store: LockStore | |||||
handler: WorkerLocksHandler | |||||
deferred: "defer.Deferred[None]" = attr.Factory(defer.Deferred) | |||||
_inner_lock_cm: Optional[AsyncContextManager] = None | |||||
_retry_interval: float = 0.1 | |||||
_lock_span: "opentracing.Scope" = attr.Factory( | |||||
lambda: start_active_span("WaitingLock.lock") | |||||
) | |||||
async def __aenter__(self) -> None: | |||||
self._lock_span.__enter__() | |||||
with start_active_span("WaitingLock.waiting_for_lock"): | |||||
while self._inner_lock_cm is None: | |||||
self.deferred = defer.Deferred() | |||||
lock_cm = await self.store.try_acquire_multi_read_write_lock( | |||||
self.lock_names, write=self.write | |||||
) | |||||
if lock_cm: | |||||
self._inner_lock_cm = lock_cm | |||||
break | |||||
try: | |||||
# Wait until the we get notified the lock might have been | |||||
# released (by the deferred being resolved). We also | |||||
# periodically wake up in case the lock was released but we | |||||
# weren't notified. | |||||
with PreserveLoggingContext(): | |||||
await timeout_deferred( | |||||
deferred=self.deferred, | |||||
timeout=self._get_next_retry_interval(), | |||||
reactor=self.reactor, | |||||
) | |||||
except Exception: | |||||
pass | |||||
assert self._inner_lock_cm | |||||
await self._inner_lock_cm.__aenter__() | |||||
return | |||||
async def __aexit__( | |||||
self, | |||||
exc_type: Optional[Type[BaseException]], | |||||
exc: Optional[BaseException], | |||||
tb: Optional[TracebackType], | |||||
) -> Optional[bool]: | |||||
assert self._inner_lock_cm | |||||
for lock_name, lock_key in self.lock_names: | |||||
self.handler.notify_lock_released(lock_name, lock_key) | |||||
try: | |||||
r = await self._inner_lock_cm.__aexit__(exc_type, exc, tb) | |||||
finally: | |||||
self._lock_span.__exit__(exc_type, exc, tb) | |||||
return r | |||||
def _get_next_retry_interval(self) -> float: | |||||
next = self._retry_interval | |||||
self._retry_interval = max(5, next * 2) | |||||
return next * random.uniform(0.9, 1.1) |
@@ -234,6 +234,9 @@ class Notifier: | |||||
self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules | self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules | ||||
# List of callbacks to be notified when a lock is released | |||||
self._lock_released_callback: List[Callable[[str, str, str], None]] = [] | |||||
self.clock = hs.get_clock() | self.clock = hs.get_clock() | ||||
self.appservice_handler = hs.get_application_service_handler() | self.appservice_handler = hs.get_application_service_handler() | ||||
self._pusher_pool = hs.get_pusherpool() | self._pusher_pool = hs.get_pusherpool() | ||||
@@ -785,6 +788,19 @@ class Notifier: | |||||
# that any in flight requests can be immediately retried. | # that any in flight requests can be immediately retried. | ||||
self._federation_client.wake_destination(server) | self._federation_client.wake_destination(server) | ||||
def add_lock_released_callback( | |||||
self, callback: Callable[[str, str, str], None] | |||||
) -> None: | |||||
"""Add a function to be called whenever we are notified about a released lock.""" | |||||
self._lock_released_callback.append(callback) | |||||
def notify_lock_released( | |||||
self, instance_name: str, lock_name: str, lock_key: str | |||||
) -> None: | |||||
"""Notify the callbacks that a lock has been released.""" | |||||
for cb in self._lock_released_callback: | |||||
cb(instance_name, lock_name, lock_key) | |||||
@attr.s(auto_attribs=True) | @attr.s(auto_attribs=True) | ||||
class ReplicationNotifier: | class ReplicationNotifier: | ||||
@@ -422,6 +422,36 @@ class RemoteServerUpCommand(_SimpleCommand): | |||||
NAME = "REMOTE_SERVER_UP" | NAME = "REMOTE_SERVER_UP" | ||||
class LockReleasedCommand(Command): | |||||
"""Sent to inform other instances that a given lock has been dropped. | |||||
Format:: | |||||
LOCK_RELEASED ["<instance_name>", "<lock_name>", "<lock_key>"] | |||||
""" | |||||
NAME = "LOCK_RELEASED" | |||||
def __init__( | |||||
self, | |||||
instance_name: str, | |||||
lock_name: str, | |||||
lock_key: str, | |||||
): | |||||
self.instance_name = instance_name | |||||
self.lock_name = lock_name | |||||
self.lock_key = lock_key | |||||
@classmethod | |||||
def from_line(cls: Type["LockReleasedCommand"], line: str) -> "LockReleasedCommand": | |||||
instance_name, lock_name, lock_key = json_decoder.decode(line) | |||||
return cls(instance_name, lock_name, lock_key) | |||||
def to_line(self) -> str: | |||||
return json_encoder.encode([self.instance_name, self.lock_name, self.lock_key]) | |||||
_COMMANDS: Tuple[Type[Command], ...] = ( | _COMMANDS: Tuple[Type[Command], ...] = ( | ||||
ServerCommand, | ServerCommand, | ||||
RdataCommand, | RdataCommand, | ||||
@@ -435,6 +465,7 @@ _COMMANDS: Tuple[Type[Command], ...] = ( | |||||
UserIpCommand, | UserIpCommand, | ||||
RemoteServerUpCommand, | RemoteServerUpCommand, | ||||
ClearUserSyncsCommand, | ClearUserSyncsCommand, | ||||
LockReleasedCommand, | |||||
) | ) | ||||
# Map of command name to command type. | # Map of command name to command type. | ||||
@@ -448,6 +479,7 @@ VALID_SERVER_COMMANDS = ( | |||||
ErrorCommand.NAME, | ErrorCommand.NAME, | ||||
PingCommand.NAME, | PingCommand.NAME, | ||||
RemoteServerUpCommand.NAME, | RemoteServerUpCommand.NAME, | ||||
LockReleasedCommand.NAME, | |||||
) | ) | ||||
# The commands the client is allowed to send | # The commands the client is allowed to send | ||||
@@ -461,6 +493,7 @@ VALID_CLIENT_COMMANDS = ( | |||||
UserIpCommand.NAME, | UserIpCommand.NAME, | ||||
ErrorCommand.NAME, | ErrorCommand.NAME, | ||||
RemoteServerUpCommand.NAME, | RemoteServerUpCommand.NAME, | ||||
LockReleasedCommand.NAME, | |||||
) | ) | ||||
@@ -39,6 +39,7 @@ from synapse.replication.tcp.commands import ( | |||||
ClearUserSyncsCommand, | ClearUserSyncsCommand, | ||||
Command, | Command, | ||||
FederationAckCommand, | FederationAckCommand, | ||||
LockReleasedCommand, | |||||
PositionCommand, | PositionCommand, | ||||
RdataCommand, | RdataCommand, | ||||
RemoteServerUpCommand, | RemoteServerUpCommand, | ||||
@@ -248,6 +249,9 @@ class ReplicationCommandHandler: | |||||
if self._is_master or self._should_insert_client_ips: | if self._is_master or self._should_insert_client_ips: | ||||
self.subscribe_to_channel("USER_IP") | self.subscribe_to_channel("USER_IP") | ||||
if hs.config.redis.redis_enabled: | |||||
self._notifier.add_lock_released_callback(self.on_lock_released) | |||||
def subscribe_to_channel(self, channel_name: str) -> None: | def subscribe_to_channel(self, channel_name: str) -> None: | ||||
""" | """ | ||||
Indicates that we wish to subscribe to a Redis channel by name. | Indicates that we wish to subscribe to a Redis channel by name. | ||||
@@ -648,6 +652,17 @@ class ReplicationCommandHandler: | |||||
self._notifier.notify_remote_server_up(cmd.data) | self._notifier.notify_remote_server_up(cmd.data) | ||||
def on_LOCK_RELEASED( | |||||
self, conn: IReplicationConnection, cmd: LockReleasedCommand | |||||
) -> None: | |||||
"""Called when we get a new LOCK_RELEASED command.""" | |||||
if cmd.instance_name == self._instance_name: | |||||
return | |||||
self._notifier.notify_lock_released( | |||||
cmd.instance_name, cmd.lock_name, cmd.lock_key | |||||
) | |||||
def new_connection(self, connection: IReplicationConnection) -> None: | def new_connection(self, connection: IReplicationConnection) -> None: | ||||
"""Called when we have a new connection.""" | """Called when we have a new connection.""" | ||||
self._connections.append(connection) | self._connections.append(connection) | ||||
@@ -754,6 +769,13 @@ class ReplicationCommandHandler: | |||||
""" | """ | ||||
self.send_command(RdataCommand(stream_name, self._instance_name, token, data)) | self.send_command(RdataCommand(stream_name, self._instance_name, token, data)) | ||||
def on_lock_released( | |||||
self, instance_name: str, lock_name: str, lock_key: str | |||||
) -> None: | |||||
"""Called when we released a lock and should notify other instances.""" | |||||
if instance_name == self._instance_name: | |||||
self.send_command(LockReleasedCommand(instance_name, lock_name, lock_key)) | |||||
UpdateToken = TypeVar("UpdateToken") | UpdateToken = TypeVar("UpdateToken") | ||||
UpdateRow = TypeVar("UpdateRow") | UpdateRow = TypeVar("UpdateRow") | ||||
@@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Tuple | |||||
from synapse.api.errors import Codes, ShadowBanError, SynapseError | from synapse.api.errors import Codes, ShadowBanError, SynapseError | ||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS | from synapse.api.room_versions import KNOWN_ROOM_VERSIONS | ||||
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME | |||||
from synapse.http.server import HttpServer | from synapse.http.server import HttpServer | ||||
from synapse.http.servlet import ( | from synapse.http.servlet import ( | ||||
RestServlet, | RestServlet, | ||||
@@ -60,6 +61,7 @@ class RoomUpgradeRestServlet(RestServlet): | |||||
self._hs = hs | self._hs = hs | ||||
self._room_creation_handler = hs.get_room_creation_handler() | self._room_creation_handler = hs.get_room_creation_handler() | ||||
self._auth = hs.get_auth() | self._auth = hs.get_auth() | ||||
self._worker_lock_handler = hs.get_worker_locks_handler() | |||||
async def on_POST( | async def on_POST( | ||||
self, request: SynapseRequest, room_id: str | self, request: SynapseRequest, room_id: str | ||||
@@ -78,9 +80,12 @@ class RoomUpgradeRestServlet(RestServlet): | |||||
) | ) | ||||
try: | try: | ||||
new_room_id = await self._room_creation_handler.upgrade_room( | |||||
requester, room_id, new_version | |||||
) | |||||
async with self._worker_lock_handler.acquire_read_write_lock( | |||||
DELETE_ROOM_LOCK_NAME, room_id, write=False | |||||
): | |||||
new_room_id = await self._room_creation_handler.upgrade_room( | |||||
requester, room_id, new_version | |||||
) | |||||
except ShadowBanError: | except ShadowBanError: | ||||
# Generate a random room ID. | # Generate a random room ID. | ||||
new_room_id = stringutils.random_string(18) | new_room_id = stringutils.random_string(18) | ||||
@@ -107,6 +107,7 @@ from synapse.handlers.stats import StatsHandler | |||||
from synapse.handlers.sync import SyncHandler | from synapse.handlers.sync import SyncHandler | ||||
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler | from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler | ||||
from synapse.handlers.user_directory import UserDirectoryHandler | from synapse.handlers.user_directory import UserDirectoryHandler | ||||
from synapse.handlers.worker_lock import WorkerLocksHandler | |||||
from synapse.http.client import ( | from synapse.http.client import ( | ||||
InsecureInterceptableContextFactory, | InsecureInterceptableContextFactory, | ||||
ReplicationClient, | ReplicationClient, | ||||
@@ -912,3 +913,7 @@ class HomeServer(metaclass=abc.ABCMeta): | |||||
def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager: | def get_common_usage_metrics_manager(self) -> CommonUsageMetricsManager: | ||||
"""Usage metrics shared between phone home stats and the prometheus exporter.""" | """Usage metrics shared between phone home stats and the prometheus exporter.""" | ||||
return CommonUsageMetricsManager(self) | return CommonUsageMetricsManager(self) | ||||
@cache_in_self | |||||
def get_worker_locks_handler(self) -> WorkerLocksHandler: | |||||
return WorkerLocksHandler(self) |
@@ -45,6 +45,7 @@ from twisted.internet import defer | |||||
from synapse.api.constants import EventTypes, Membership | from synapse.api.constants import EventTypes, Membership | ||||
from synapse.events import EventBase | from synapse.events import EventBase | ||||
from synapse.events.snapshot import EventContext | from synapse.events.snapshot import EventContext | ||||
from synapse.handlers.worker_lock import DELETE_ROOM_LOCK_NAME | |||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable | from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable | ||||
from synapse.logging.opentracing import ( | from synapse.logging.opentracing import ( | ||||
SynapseTags, | SynapseTags, | ||||
@@ -338,6 +339,7 @@ class EventsPersistenceStorageController: | |||||
) | ) | ||||
self._state_resolution_handler = hs.get_state_resolution_handler() | self._state_resolution_handler = hs.get_state_resolution_handler() | ||||
self._state_controller = state_controller | self._state_controller = state_controller | ||||
self.hs = hs | |||||
async def _process_event_persist_queue_task( | async def _process_event_persist_queue_task( | ||||
self, | self, | ||||
@@ -350,15 +352,22 @@ class EventsPersistenceStorageController: | |||||
A dictionary of event ID to event ID we didn't persist as we already | A dictionary of event ID to event ID we didn't persist as we already | ||||
had another event persisted with the same TXN ID. | had another event persisted with the same TXN ID. | ||||
""" | """ | ||||
if isinstance(task, _PersistEventsTask): | |||||
return await self._persist_event_batch(room_id, task) | |||||
elif isinstance(task, _UpdateCurrentStateTask): | |||||
await self._update_current_state(room_id, task) | |||||
return {} | |||||
else: | |||||
raise AssertionError( | |||||
f"Found an unexpected task type in event persistence queue: {task}" | |||||
) | |||||
# Ensure that the room can't be deleted while we're persisting events to | |||||
# it. We might already have taken out the lock, but since this is just a | |||||
# "read" lock its inherently reentrant. | |||||
async with self.hs.get_worker_locks_handler().acquire_read_write_lock( | |||||
DELETE_ROOM_LOCK_NAME, room_id, write=False | |||||
): | |||||
if isinstance(task, _PersistEventsTask): | |||||
return await self._persist_event_batch(room_id, task) | |||||
elif isinstance(task, _UpdateCurrentStateTask): | |||||
await self._update_current_state(room_id, task) | |||||
return {} | |||||
else: | |||||
raise AssertionError( | |||||
f"Found an unexpected task type in event persistence queue: {task}" | |||||
) | |||||
@trace | @trace | ||||
async def persist_events( | async def persist_events( | ||||
@@ -12,8 +12,9 @@ | |||||
# See the License for the specific language governing permissions and | # See the License for the specific language governing permissions and | ||||
# limitations under the License. | # limitations under the License. | ||||
import logging | import logging | ||||
from contextlib import AsyncExitStack | |||||
from types import TracebackType | from types import TracebackType | ||||
from typing import TYPE_CHECKING, Optional, Set, Tuple, Type | |||||
from typing import TYPE_CHECKING, Collection, Optional, Set, Tuple, Type | |||||
from weakref import WeakValueDictionary | from weakref import WeakValueDictionary | ||||
from twisted.internet.interfaces import IReactorCore | from twisted.internet.interfaces import IReactorCore | ||||
@@ -208,76 +209,85 @@ class LockStore(SQLBaseStore): | |||||
used (otherwise the lock will leak). | used (otherwise the lock will leak). | ||||
""" | """ | ||||
try: | |||||
lock = await self.db_pool.runInteraction( | |||||
"try_acquire_read_write_lock", | |||||
self._try_acquire_read_write_lock_txn, | |||||
lock_name, | |||||
lock_key, | |||||
write, | |||||
) | |||||
except self.database_engine.module.IntegrityError: | |||||
return None | |||||
return lock | |||||
def _try_acquire_read_write_lock_txn( | |||||
self, | |||||
txn: LoggingTransaction, | |||||
lock_name: str, | |||||
lock_key: str, | |||||
write: bool, | |||||
) -> "Lock": | |||||
# We attempt to acquire the lock by inserting into | |||||
# `worker_read_write_locks` and seeing if that fails any | |||||
# constraints. If it doesn't then we have acquired the lock, | |||||
# otherwise we haven't. | |||||
# | |||||
# Before that though we clear the table of any stale locks. | |||||
now = self._clock.time_msec() | now = self._clock.time_msec() | ||||
token = random_string(6) | token = random_string(6) | ||||
def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None: | |||||
# We attempt to acquire the lock by inserting into | |||||
# `worker_read_write_locks` and seeing if that fails any | |||||
# constraints. If it doesn't then we have acquired the lock, | |||||
# otherwise we haven't. | |||||
# | |||||
# Before that though we clear the table of any stale locks. | |||||
delete_sql = """ | |||||
DELETE FROM worker_read_write_locks | |||||
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?; | |||||
""" | |||||
insert_sql = """ | |||||
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts) | |||||
VALUES (?, ?, ?, ?, ?, ?) | |||||
""" | |||||
if isinstance(self.database_engine, PostgresEngine): | |||||
# For Postgres we can send these queries at the same time. | |||||
txn.execute( | |||||
delete_sql + ";" + insert_sql, | |||||
( | |||||
# DELETE args | |||||
now - _LOCK_TIMEOUT_MS, | |||||
lock_name, | |||||
lock_key, | |||||
# UPSERT args | |||||
lock_name, | |||||
lock_key, | |||||
write, | |||||
self._instance_name, | |||||
token, | |||||
now, | |||||
), | |||||
) | |||||
else: | |||||
# For SQLite these need to be two queries. | |||||
txn.execute( | |||||
delete_sql, | |||||
( | |||||
now - _LOCK_TIMEOUT_MS, | |||||
lock_name, | |||||
lock_key, | |||||
), | |||||
) | |||||
txn.execute( | |||||
insert_sql, | |||||
( | |||||
lock_name, | |||||
lock_key, | |||||
write, | |||||
self._instance_name, | |||||
token, | |||||
now, | |||||
), | |||||
) | |||||
delete_sql = """ | |||||
DELETE FROM worker_read_write_locks | |||||
WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?; | |||||
""" | |||||
return | |||||
insert_sql = """ | |||||
INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts) | |||||
VALUES (?, ?, ?, ?, ?, ?) | |||||
""" | |||||
try: | |||||
await self.db_pool.runInteraction( | |||||
"try_acquire_read_write_lock", | |||||
_try_acquire_read_write_lock_txn, | |||||
if isinstance(self.database_engine, PostgresEngine): | |||||
# For Postgres we can send these queries at the same time. | |||||
txn.execute( | |||||
delete_sql + ";" + insert_sql, | |||||
( | |||||
# DELETE args | |||||
now - _LOCK_TIMEOUT_MS, | |||||
lock_name, | |||||
lock_key, | |||||
# UPSERT args | |||||
lock_name, | |||||
lock_key, | |||||
write, | |||||
self._instance_name, | |||||
token, | |||||
now, | |||||
), | |||||
) | |||||
else: | |||||
# For SQLite these need to be two queries. | |||||
txn.execute( | |||||
delete_sql, | |||||
( | |||||
now - _LOCK_TIMEOUT_MS, | |||||
lock_name, | |||||
lock_key, | |||||
), | |||||
) | |||||
txn.execute( | |||||
insert_sql, | |||||
( | |||||
lock_name, | |||||
lock_key, | |||||
write, | |||||
self._instance_name, | |||||
token, | |||||
now, | |||||
), | |||||
) | ) | ||||
except self.database_engine.module.IntegrityError: | |||||
return None | |||||
lock = Lock( | lock = Lock( | ||||
self._reactor, | self._reactor, | ||||
@@ -289,10 +299,58 @@ class LockStore(SQLBaseStore): | |||||
token=token, | token=token, | ||||
) | ) | ||||
self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock | |||||
def set_lock() -> None: | |||||
self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock | |||||
txn.call_after(set_lock) | |||||
return lock | return lock | ||||
async def try_acquire_multi_read_write_lock( | |||||
self, | |||||
lock_names: Collection[Tuple[str, str]], | |||||
write: bool, | |||||
) -> Optional[AsyncExitStack]: | |||||
"""Try to acquire multiple locks for the given names/keys. Will return | |||||
an async context manager if the locks are successfully acquired, which | |||||
*must* be used (otherwise the lock will leak). | |||||
If only a subset of the locks can be acquired then it will immediately | |||||
drop them and return `None`. | |||||
""" | |||||
try: | |||||
locks = await self.db_pool.runInteraction( | |||||
"try_acquire_multi_read_write_lock", | |||||
self._try_acquire_multi_read_write_lock_txn, | |||||
lock_names, | |||||
write, | |||||
) | |||||
except self.database_engine.module.IntegrityError: | |||||
return None | |||||
stack = AsyncExitStack() | |||||
for lock in locks: | |||||
await stack.enter_async_context(lock) | |||||
return stack | |||||
def _try_acquire_multi_read_write_lock_txn( | |||||
self, | |||||
txn: LoggingTransaction, | |||||
lock_names: Collection[Tuple[str, str]], | |||||
write: bool, | |||||
) -> Collection["Lock"]: | |||||
locks = [] | |||||
for lock_name, lock_key in lock_names: | |||||
lock = self._try_acquire_read_write_lock_txn( | |||||
txn, lock_name, lock_key, write | |||||
) | |||||
locks.append(lock) | |||||
return locks | |||||
class Lock: | class Lock: | ||||
"""An async context manager that manages an acquired lock, ensuring it is | """An async context manager that manages an acquired lock, ensuring it is | ||||
@@ -0,0 +1,74 @@ | |||||
# Copyright 2023 The Matrix.org Foundation C.I.C. | |||||
# | |||||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||||
# you may not use this file except in compliance with the License. | |||||
# You may obtain a copy of the License at | |||||
# | |||||
# http://www.apache.org/licenses/LICENSE-2.0 | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, software | |||||
# distributed under the License is distributed on an "AS IS" BASIS, | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# See the License for the specific language governing permissions and | |||||
# limitations under the License. | |||||
from twisted.internet import defer | |||||
from twisted.test.proto_helpers import MemoryReactor | |||||
from synapse.server import HomeServer | |||||
from synapse.util import Clock | |||||
from tests import unittest | |||||
from tests.replication._base import BaseMultiWorkerStreamTestCase | |||||
class WorkerLockTestCase(unittest.HomeserverTestCase): | |||||
def prepare( | |||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | |||||
) -> None: | |||||
self.worker_lock_handler = self.hs.get_worker_locks_handler() | |||||
def test_wait_for_lock_locally(self) -> None: | |||||
"""Test waiting for a lock on a single worker""" | |||||
lock1 = self.worker_lock_handler.acquire_lock("name", "key") | |||||
self.get_success(lock1.__aenter__()) | |||||
lock2 = self.worker_lock_handler.acquire_lock("name", "key") | |||||
d2 = defer.ensureDeferred(lock2.__aenter__()) | |||||
self.assertNoResult(d2) | |||||
self.get_success(lock1.__aexit__(None, None, None)) | |||||
self.get_success(d2) | |||||
self.get_success(lock2.__aexit__(None, None, None)) | |||||
class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase): | |||||
def prepare( | |||||
self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer | |||||
) -> None: | |||||
self.main_worker_lock_handler = self.hs.get_worker_locks_handler() | |||||
def test_wait_for_lock_worker(self) -> None: | |||||
"""Test waiting for a lock on another worker""" | |||||
worker = self.make_worker_hs( | |||||
"synapse.app.generic_worker", | |||||
extra_config={ | |||||
"redis": {"enabled": True}, | |||||
}, | |||||
) | |||||
worker_lock_handler = worker.get_worker_locks_handler() | |||||
lock1 = self.main_worker_lock_handler.acquire_lock("name", "key") | |||||
self.get_success(lock1.__aenter__()) | |||||
lock2 = worker_lock_handler.acquire_lock("name", "key") | |||||
d2 = defer.ensureDeferred(lock2.__aenter__()) | |||||
self.assertNoResult(d2) | |||||
self.get_success(lock1.__aexit__(None, None, None)) | |||||
self.get_success(d2) | |||||
self.get_success(lock2.__aexit__(None, None, None)) |
@@ -711,7 +711,7 @@ class RoomsCreateTestCase(RoomBase): | |||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) | self.assertEqual(HTTPStatus.OK, channel.code, channel.result) | ||||
self.assertTrue("room_id" in channel.json_body) | self.assertTrue("room_id" in channel.json_body) | ||||
assert channel.resource_usage is not None | assert channel.resource_usage is not None | ||||
self.assertEqual(30, channel.resource_usage.db_txn_count) | |||||
self.assertEqual(32, channel.resource_usage.db_txn_count) | |||||
def test_post_room_initial_state(self) -> None: | def test_post_room_initial_state(self) -> None: | ||||
# POST with initial_state config key, expect new room id | # POST with initial_state config key, expect new room id | ||||
@@ -724,7 +724,7 @@ class RoomsCreateTestCase(RoomBase): | |||||
self.assertEqual(HTTPStatus.OK, channel.code, channel.result) | self.assertEqual(HTTPStatus.OK, channel.code, channel.result) | ||||
self.assertTrue("room_id" in channel.json_body) | self.assertTrue("room_id" in channel.json_body) | ||||
assert channel.resource_usage is not None | assert channel.resource_usage is not None | ||||
self.assertEqual(32, channel.resource_usage.db_txn_count) | |||||
self.assertEqual(34, channel.resource_usage.db_txn_count) | |||||
def test_post_room_visibility_key(self) -> None: | def test_post_room_visibility_key(self) -> None: | ||||
# POST with visibility config key, expect new room id | # POST with visibility config key, expect new room id | ||||
@@ -448,3 +448,55 @@ class ReadWriteLockTestCase(unittest.HomeserverTestCase): | |||||
self.get_success(self.store._on_shutdown()) | self.get_success(self.store._on_shutdown()) | ||||
self.assertEqual(self.store._live_read_write_lock_tokens, {}) | self.assertEqual(self.store._live_read_write_lock_tokens, {}) | ||||
def test_acquire_multiple_locks(self) -> None: | |||||
"""Tests that acquiring multiple locks at once works.""" | |||||
# Take out multiple locks and ensure that we can't get those locks out | |||||
# again. | |||||
lock = self.get_success( | |||||
self.store.try_acquire_multi_read_write_lock( | |||||
[("name1", "key1"), ("name2", "key2")], write=True | |||||
) | |||||
) | |||||
self.assertIsNotNone(lock) | |||||
assert lock is not None | |||||
self.get_success(lock.__aenter__()) | |||||
lock2 = self.get_success( | |||||
self.store.try_acquire_read_write_lock("name1", "key1", write=True) | |||||
) | |||||
self.assertIsNone(lock2) | |||||
lock3 = self.get_success( | |||||
self.store.try_acquire_read_write_lock("name2", "key2", write=False) | |||||
) | |||||
self.assertIsNone(lock3) | |||||
# Overlapping locks attempts will fail, and won't lock any locks. | |||||
lock4 = self.get_success( | |||||
self.store.try_acquire_multi_read_write_lock( | |||||
[("name1", "key1"), ("name3", "key3")], write=True | |||||
) | |||||
) | |||||
self.assertIsNone(lock4) | |||||
lock5 = self.get_success( | |||||
self.store.try_acquire_read_write_lock("name3", "key3", write=True) | |||||
) | |||||
self.assertIsNotNone(lock5) | |||||
assert lock5 is not None | |||||
self.get_success(lock5.__aenter__()) | |||||
self.get_success(lock5.__aexit__(None, None, None)) | |||||
# Once we release the lock we can take out the locks again. | |||||
self.get_success(lock.__aexit__(None, None, None)) | |||||
lock6 = self.get_success( | |||||
self.store.try_acquire_read_write_lock("name1", "key1", write=True) | |||||
) | |||||
self.assertIsNotNone(lock6) | |||||
assert lock6 is not None | |||||
self.get_success(lock6.__aenter__()) | |||||
self.get_success(lock6.__aexit__(None, None, None)) |