Bläddra i källkod

Add type hints for most `HomeServer` parameters (#11095)

tags/v1.46.0rc1
Sean Quah 2 år sedan
committed by GitHub
förälder
incheckning
2b82ec425f
Ingen känd nyckel hittad för denna signaturen i databasen GPG-nyckel ID: 4AEE18F83AFDEB23
58 ändrade filer med 342 tillägg och 143 borttagningar
  1. +1
    -0
      changelog.d/11095.misc
  2. +4
    -4
      synapse/app/_base.py
  3. +2
    -2
      synapse/app/admin_cmd.py
  4. +2
    -2
      synapse/app/generic_worker.py
  5. +1
    -1
      synapse/app/homeserver.py
  6. +6
    -2
      synapse/app/phone_stats_home.py
  7. +2
    -1
      synapse/appservice/api.py
  8. +8
    -1
      synapse/config/logger.py
  9. +6
    -1
      synapse/federation/federation_base.py
  10. +5
    -4
      synapse/federation/federation_server.py
  11. +6
    -2
      synapse/http/matrixfederationclient.py
  12. +12
    -7
      synapse/http/server.py
  13. +7
    -2
      synapse/replication/http/__init__.py
  14. +5
    -3
      synapse/replication/http/_base.py
  15. +9
    -5
      synapse/replication/http/account_data.py
  16. +6
    -2
      synapse/replication/http/devices.py
  17. +10
    -6
      synapse/replication/http/federation.py
  18. +6
    -2
      synapse/replication/http/login.py
  19. +3
    -3
      synapse/replication/http/membership.py
  20. +1
    -1
      synapse/replication/http/presence.py
  21. +1
    -1
      synapse/replication/http/push.py
  22. +7
    -3
      synapse/replication/http/register.py
  23. +6
    -2
      synapse/replication/http/send_event.py
  24. +6
    -2
      synapse/replication/http/streams.py
  25. +5
    -2
      synapse/replication/slave/storage/_base.py
  26. +6
    -1
      synapse/replication/slave/storage/client_ips.py
  27. +6
    -1
      synapse/replication/slave/storage/devices.py
  28. +5
    -1
      synapse/replication/slave/storage/events.py
  29. +6
    -1
      synapse/replication/slave/storage/filtering.py
  30. +6
    -1
      synapse/replication/slave/storage/groups.py
  31. +8
    -1
      synapse/replication/tcp/external_cache.py
  32. +5
    -1
      synapse/replication/tcp/handler.py
  33. +6
    -2
      synapse/replication/tcp/resource.py
  34. +10
    -10
      synapse/replication/tcp/streams/_base.py
  35. +1
    -1
      synapse/rest/admin/devices.py
  36. +8
    -3
      synapse/server.py
  37. +5
    -1
      synapse/storage/database.py
  38. +22
    -6
      synapse/storage/databases/__init__.py
  39. +5
    -2
      synapse/storage/databases/main/__init__.py
  40. +5
    -2
      synapse/storage/databases/main/account_data.py
  41. +5
    -2
      synapse/storage/databases/main/cache.py
  42. +6
    -3
      synapse/storage/databases/main/deviceinbox.py
  43. +17
    -4
      synapse/storage/databases/main/devices.py
  44. +6
    -3
      synapse/storage/databases/main/event_federation.py
  45. +6
    -3
      synapse/storage/databases/main/event_push_actions.py
  46. +5
    -2
      synapse/storage/databases/main/events_bg_updates.py
  47. +6
    -3
      synapse/storage/databases/main/media_repository.py
  48. +5
    -2
      synapse/storage/databases/main/metrics.py
  49. +6
    -3
      synapse/storage/databases/main/monthly_active_users.py
  50. +5
    -2
      synapse/storage/databases/main/push_rule.py
  51. +5
    -2
      synapse/storage/databases/main/receipts.py
  52. +7
    -4
      synapse/storage/databases/main/room.py
  53. +4
    -3
      synapse/storage/databases/main/roommember.py
  54. +6
    -3
      synapse/storage/databases/main/search.py
  55. +7
    -4
      synapse/storage/databases/main/state.py
  56. +5
    -2
      synapse/storage/databases/main/stats.py
  57. +5
    -2
      synapse/storage/databases/main/transactions.py
  58. +5
    -1
      synapse/storage/persist_events.py

+ 1
- 0
changelog.d/11095.misc Visa fil

@@ -0,0 +1 @@
Add type hints to most `HomeServer` parameters.

+ 4
- 4
synapse/app/_base.py Visa fil

@@ -294,7 +294,7 @@ def listen_ssl(
return r


def refresh_certificate(hs):
def refresh_certificate(hs: "HomeServer"):
"""
Refresh the TLS certificates that Synapse is using by re-reading them from
disk and updating the TLS context factories to use them.
@@ -419,11 +419,11 @@ async def start(hs: "HomeServer"):
atexit.register(gc.freeze)


def setup_sentry(hs):
def setup_sentry(hs: "HomeServer"):
"""Enable sentry integration, if enabled in configuration

Args:
hs (synapse.server.HomeServer)
hs
"""

if not hs.config.metrics.sentry_enabled:
@@ -449,7 +449,7 @@ def setup_sentry(hs):
scope.set_tag("worker_name", name)


def setup_sdnotify(hs):
def setup_sdnotify(hs: "HomeServer"):
"""Adds process state hooks to tell systemd what we are up to."""

# Tell systemd our state, if we're using it. This will silently fail if


+ 2
- 2
synapse/app/admin_cmd.py Visa fil

@@ -68,11 +68,11 @@ class AdminCmdServer(HomeServer):
DATASTORE_CLASS = AdminCmdSlavedStore


async def export_data_command(hs, args):
async def export_data_command(hs: HomeServer, args):
"""Export data for a user.

Args:
hs (HomeServer)
hs
args (argparse.Namespace)
"""



+ 2
- 2
synapse/app/generic_worker.py Visa fil

@@ -131,10 +131,10 @@ class KeyUploadServlet(RestServlet):

PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")

def __init__(self, hs):
def __init__(self, hs: HomeServer):
"""
Args:
hs (synapse.server.HomeServer): server
hs: server
"""
super().__init__()
self.auth = hs.get_auth()


+ 1
- 1
synapse/app/homeserver.py Visa fil

@@ -412,7 +412,7 @@ def format_config_error(e: ConfigError) -> Iterator[str]:
e = e.__cause__


def run(hs):
def run(hs: HomeServer):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:



+ 6
- 2
synapse/app/phone_stats_home.py Visa fil

@@ -15,11 +15,15 @@ import logging
import math
import resource
import sys
from typing import TYPE_CHECKING

from prometheus_client import Gauge

from synapse.metrics.background_process_metrics import wrap_as_background_process

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger("synapse.app.homeserver")

# Contains the list of processes we will be monitoring
@@ -41,7 +45,7 @@ registered_reserved_users_mau_gauge = Gauge(


@wrap_as_background_process("phone_stats_home")
async def phone_stats_home(hs, stats, stats_process=_stats_process):
async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process):
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time())
uptime = int(now - hs.start_time)
@@ -142,7 +146,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.warning("Error reporting stats: %s", e)


def start_phone_stats_home(hs):
def start_phone_stats_home(hs: "HomeServer"):
"""
Start the background tasks which report phone home stats.
"""


+ 2
- 1
synapse/appservice/api.py Visa fil

@@ -27,6 +27,7 @@ from synapse.util.caches.response_cache import ResponseCache

if TYPE_CHECKING:
from synapse.appservice import ApplicationService
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

@@ -84,7 +85,7 @@ class ApplicationServiceApi(SimpleHttpClient):
pushing.
"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()



+ 8
- 1
synapse/config/logger.py Visa fil

@@ -18,6 +18,7 @@ import os
import sys
import threading
from string import Template
from typing import TYPE_CHECKING

import yaml
from zope.interface import implementer
@@ -38,6 +39,9 @@ from synapse.util.versionstring import get_version_string

from ._base import Config, ConfigError

if TYPE_CHECKING:
from synapse.server import HomeServer

DEFAULT_LOG_CONFIG = Template(
"""\
# Log configuration for Synapse.
@@ -306,7 +310,10 @@ def _reload_logging_config(log_config_path):


def setup_logging(
hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner
hs: "HomeServer",
config,
use_worker_options=False,
logBeginner: LogBeginner = globalLogBeginner,
) -> None:
"""
Set up the logging subsystem.


+ 6
- 1
synapse/federation/federation_base.py Visa fil

@@ -14,6 +14,7 @@
# limitations under the License.
import logging
from collections import namedtuple
from typing import TYPE_CHECKING

from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
from synapse.api.errors import Codes, SynapseError
@@ -25,11 +26,15 @@ from synapse.events.utils import prune_event, validate_canonicaljson
from synapse.http.servlet import assert_params_in_dict
from synapse.types import JsonDict, get_domain_from_id

if TYPE_CHECKING:
from synapse.server import HomeServer


logger = logging.getLogger(__name__)


class FederationBase:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs

self.server_name = hs.hostname


+ 5
- 4
synapse/federation/federation_server.py Visa fil

@@ -467,7 +467,7 @@ class FederationServer(FederationBase):

async def on_room_state_request(
self, origin: str, room_id: str, event_id: Optional[str]
) -> Tuple[int, Dict[str, Any]]:
) -> Tuple[int, JsonDict]:
origin_host, _ = parse_server_name(origin)
await self.check_server_matches_acl(origin_host, room_id)

@@ -481,7 +481,7 @@ class FederationServer(FederationBase):
# - but that's non-trivial to get right, and anyway somewhat defeats
# the point of the linearizer.
with (await self._server_linearizer.queue((origin, room_id))):
resp = dict(
resp: JsonDict = dict(
await self._state_resp_cache.wrap(
(room_id, event_id),
self._on_context_state_request_compute,
@@ -1061,11 +1061,12 @@ class FederationServer(FederationBase):

origin, event = next

lock = await self.store.try_acquire_lock(
new_lock = await self.store.try_acquire_lock(
_INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
)
if not lock:
if not new_lock:
return
lock = new_lock

def __str__(self) -> str:
return "<ReplicationLayer(%s)>" % self.server_name


+ 6
- 2
synapse/http/matrixfederationclient.py Visa fil

@@ -21,6 +21,7 @@ import typing
import urllib.parse
from io import BytesIO, StringIO
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Generic,
@@ -73,6 +74,9 @@ from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

outgoing_requests_counter = Counter(
@@ -319,7 +323,7 @@ class MatrixFederationHttpClient:
requests.
"""

def __init__(self, hs, tls_client_options_factory):
def __init__(self, hs: "HomeServer", tls_client_options_factory):
self.hs = hs
self.signing_key = hs.signing_key
self.server_name = hs.hostname
@@ -711,7 +715,7 @@ class MatrixFederationHttpClient:
Returns:
A list of headers to be added as "Authorization:" headers
"""
request = {
request: JsonDict = {
"method": method.decode("ascii"),
"uri": url_bytes.decode("ascii"),
"origin": self.server_name,


+ 12
- 7
synapse/http/server.py Visa fil

@@ -22,6 +22,7 @@ import urllib
from http import HTTPStatus
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
@@ -61,6 +62,9 @@ from synapse.util import json_encoder
from synapse.util.caches import intern_dict
from synapse.util.iterutils import chunk_seq

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
@@ -343,6 +347,11 @@ class DirectServeJsonResource(_AsyncResource):
return_json_error(f, request)


_PathEntry = collections.namedtuple(
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)


class JsonResource(DirectServeJsonResource):
"""This implements the HttpServer interface and provides JSON support for
Resources.
@@ -359,14 +368,10 @@ class JsonResource(DirectServeJsonResource):

isLeaf = True

_PathEntry = collections.namedtuple(
"_PathEntry", ["pattern", "callback", "servlet_classname"]
)

def __init__(self, hs, canonical_json=True, extract_context=False):
def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False):
super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
self.path_regexs = {}
self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
self.hs = hs

def register_paths(self, method, path_patterns, callback, servlet_classname):
@@ -391,7 +396,7 @@ class JsonResource(DirectServeJsonResource):
for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append(
self._PathEntry(path_pattern, callback, servlet_classname)
_PathEntry(path_pattern, callback, servlet_classname)
)

def _get_handler_for_request(


+ 7
- 2
synapse/replication/http/__init__.py Visa fil

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from synapse.http.server import JsonResource
from synapse.replication.http import (
account_data,
@@ -26,16 +28,19 @@ from synapse.replication.http import (
streams,
)

if TYPE_CHECKING:
from synapse.server import HomeServer

REPLICATION_PREFIX = "/_synapse/replication"


class ReplicationRestResource(JsonResource):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
# We enable extracting jaeger contexts here as these are internal APIs.
super().__init__(hs, canonical_json=False, extract_context=True)
self.register_servlets(hs)

def register_servlets(self, hs):
def register_servlets(self, hs: "HomeServer"):
send_event.register_servlets(hs, self)
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)


+ 5
- 3
synapse/replication/http/_base.py Visa fil

@@ -17,7 +17,7 @@ import logging
import re
import urllib
from inspect import signature
from typing import TYPE_CHECKING, Dict, List, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple

from prometheus_client import Counter, Gauge

@@ -156,7 +156,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
pass

@classmethod
def make_client(cls, hs):
def make_client(cls, hs: "HomeServer"):
"""Create a client that makes requests.

Returns a callable that accepts the same parameters as
@@ -208,7 +208,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
url_args.append(txn_id)

if cls.METHOD == "POST":
request_func = client.post_json_get_json
request_func: Callable[
..., Awaitable[Any]
] = client.post_json_get_json
elif cls.METHOD == "PUT":
request_func = client.put_json
elif cls.METHOD == "GET":


+ 9
- 5
synapse/replication/http/account_data.py Visa fil

@@ -13,10 +13,14 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -37,7 +41,7 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint):
PATH_ARGS = ("user_id", "account_data_type")
CACHE = False

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.handler = hs.get_account_data_handler()
@@ -78,7 +82,7 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint):
PATH_ARGS = ("user_id", "room_id", "account_data_type")
CACHE = False

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.handler = hs.get_account_data_handler()
@@ -119,7 +123,7 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint):
PATH_ARGS = ("user_id", "room_id", "tag")
CACHE = False

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.handler = hs.get_account_data_handler()
@@ -162,7 +166,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
)
CACHE = False

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.handler = hs.get_account_data_handler()
@@ -183,7 +187,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint):
return 200, {"max_stream_id": max_stream_id}


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server):
ReplicationUserAccountDataRestServlet(hs).register(http_server)
ReplicationRoomAccountDataRestServlet(hs).register(http_server)
ReplicationAddTagRestServlet(hs).register(http_server)


+ 6
- 2
synapse/replication/http/devices.py Visa fil

@@ -13,9 +13,13 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

from synapse.replication.http._base import ReplicationEndpoint

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -51,7 +55,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
PATH_ARGS = ("user_id",)
CACHE = False

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.device_list_updater = hs.get_device_handler().device_list_updater
@@ -68,5 +72,5 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint):
return 200, user_devices


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server):
ReplicationUserDevicesResyncRestServlet(hs).register(http_server)

+ 10
- 6
synapse/replication/http/federation.py Visa fil

@@ -13,6 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
@@ -21,6 +22,9 @@ from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -56,7 +60,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
NAME = "fed_send_events"
PATH_ARGS = ()

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.store = hs.get_datastore()
@@ -151,7 +155,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
NAME = "fed_send_edu"
PATH_ARGS = ("edu_type",)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.store = hs.get_datastore()
@@ -194,7 +198,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
# This is a query, so let's not bother caching
CACHE = False

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.store = hs.get_datastore()
@@ -238,7 +242,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
NAME = "fed_cleanup_room"
PATH_ARGS = ("room_id",)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.store = hs.get_datastore()
@@ -273,7 +277,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
NAME = "store_room_on_outlier_membership"
PATH_ARGS = ("room_id",)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.store = hs.get_datastore()
@@ -289,7 +293,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
return 200, {}


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server):
ReplicationFederationSendEventsRestServlet(hs).register(http_server)
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)


+ 6
- 2
synapse/replication/http/login.py Visa fil

@@ -13,10 +13,14 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -30,7 +34,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
NAME = "device_check_registered"
PATH_ARGS = ("user_id",)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.registration_handler = hs.get_registration_handler()

@@ -82,5 +86,5 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
return 200, res


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server):
RegisterDeviceReplicationServlet(hs).register(http_server)

+ 3
- 3
synapse/replication/http/membership.py Visa fil

@@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
NAME = "remote_join"
PATH_ARGS = ("room_id", "user_id")

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.federation_handler = hs.get_federation_handler()
@@ -320,7 +320,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
PATH_ARGS = ("room_id", "user_id", "change")
CACHE = False # No point caching as should return instantly.

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.registeration_handler = hs.get_registration_handler()
@@ -360,7 +360,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
return 200, {}


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server):
ReplicationRemoteJoinRestServlet(hs).register(http_server)
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server)
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server)

+ 1
- 1
synapse/replication/http/presence.py Visa fil

@@ -117,6 +117,6 @@ class ReplicationPresenceSetState(ReplicationEndpoint):
)


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server):
ReplicationBumpPresenceActiveTime(hs).register(http_server)
ReplicationPresenceSetState(hs).register(http_server)

+ 1
- 1
synapse/replication/http/push.py Visa fil

@@ -67,5 +67,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return 200, {}


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server):
ReplicationRemovePusherRestServlet(hs).register(http_server)

+ 7
- 3
synapse/replication/http/register.py Visa fil

@@ -13,10 +13,14 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -26,7 +30,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
NAME = "register_user"
PATH_ARGS = ("user_id",)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
self.registration_handler = hs.get_registration_handler()
@@ -100,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
NAME = "post_register"
PATH_ARGS = ("user_id",)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.store = hs.get_datastore()
self.registration_handler = hs.get_registration_handler()
@@ -130,6 +134,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
return 200, {}


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server):
ReplicationRegisterServlet(hs).register(http_server)
ReplicationPostRegisterActionsServlet(hs).register(http_server)

+ 6
- 2
synapse/replication/http/send_event.py Visa fil

@@ -13,6 +13,7 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
@@ -22,6 +23,9 @@ from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import Requester, UserID
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -57,7 +61,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
NAME = "send_event"
PATH_ARGS = ("event_id",)

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.event_creation_handler = hs.get_event_creation_handler()
@@ -135,5 +139,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
)


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server):
ReplicationSendEventRestServlet(hs).register(http_server)

+ 6
- 2
synapse/replication/http/streams.py Visa fil

@@ -13,11 +13,15 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING

from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer
from synapse.replication.http._base import ReplicationEndpoint

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -46,7 +50,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
PATH_ARGS = ("stream_name",)
METHOD = "GET"

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self._instance_name = hs.get_instance_name()
@@ -74,5 +78,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint):
)


def register_servlets(hs, http_server):
def register_servlets(hs: "HomeServer", http_server):
ReplicationGetStreamUpdates(hs).register(http_server)

+ 5
- 2
synapse/replication/slave/storage/_base.py Visa fil

@@ -13,18 +13,21 @@
# limitations under the License.

import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional

from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class BaseSlavedStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen: Optional[


+ 6
- 1
synapse/replication/slave/storage/client_ips.py Visa fil

@@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.lrucache import LruCache

from ._base import BaseSlavedStore

if TYPE_CHECKING:
from synapse.server import HomeServer


class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.client_ip_last_seen: LruCache[tuple, int] = LruCache(


+ 6
- 1
synapse/replication/slave/storage/devices.py Visa fil

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
@@ -20,9 +22,12 @@ from synapse.storage.databases.main.devices import DeviceWorkerStore
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache

if TYPE_CHECKING:
from synapse.server import HomeServer


class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.hs = hs


+ 5
- 1
synapse/replication/slave/storage/events.py Visa fil

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING

from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore
@@ -30,6 +31,9 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache

from ._base import BaseSlavedStore

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -54,7 +58,7 @@ class SlavedEventStore(
RelationsWorkerStore,
BaseSlavedStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

events_max = self._stream_id_gen.get_current_token()


+ 6
- 1
synapse/replication/slave/storage/filtering.py Visa fil

@@ -12,14 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.filtering import FilteringStore

from ._base import BaseSlavedStore

if TYPE_CHECKING:
from synapse.server import HomeServer


class SlavedFilteringStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

# Filters are immutable so this cache doesn't need to be expired


+ 6
- 1
synapse/replication/slave/storage/groups.py Visa fil

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import GroupServerStream
@@ -19,9 +21,12 @@ from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.group_server import GroupServerWorkerStore
from synapse.util.caches.stream_change_cache import StreamChangeCache

if TYPE_CHECKING:
from synapse.server import HomeServer


class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.hs = hs


+ 8
- 1
synapse/replication/tcp/external_cache.py Visa fil

@@ -21,6 +21,8 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.util import json_decoder, json_encoder

if TYPE_CHECKING:
from txredisapi import RedisProtocol

from synapse.server import HomeServer

set_counter = Counter(
@@ -59,7 +61,12 @@ class ExternalCache:
"""

def __init__(self, hs: "HomeServer"):
self._redis_connection = hs.get_outbound_redis_connection()
if hs.config.redis.redis_enabled:
self._redis_connection: Optional[
"RedisProtocol"
] = hs.get_outbound_redis_connection()
else:
self._redis_connection = None

def _get_redis_key(self, cache_name: str, key: str) -> str:
return "cache_v1:%s:%s" % (cache_name, key)


+ 5
- 1
synapse/replication/tcp/handler.py Visa fil

@@ -294,7 +294,7 @@ class ReplicationCommandHandler:
# This shouldn't be possible
raise Exception("Unrecognised command %s in stream queue", cmd.NAME)

def start_replication(self, hs):
def start_replication(self, hs: "HomeServer"):
"""Helper method to start a replication connection to the remote server
using TCP.
"""
@@ -321,6 +321,8 @@ class ReplicationCommandHandler:
hs.config.redis.redis_host, # type: ignore[arg-type]
hs.config.redis.redis_port,
self._factory,
timeout=30,
bindAddress=None,
)
else:
client_name = hs.get_instance_name()
@@ -331,6 +333,8 @@ class ReplicationCommandHandler:
host, # type: ignore[arg-type]
port,
self._factory,
timeout=30,
bindAddress=None,
)

def get_streams(self) -> Dict[str, Stream]:


+ 6
- 2
synapse/replication/tcp/resource.py Visa fil

@@ -16,6 +16,7 @@

import logging
import random
from typing import TYPE_CHECKING

from prometheus_client import Counter

@@ -27,6 +28,9 @@ from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import EventsStream
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer

stream_updates_counter = Counter(
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"]
)
@@ -37,7 +41,7 @@ logger = logging.getLogger(__name__)
class ReplicationStreamProtocolFactory(Factory):
"""Factory for new replication connections."""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.command_handler = hs.get_tcp_replication()
self.clock = hs.get_clock()
self.server_name = hs.config.server.server_name
@@ -65,7 +69,7 @@ class ReplicationStreamer:
data is available it will propagate to all connected clients.
"""

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()


+ 10
- 10
synapse/replication/tcp/streams/_base.py Visa fil

@@ -241,7 +241,7 @@ class BackfillStream(Stream):
NAME = "backfill"
ROW_TYPE = BackfillStreamRow

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -363,7 +363,7 @@ class ReceiptsStream(Stream):
NAME = "receipts"
ROW_TYPE = ReceiptsStreamRow

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -380,7 +380,7 @@ class PushRulesStream(Stream):
NAME = "push_rules"
ROW_TYPE = PushRulesStreamRow

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()

super().__init__(
@@ -405,7 +405,7 @@ class PushersStream(Stream):
NAME = "pushers"
ROW_TYPE = PushersStreamRow

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()

super().__init__(
@@ -438,7 +438,7 @@ class CachesStream(Stream):
NAME = "caches"
ROW_TYPE = CachesStreamRow

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -459,7 +459,7 @@ class DeviceListsStream(Stream):
NAME = "device_lists"
ROW_TYPE = DeviceListsStreamRow

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -476,7 +476,7 @@ class ToDeviceStream(Stream):
NAME = "to_device"
ROW_TYPE = ToDeviceStreamRow

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -495,7 +495,7 @@ class TagAccountDataStream(Stream):
NAME = "tag_account_data"
ROW_TYPE = TagAccountDataStreamRow

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -582,7 +582,7 @@ class GroupServerStream(Stream):
NAME = "groups"
ROW_TYPE = GroupsStreamRow

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
@@ -599,7 +599,7 @@ class UserSignatureStream(Stream):
NAME = "user_signature"
ROW_TYPE = UserSignatureStreamRow

def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),


+ 1
- 1
synapse/rest/admin/devices.py Visa fil

@@ -110,7 +110,7 @@ class DevicesRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
hs: server
"""
self.hs = hs
self.auth = hs.get_auth()


+ 8
- 3
synapse/server.py Visa fil

@@ -800,9 +800,14 @@ class HomeServer(metaclass=abc.ABCMeta):
return ExternalCache(self)

@cache_in_self
def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
if not self.config.redis.redis_enabled:
return None
def get_outbound_redis_connection(self) -> "RedisProtocol":
"""
The Redis connection used for replication.

Raises:
AssertionError: if Redis is not enabled in the homeserver config.
"""
assert self.config.redis.redis_enabled

# We only want to import redis module if we're using it, as we have
# `txredisapi` as an optional dependency.


+ 5
- 1
synapse/storage/database.py Visa fil

@@ -19,6 +19,7 @@ from collections import defaultdict
from sys import intern
from time import monotonic as monotonic_time
from typing import (
TYPE_CHECKING,
Any,
Callable,
Collection,
@@ -52,6 +53,9 @@ from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor

if TYPE_CHECKING:
from synapse.server import HomeServer

# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1

@@ -392,7 +396,7 @@ class DatabasePool:

def __init__(
self,
hs,
hs: "HomeServer",
database_config: DatabaseConnectionConfig,
engine: BaseDatabaseEngine,
):


+ 22
- 6
synapse/storage/databases/__init__.py Visa fil

@@ -13,33 +13,49 @@
# limitations under the License.

import logging
from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_conn
from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.databases.state import StateGroupDataStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class Databases:
DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True)


class Databases(Generic[DataStoreT]):
"""The various databases.

These are low level interfaces to physical databases.

Attributes:
main (DataStore)
databases
main
state
persist_events
"""

def __init__(self, main_store_class, hs):
databases: List[DatabasePool]
main: DataStoreT
state: StateGroupDataStore
persist_events: Optional[PersistEventsStore]

def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"):
# Note we pass in the main store class here as workers use a different main
# store.

self.databases = []
main = None
state = None
persist_events = None
main: Optional[DataStoreT] = None
state: Optional[StateGroupDataStore] = None
persist_events: Optional[PersistEventsStore] = None

for database_config in hs.config.database.databases:
db_name = database_config.name


+ 5
- 2
synapse/storage/databases/main/__init__.py Visa fil

@@ -15,7 +15,7 @@
# limitations under the License.

import logging
from typing import List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple

from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import DatabasePool
@@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore
from .user_directory import UserDirectoryStore
from .user_erasure_store import UserErasureStore

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -126,7 +129,7 @@ class DataStore(
LockStore,
SessionStore,
):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self.hs = hs
self._clock = hs.get_clock()
self.database_engine = database.engine


+ 5
- 2
synapse/storage/databases/main/account_data.py Visa fil

@@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import Dict, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple

from synapse.api.constants import AccountDataTypes
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@@ -28,6 +28,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore):
`get_max_account_data_stream_id` which can be called in the initializer.
"""

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name = hs.get_instance_name()

if isinstance(database.engine, PostgresEngine):


+ 5
- 2
synapse/storage/databases/main/cache.py Visa fil

@@ -15,7 +15,7 @@

import itertools
import logging
from typing import Any, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple

from synapse.api.constants import EventTypes
from synapse.replication.tcp.streams import BackfillStream, CachesStream
@@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake"


class CacheInvalidationWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()


+ 6
- 3
synapse/storage/databases/main/deviceinbox.py Visa fil

@@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import List, Optional, Tuple
from typing import TYPE_CHECKING, List, Optional, Tuple

from synapse.logging import issue9533_logger
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -26,11 +26,14 @@ from synapse.util import json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class DeviceInboxWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self._instance_name = hs.get_instance_name()
@@ -553,7 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_index_update(


+ 17
- 4
synapse/storage/databases/main/devices.py Visa fil

@@ -15,7 +15,17 @@
# limitations under the License.
import abc
import logging
from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
)

from synapse.api.errors import Codes, StoreError
from synapse.logging.opentracing import (
@@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = (
@@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes"


class DeviceWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

if hs.config.worker.run_background_tasks:
@@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore):


class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_index_update(
@@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):


class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

# Map of (user_id, device_id) -> bool. If there is an entry that implies


+ 6
- 3
synapse/storage/databases/main/event_federation.py Visa fil

@@ -14,7 +14,7 @@
import itertools
import logging
from queue import Empty, PriorityQueue
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple

from prometheus_client import Counter, Gauge

@@ -34,6 +34,9 @@ from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter

if TYPE_CHECKING:
from synapse.server import HomeServer

oldest_pdu_in_federation_staging = Gauge(
"synapse_federation_server_oldest_inbound_pdu_in_staging",
"The age in seconds since we received the oldest pdu in the federation staging area",
@@ -59,7 +62,7 @@ class _NoChainCoverIndex(Exception):


class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

if hs.config.worker.run_background_tasks:
@@ -1511,7 +1514,7 @@ class EventFederationStore(EventFederationWorkerStore):

EVENT_AUTH_STATE_ONLY = "event_auth_state_only"

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_update_handler(


+ 6
- 3
synapse/storage/databases/main/event_push_actions.py Visa fil

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union

import attr

@@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight):


class EventPushActionsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

# These get correctly set by _find_stream_orderings_for_times_txn
@@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
class EventPushActionsStore(EventPushActionsWorkerStore):
EPA_HIGHLIGHT_INDEX = "epa_highlight_index"

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_index_update(


+ 5
- 2
synapse/storage/databases/main/events_bg_updates.py Visa fil

@@ -13,7 +13,7 @@
# limitations under the License.

import logging
from typing import Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple

import attr

@@ -26,6 +26,9 @@ from synapse.storage.databases.main.events import PersistEventsStore
from synapse.storage.types import Cursor
from synapse.types import JsonDict

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -76,7 +79,7 @@ class _CalculateChainCover:


class EventsBackgroundUpdatesStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_update_handler(


+ 6
- 3
synapse/storage/databases/main/media_repository.py Visa fil

@@ -13,11 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple

from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool

if TYPE_CHECKING:
from synapse.server import HomeServer

BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = (
"media_repository_drop_index_wo_method"
)
@@ -43,7 +46,7 @@ class MediaSortOrder(Enum):


class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.db_pool.updates.register_background_index_update(
@@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore):
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.server_name = hs.hostname



+ 5
- 2
synapse/storage/databases/main/metrics.py Visa fil

@@ -14,7 +14,7 @@
import calendar
import logging
import time
from typing import Dict
from typing import TYPE_CHECKING, Dict

from synapse.metrics import GaugeBucketCollector
from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import (
EventPushActionsWorkerStore,
)

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

# Collect metrics on the number of forward extremities that exist.
@@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore):
stats and prometheus metrics.
"""

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

# Read the extrems every 60 minutes


+ 6
- 3
synapse/storage/databases/main/monthly_active_users.py Visa fil

@@ -12,13 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional

from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.util.caches.descriptors import cached

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

# Number of msec of granularity to store the monthly_active_user timestamp
@@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000


class MonthlyActiveUsersWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
self.hs = hs
@@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore):


class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self._mau_stats_only = hs.config.server.mau_stats_only


+ 5
- 2
synapse/storage/databases/main/push_rule.py Visa fil

@@ -14,7 +14,7 @@
# limitations under the License.
import abc
import logging
from typing import Dict, List, Tuple, Union
from typing import TYPE_CHECKING, Dict, List, Tuple, Union

from synapse.api.errors import NotFoundError, StoreError
from synapse.push.baserules import list_with_base_rules
@@ -33,6 +33,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -75,7 +78,7 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer.
"""

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

if hs.config.worker.worker_app is None:


+ 5
- 2
synapse/storage/databases/main/receipts.py Visa fil

@@ -14,7 +14,7 @@
# limitations under the License.

import logging
from typing import Any, Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple

from twisted.internet import defer

@@ -29,11 +29,14 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


class ReceiptsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
self._instance_name = hs.get_instance_name()

if isinstance(database.engine, PostgresEngine):


+ 7
- 4
synapse/storage/databases/main/room.py Visa fil

@@ -17,7 +17,7 @@ import collections
import logging
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from synapse.api.constants import EventContentFields, EventTypes, JoinRules
from synapse.api.errors import StoreError
@@ -32,6 +32,9 @@ from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import MXC_REGEX

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -69,7 +72,7 @@ class RoomSortOrder(Enum):


class RoomWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.config = hs.config
@@ -1026,7 +1029,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (


class RoomBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.config = hs.config
@@ -1411,7 +1414,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore):


class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.config = hs.config


+ 4
- 3
synapse/storage/databases/main/roommember.py Visa fil

@@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.state import _StateCacheEntry

logger = logging.getLogger(__name__)
@@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"


class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

# Used by `_get_joined_hosts` to ensure only one thing mutates the cache
@@ -982,7 +983,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):


class RoomMemberBackgroundUpdateStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
@@ -1132,7 +1133,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):


class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

async def forget(self, user_id: str, room_id: str) -> None:


+ 6
- 3
synapse/storage/databases/main/search.py Visa fil

@@ -15,7 +15,7 @@
import logging
import re
from collections import namedtuple
from typing import Collection, Iterable, List, Optional, Set
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set

from synapse.api.errors import SynapseError
from synapse.events import EventBase
@@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

SearchEntry = namedtuple(
@@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist"
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

if not hs.config.server.enable_search:
@@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):


class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

async def search_msgs(self, room_ids, search_term, keys):


+ 7
- 4
synapse/storage/databases/main/state.py Visa fil

@@ -15,7 +15,7 @@
import collections.abc
import logging
from collections import namedtuple
from typing import Iterable, Optional, Set
from typing import TYPE_CHECKING, Iterable, Optional, Set

from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -30,6 +30,9 @@ from synapse.types import StateMap
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedList

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)


@@ -53,7 +56,7 @@ class _GetStateGroupDelta(
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"""The parts of StateGroupStore that can be called from workers."""

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

async def get_room_version(self, room_id: str) -> RoomVersion:
@@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events"

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname
@@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore):
* `state_groups_state`: Maps state group to state events.
"""

def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

+ 5
- 2
synapse/storage/databases/main/stats.py Visa fil

@@ -16,7 +16,7 @@
import logging
from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple

from typing_extensions import Counter

@@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

# these fields track absolutes (e.g. total number of rooms on the server)
@@ -93,7 +96,7 @@ class UserSortOrder(Enum):


class StatsStore(StateDeltasStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

self.server_name = hs.hostname


+ 5
- 2
synapse/storage/databases/main/transactions.py Visa fil

@@ -14,7 +14,7 @@

import logging
from collections import namedtuple
from typing import Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple

import attr
from canonicaljson import encode_canonical_json
@@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached

if TYPE_CHECKING:
from synapse.server import HomeServer

db_binary_type = memoryview

logger = logging.getLogger(__name__)
@@ -57,7 +60,7 @@ class DestinationRetryTimings:


class TransactionWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs)

if hs.config.worker.run_background_tasks:


+ 5
- 1
synapse/storage/persist_events.py Visa fil

@@ -18,6 +18,7 @@ import itertools
import logging
from collections import deque
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
@@ -56,6 +57,9 @@ from synapse.types import (
from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results
from synapse.util.metrics import Measure

if TYPE_CHECKING:
from synapse.server import HomeServer

logger = logging.getLogger(__name__)

# The number of times we are recalculating the current state
@@ -272,7 +276,7 @@ class EventsPersistenceStorage:
current state and forward extremity changes.
"""

def __init__(self, hs, stores: Databases):
def __init__(self, hs: "HomeServer", stores: Databases):
# We ultimately want to split out the state store from the main store,
# so we use separate variables here even though they point to the same
# store for now.


Laddar…
Avbryt
Spara