@@ -0,0 +1 @@ | |||
Refactor some code to simplify and better type receipts stream adjacent code. |
@@ -216,7 +216,7 @@ class ApplicationServicesHandler: | |||
def notify_interested_services_ephemeral( | |||
self, | |||
stream_key: str, | |||
stream_key: StreamKeyType, | |||
new_token: Union[int, RoomStreamToken], | |||
users: Collection[Union[str, UserID]], | |||
) -> None: | |||
@@ -326,7 +326,7 @@ class ApplicationServicesHandler: | |||
async def _notify_interested_services_ephemeral( | |||
self, | |||
services: List[ApplicationService], | |||
stream_key: str, | |||
stream_key: StreamKeyType, | |||
new_token: int, | |||
users: Collection[Union[str, UserID]], | |||
) -> None: | |||
@@ -19,7 +19,7 @@ from synapse.api.errors import SynapseError, UnrecognizedRequestError | |||
from synapse.push.clientformat import format_push_rules_for_user | |||
from synapse.storage.push_rule import RuleNotFoundException | |||
from synapse.synapse_rust.push import get_base_rule_ids | |||
from synapse.types import JsonDict, UserID | |||
from synapse.types import JsonDict, StreamKeyType, UserID | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
@@ -114,7 +114,9 @@ class PushRulesHandler: | |||
user_id: the user ID the change is for. | |||
""" | |||
stream_id = self._main_store.get_max_push_rules_stream_id() | |||
self._notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) | |||
self._notifier.on_new_event( | |||
StreamKeyType.PUSH_RULES, stream_id, users=[user_id] | |||
) | |||
async def push_rules_for_user( | |||
self, user: UserID | |||
@@ -130,11 +130,10 @@ class ReceiptsHandler: | |||
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: | |||
"""Takes a list of receipts, stores them and informs the notifier.""" | |||
min_batch_id: Optional[int] = None | |||
max_batch_id: Optional[int] = None | |||
receipts_persisted: List[ReadReceipt] = [] | |||
for receipt in receipts: | |||
res = await self.store.insert_receipt( | |||
stream_id = await self.store.insert_receipt( | |||
receipt.room_id, | |||
receipt.receipt_type, | |||
receipt.user_id, | |||
@@ -143,30 +142,26 @@ class ReceiptsHandler: | |||
receipt.data, | |||
) | |||
if not res: | |||
# res will be None if this receipt is 'old' | |||
if stream_id is None: | |||
# stream_id will be None if this receipt is 'old' | |||
continue | |||
stream_id, max_persisted_id = res | |||
receipts_persisted.append(receipt) | |||
if min_batch_id is None or stream_id < min_batch_id: | |||
min_batch_id = stream_id | |||
if max_batch_id is None or max_persisted_id > max_batch_id: | |||
max_batch_id = max_persisted_id | |||
# Either both of these should be None or neither. | |||
if min_batch_id is None or max_batch_id is None: | |||
if not receipts_persisted: | |||
# no new receipts | |||
return False | |||
affected_room_ids = list({r.room_id for r in receipts}) | |||
max_batch_id = self.store.get_max_receipt_stream_id() | |||
affected_room_ids = list({r.room_id for r in receipts_persisted}) | |||
self.notifier.on_new_event( | |||
StreamKeyType.RECEIPT, max_batch_id, rooms=affected_room_ids | |||
) | |||
# Note that the min here shouldn't be relied upon to be accurate. | |||
await self.hs.get_pusherpool().on_new_receipts( | |||
min_batch_id, max_batch_id, affected_room_ids | |||
{r.user_id for r in receipts_persisted} | |||
) | |||
return True | |||
@@ -126,7 +126,7 @@ class _NotifierUserStream: | |||
def notify( | |||
self, | |||
stream_key: str, | |||
stream_key: StreamKeyType, | |||
stream_id: Union[int, RoomStreamToken], | |||
time_now_ms: int, | |||
) -> None: | |||
@@ -454,7 +454,7 @@ class Notifier: | |||
def on_new_event( | |||
self, | |||
stream_key: str, | |||
stream_key: StreamKeyType, | |||
new_token: Union[int, RoomStreamToken], | |||
users: Optional[Collection[Union[str, UserID]]] = None, | |||
rooms: Optional[StrCollection] = None, | |||
@@ -655,30 +655,29 @@ class Notifier: | |||
events: List[Union[JsonDict, EventBase]] = [] | |||
end_token = from_token | |||
for name, source in self.event_sources.sources.get_sources(): | |||
keyname = "%s_key" % name | |||
before_id = getattr(before_token, keyname) | |||
after_id = getattr(after_token, keyname) | |||
for keyname, source in self.event_sources.sources.get_sources(): | |||
before_id = before_token.get_field(keyname) | |||
after_id = after_token.get_field(keyname) | |||
if before_id == after_id: | |||
continue | |||
new_events, new_key = await source.get_new_events( | |||
user=user, | |||
from_key=getattr(from_token, keyname), | |||
from_key=from_token.get_field(keyname), | |||
limit=limit, | |||
is_guest=is_peeking, | |||
room_ids=room_ids, | |||
explicit_room_id=explicit_room_id, | |||
) | |||
if name == "room": | |||
if keyname == StreamKeyType.ROOM: | |||
new_events = await filter_events_for_client( | |||
self._storage_controllers, | |||
user.to_string(), | |||
new_events, | |||
is_peeking=is_peeking, | |||
) | |||
elif name == "presence": | |||
elif keyname == StreamKeyType.PRESENCE: | |||
now = self.clock.time_msec() | |||
new_events[:] = [ | |||
{ | |||
@@ -182,7 +182,7 @@ class Pusher(metaclass=abc.ABCMeta): | |||
raise NotImplementedError() | |||
@abc.abstractmethod | |||
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: | |||
def on_new_receipts(self) -> None: | |||
raise NotImplementedError() | |||
@abc.abstractmethod | |||
@@ -99,7 +99,7 @@ class EmailPusher(Pusher): | |||
pass | |||
self.timed_call = None | |||
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: | |||
def on_new_receipts(self) -> None: | |||
# We could wake up and cancel the timer but there tend to be quite a | |||
# lot of read receipts so it's probably less work to just let the | |||
# timer fire | |||
@@ -160,7 +160,7 @@ class HttpPusher(Pusher): | |||
if should_check_for_notifs: | |||
self._start_processing() | |||
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: | |||
def on_new_receipts(self) -> None: | |||
# Note that the min here shouldn't be relied upon to be accurate. | |||
# We could check the receipts are actually m.read receipts here, | |||
@@ -292,20 +292,12 @@ class PusherPool: | |||
except Exception: | |||
logger.exception("Exception in pusher on_new_notifications") | |||
async def on_new_receipts( | |||
self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str] | |||
) -> None: | |||
async def on_new_receipts(self, users_affected: StrCollection) -> None: | |||
if not self.pushers: | |||
# nothing to do here. | |||
return | |||
try: | |||
# Need to subtract 1 from the minimum because the lower bound here | |||
# is not inclusive | |||
users_affected = await self.store.get_users_sent_receipts_between( | |||
min_stream_id - 1, max_stream_id | |||
) | |||
for u in users_affected: | |||
# Don't push if the user account has expired | |||
expired = await self._account_validity_handler.is_user_expired(u) | |||
@@ -314,7 +306,7 @@ class PusherPool: | |||
if u in self.pushers: | |||
for p in self.pushers[u].values(): | |||
p.on_new_receipts(min_stream_id, max_stream_id) | |||
p.on_new_receipts() | |||
except Exception: | |||
logger.exception("Exception in pusher on_new_receipts") | |||
@@ -129,9 +129,7 @@ class ReplicationDataHandler: | |||
self.notifier.on_new_event( | |||
StreamKeyType.RECEIPT, token, rooms=[row.room_id for row in rows] | |||
) | |||
await self._pusher_pool.on_new_receipts( | |||
token, token, {row.room_id for row in rows} | |||
) | |||
await self._pusher_pool.on_new_receipts({row.user_id for row in rows}) | |||
elif stream_name == ToDeviceStream.NAME: | |||
entities = [row.entity for row in rows if row.entity.startswith("@")] | |||
if entities: | |||
@@ -208,7 +208,7 @@ class EndToEndRoomKeyStore(EndToEndRoomKeyBackgroundStore): | |||
"message": "Set room key", | |||
"room_id": room_id, | |||
"session_id": session_id, | |||
StreamKeyType.ROOM: room_key, | |||
StreamKeyType.ROOM.value: room_key, | |||
} | |||
) | |||
@@ -742,7 +742,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
event_ids: List[str], | |||
thread_id: Optional[str], | |||
data: dict, | |||
) -> Optional[Tuple[int, int]]: | |||
) -> Optional[int]: | |||
"""Insert a receipt, either from local client or remote server. | |||
Automatically does conversion between linearized and graph | |||
@@ -804,9 +804,7 @@ class ReceiptsWorkerStore(SQLBaseStore): | |||
data, | |||
) | |||
max_persisted_id = self._receipts_id_gen.get_current_token() | |||
return stream_id, max_persisted_id | |||
return stream_id | |||
async def _insert_graph_receipt( | |||
self, | |||
@@ -12,7 +12,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING, Iterator, Tuple | |||
from typing import TYPE_CHECKING, Sequence, Tuple | |||
import attr | |||
@@ -23,7 +23,7 @@ from synapse.handlers.room import RoomEventSource | |||
from synapse.handlers.typing import TypingNotificationEventSource | |||
from synapse.logging.opentracing import trace | |||
from synapse.streams import EventSource | |||
from synapse.types import StreamToken | |||
from synapse.types import StreamKeyType, StreamToken | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
@@ -37,9 +37,14 @@ class _EventSourcesInner: | |||
receipt: ReceiptEventSource | |||
account_data: AccountDataEventSource | |||
def get_sources(self) -> Iterator[Tuple[str, EventSource]]: | |||
for attribute in attr.fields(_EventSourcesInner): | |||
yield attribute.name, getattr(self, attribute.name) | |||
def get_sources(self) -> Sequence[Tuple[StreamKeyType, EventSource]]: | |||
return [ | |||
(StreamKeyType.ROOM, self.room), | |||
(StreamKeyType.PRESENCE, self.presence), | |||
(StreamKeyType.TYPING, self.typing), | |||
(StreamKeyType.RECEIPT, self.receipt), | |||
(StreamKeyType.ACCOUNT_DATA, self.account_data), | |||
] | |||
class EventSources: | |||
@@ -22,8 +22,8 @@ from typing import ( | |||
Any, | |||
ClassVar, | |||
Dict, | |||
Final, | |||
List, | |||
Literal, | |||
Mapping, | |||
Match, | |||
MutableMapping, | |||
@@ -34,6 +34,7 @@ from typing import ( | |||
Type, | |||
TypeVar, | |||
Union, | |||
overload, | |||
) | |||
import attr | |||
@@ -649,20 +650,20 @@ class RoomStreamToken: | |||
return "s%d" % (self.stream,) | |||
class StreamKeyType: | |||
class StreamKeyType(Enum): | |||
"""Known stream types. | |||
A stream is a list of entities ordered by an incrementing "stream token". | |||
""" | |||
ROOM: Final = "room_key" | |||
PRESENCE: Final = "presence_key" | |||
TYPING: Final = "typing_key" | |||
RECEIPT: Final = "receipt_key" | |||
ACCOUNT_DATA: Final = "account_data_key" | |||
PUSH_RULES: Final = "push_rules_key" | |||
TO_DEVICE: Final = "to_device_key" | |||
DEVICE_LIST: Final = "device_list_key" | |||
ROOM = "room_key" | |||
PRESENCE = "presence_key" | |||
TYPING = "typing_key" | |||
RECEIPT = "receipt_key" | |||
ACCOUNT_DATA = "account_data_key" | |||
PUSH_RULES = "push_rules_key" | |||
TO_DEVICE = "to_device_key" | |||
DEVICE_LIST = "device_list_key" | |||
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key" | |||
@@ -784,7 +785,7 @@ class StreamToken: | |||
def room_stream_id(self) -> int: | |||
return self.room_key.stream | |||
def copy_and_advance(self, key: str, new_value: Any) -> "StreamToken": | |||
def copy_and_advance(self, key: StreamKeyType, new_value: Any) -> "StreamToken": | |||
"""Advance the given key in the token to a new value if and only if the | |||
new value is after the old value. | |||
@@ -797,16 +798,44 @@ class StreamToken: | |||
return new_token | |||
new_token = self.copy_and_replace(key, new_value) | |||
new_id = int(getattr(new_token, key)) | |||
old_id = int(getattr(self, key)) | |||
new_id = new_token.get_field(key) | |||
old_id = self.get_field(key) | |||
if old_id < new_id: | |||
return new_token | |||
else: | |||
return self | |||
def copy_and_replace(self, key: str, new_value: Any) -> "StreamToken": | |||
return attr.evolve(self, **{key: new_value}) | |||
def copy_and_replace(self, key: StreamKeyType, new_value: Any) -> "StreamToken": | |||
return attr.evolve(self, **{key.value: new_value}) | |||
@overload | |||
def get_field(self, key: Literal[StreamKeyType.ROOM]) -> RoomStreamToken: | |||
... | |||
@overload | |||
def get_field( | |||
self, | |||
key: Literal[ | |||
StreamKeyType.ACCOUNT_DATA, | |||
StreamKeyType.DEVICE_LIST, | |||
StreamKeyType.PRESENCE, | |||
StreamKeyType.PUSH_RULES, | |||
StreamKeyType.RECEIPT, | |||
StreamKeyType.TO_DEVICE, | |||
StreamKeyType.TYPING, | |||
StreamKeyType.UN_PARTIAL_STATED_ROOMS, | |||
], | |||
) -> int: | |||
... | |||
@overload | |||
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: | |||
... | |||
def get_field(self, key: StreamKeyType) -> Union[int, RoomStreamToken]: | |||
"""Returns the stream ID for the given key.""" | |||
return getattr(self, key.value) | |||
StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0, 0) | |||
@@ -31,7 +31,7 @@ from synapse.appservice import ( | |||
from synapse.handlers.appservice import ApplicationServicesHandler | |||
from synapse.rest.client import login, receipts, register, room, sendtodevice | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict, RoomStreamToken | |||
from synapse.types import JsonDict, RoomStreamToken, StreamKeyType | |||
from synapse.util import Clock | |||
from synapse.util.stringutils import random_string | |||
@@ -304,7 +304,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
) | |||
self.handler.notify_interested_services_ephemeral( | |||
"receipt_key", 580, ["@fakerecipient:example.com"] | |||
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] | |||
) | |||
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( | |||
interested_service, ephemeral=[event] | |||
@@ -332,7 +332,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): | |||
) | |||
self.handler.notify_interested_services_ephemeral( | |||
"receipt_key", 580, ["@fakerecipient:example.com"] | |||
StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] | |||
) | |||
# This method will be called, but with an empty list of events | |||
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( | |||
@@ -634,7 +634,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): | |||
self.get_success( | |||
self.hs.get_application_service_handler()._notify_interested_services_ephemeral( | |||
services=[interested_appservice], | |||
stream_key="receipt_key", | |||
stream_key=StreamKeyType.RECEIPT, | |||
new_token=stream_token, | |||
users=[self.exclusive_as_user], | |||
) | |||
@@ -28,7 +28,7 @@ from synapse.federation.transport.server import TransportLayerServer | |||
from synapse.handlers.typing import TypingWriterHandler | |||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent | |||
from synapse.server import HomeServer | |||
from synapse.types import JsonDict, Requester, UserID, create_requester | |||
from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester | |||
from synapse.util import Clock | |||
from tests import unittest | |||
@@ -203,7 +203,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) | |||
self.on_new_event.assert_has_calls( | |||
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] | |||
) | |||
self.assertEqual(self.event_source.get_current_key(), 1) | |||
events = self.get_success( | |||
@@ -273,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||
) | |||
self.assertEqual(channel.code, 200) | |||
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) | |||
self.on_new_event.assert_has_calls( | |||
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] | |||
) | |||
self.assertEqual(self.event_source.get_current_key(), 1) | |||
events = self.get_success( | |||
@@ -349,7 +353,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) | |||
self.on_new_event.assert_has_calls( | |||
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] | |||
) | |||
self.mock_federation_client.put_json.assert_called_once_with( | |||
"farm", | |||
@@ -399,7 +405,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) | |||
self.on_new_event.assert_has_calls( | |||
[call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] | |||
) | |||
self.on_new_event.reset_mock() | |||
self.assertEqual(self.event_source.get_current_key(), 1) | |||
@@ -425,7 +433,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||
self.reactor.pump([16]) | |||
self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) | |||
self.on_new_event.assert_has_calls( | |||
[call(StreamKeyType.TYPING, 2, rooms=[ROOM_ID])] | |||
) | |||
self.assertEqual(self.event_source.get_current_key(), 2) | |||
events = self.get_success( | |||
@@ -459,7 +469,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): | |||
) | |||
) | |||
self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])]) | |||
self.on_new_event.assert_has_calls( | |||
[call(StreamKeyType.TYPING, 3, rooms=[ROOM_ID])] | |||
) | |||
self.on_new_event.reset_mock() | |||
self.assertEqual(self.event_source.get_current_key(), 3) | |||