You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

81 lines
3.2 KiB

  1. # Copyright 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Any, Iterable
  15. from synapse.replication.slave.storage._base import BaseSlavedStore
  16. from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
  17. from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
  18. from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
  19. from synapse.storage.databases.main.devices import DeviceWorkerStore
  20. if TYPE_CHECKING:
  21. from synapse.server import HomeServer
  22. class SlavedDeviceStore(DeviceWorkerStore, BaseSlavedStore):
  23. def __init__(
  24. self,
  25. database: DatabasePool,
  26. db_conn: LoggingDatabaseConnection,
  27. hs: "HomeServer",
  28. ):
  29. self.hs = hs
  30. self._device_list_id_gen = SlavedIdTracker(
  31. db_conn,
  32. "device_lists_stream",
  33. "stream_id",
  34. extra_tables=[
  35. ("user_signature_stream", "stream_id"),
  36. ("device_lists_outbound_pokes", "stream_id"),
  37. ("device_lists_changes_in_room", "stream_id"),
  38. ],
  39. )
  40. super().__init__(database, db_conn, hs)
  41. def get_device_stream_token(self) -> int:
  42. return self._device_list_id_gen.get_current_token()
  43. def process_replication_rows(
  44. self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
  45. ) -> None:
  46. if stream_name == DeviceListsStream.NAME:
  47. self._device_list_id_gen.advance(instance_name, token)
  48. self._invalidate_caches_for_devices(token, rows)
  49. elif stream_name == UserSignatureStream.NAME:
  50. self._device_list_id_gen.advance(instance_name, token)
  51. for row in rows:
  52. self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
  53. return super().process_replication_rows(stream_name, instance_name, token, rows)
  54. def _invalidate_caches_for_devices(
  55. self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
  56. ) -> None:
  57. for row in rows:
  58. # The entities are either user IDs (starting with '@') whose devices
  59. # have changed, or remote servers that we need to tell about
  60. # changes.
  61. if row.entity.startswith("@"):
  62. self._device_list_stream_cache.entity_has_changed(row.entity, token)
  63. self.get_cached_devices_for_user.invalidate((row.entity,))
  64. self._get_cached_user_device.invalidate((row.entity,))
  65. self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
  66. else:
  67. self._device_list_federation_stream_cache.entity_has_changed(
  68. row.entity, token
  69. )