@@ -0,0 +1 @@ | |||
Add type hints to most `HomeServer` parameters. |
@@ -294,7 +294,7 @@ def listen_ssl( | |||
return r | |||
def refresh_certificate(hs): | |||
def refresh_certificate(hs: "HomeServer"): | |||
""" | |||
Refresh the TLS certificates that Synapse is using by re-reading them from | |||
disk and updating the TLS context factories to use them. | |||
@@ -419,11 +419,11 @@ async def start(hs: "HomeServer"): | |||
atexit.register(gc.freeze) | |||
def setup_sentry(hs): | |||
def setup_sentry(hs: "HomeServer"): | |||
"""Enable sentry integration, if enabled in configuration | |||
Args: | |||
hs (synapse.server.HomeServer) | |||
hs | |||
""" | |||
if not hs.config.metrics.sentry_enabled: | |||
@@ -449,7 +449,7 @@ def setup_sentry(hs): | |||
scope.set_tag("worker_name", name) | |||
def setup_sdnotify(hs): | |||
def setup_sdnotify(hs: "HomeServer"): | |||
"""Adds process state hooks to tell systemd what we are up to.""" | |||
# Tell systemd our state, if we're using it. This will silently fail if | |||
@@ -68,11 +68,11 @@ class AdminCmdServer(HomeServer): | |||
DATASTORE_CLASS = AdminCmdSlavedStore | |||
async def export_data_command(hs, args): | |||
async def export_data_command(hs: HomeServer, args): | |||
"""Export data for a user. | |||
Args: | |||
hs (HomeServer) | |||
hs | |||
args (argparse.Namespace) | |||
""" | |||
@@ -131,10 +131,10 @@ class KeyUploadServlet(RestServlet): | |||
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") | |||
def __init__(self, hs): | |||
def __init__(self, hs: HomeServer): | |||
""" | |||
Args: | |||
hs (synapse.server.HomeServer): server | |||
hs: server | |||
""" | |||
super().__init__() | |||
self.auth = hs.get_auth() | |||
@@ -412,7 +412,7 @@ def format_config_error(e: ConfigError) -> Iterator[str]: | |||
e = e.__cause__ | |||
def run(hs): | |||
def run(hs: HomeServer): | |||
PROFILE_SYNAPSE = False | |||
if PROFILE_SYNAPSE: | |||
@@ -15,11 +15,15 @@ import logging | |||
import math | |||
import resource | |||
import sys | |||
from typing import TYPE_CHECKING | |||
from prometheus_client import Gauge | |||
from synapse.metrics.background_process_metrics import wrap_as_background_process | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger("synapse.app.homeserver") | |||
# Contains the list of processes we will be monitoring | |||
@@ -41,7 +45,7 @@ registered_reserved_users_mau_gauge = Gauge( | |||
@wrap_as_background_process("phone_stats_home") | |||
async def phone_stats_home(hs, stats, stats_process=_stats_process): | |||
async def phone_stats_home(hs: "HomeServer", stats, stats_process=_stats_process): | |||
logger.info("Gathering stats for reporting") | |||
now = int(hs.get_clock().time()) | |||
uptime = int(now - hs.start_time) | |||
@@ -142,7 +146,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): | |||
logger.warning("Error reporting stats: %s", e) | |||
def start_phone_stats_home(hs): | |||
def start_phone_stats_home(hs: "HomeServer"): | |||
""" | |||
Start the background tasks which report phone home stats. | |||
""" | |||
@@ -27,6 +27,7 @@ from synapse.util.caches.response_cache import ResponseCache | |||
if TYPE_CHECKING: | |||
from synapse.appservice import ApplicationService | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -84,7 +85,7 @@ class ApplicationServiceApi(SimpleHttpClient): | |||
pushing. | |||
""" | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.clock = hs.get_clock() | |||
@@ -18,6 +18,7 @@ import os | |||
import sys | |||
import threading | |||
from string import Template | |||
from typing import TYPE_CHECKING | |||
import yaml | |||
from zope.interface import implementer | |||
@@ -38,6 +39,9 @@ from synapse.util.versionstring import get_version_string | |||
from ._base import Config, ConfigError | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
DEFAULT_LOG_CONFIG = Template( | |||
"""\ | |||
# Log configuration for Synapse. | |||
@@ -306,7 +310,10 @@ def _reload_logging_config(log_config_path): | |||
def setup_logging( | |||
hs, config, use_worker_options=False, logBeginner: LogBeginner = globalLogBeginner | |||
hs: "HomeServer", | |||
config, | |||
use_worker_options=False, | |||
logBeginner: LogBeginner = globalLogBeginner, | |||
) -> None: | |||
""" | |||
Set up the logging subsystem. | |||
@@ -14,6 +14,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from collections import namedtuple | |||
from typing import TYPE_CHECKING | |||
from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership | |||
from synapse.api.errors import Codes, SynapseError | |||
@@ -25,11 +26,15 @@ from synapse.events.utils import prune_event, validate_canonicaljson | |||
from synapse.http.servlet import assert_params_in_dict | |||
from synapse.types import JsonDict, get_domain_from_id | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
class FederationBase: | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
self.hs = hs | |||
self.server_name = hs.hostname | |||
@@ -467,7 +467,7 @@ class FederationServer(FederationBase): | |||
async def on_room_state_request( | |||
self, origin: str, room_id: str, event_id: Optional[str] | |||
) -> Tuple[int, Dict[str, Any]]: | |||
) -> Tuple[int, JsonDict]: | |||
origin_host, _ = parse_server_name(origin) | |||
await self.check_server_matches_acl(origin_host, room_id) | |||
@@ -481,7 +481,7 @@ class FederationServer(FederationBase): | |||
# - but that's non-trivial to get right, and anyway somewhat defeats | |||
# the point of the linearizer. | |||
with (await self._server_linearizer.queue((origin, room_id))): | |||
resp = dict( | |||
resp: JsonDict = dict( | |||
await self._state_resp_cache.wrap( | |||
(room_id, event_id), | |||
self._on_context_state_request_compute, | |||
@@ -1061,11 +1061,12 @@ class FederationServer(FederationBase): | |||
origin, event = next | |||
lock = await self.store.try_acquire_lock( | |||
new_lock = await self.store.try_acquire_lock( | |||
_INBOUND_EVENT_HANDLING_LOCK_NAME, room_id | |||
) | |||
if not lock: | |||
if not new_lock: | |||
return | |||
lock = new_lock | |||
def __str__(self) -> str: | |||
return "<ReplicationLayer(%s)>" % self.server_name | |||
@@ -21,6 +21,7 @@ import typing | |||
import urllib.parse | |||
from io import BytesIO, StringIO | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Callable, | |||
Dict, | |||
Generic, | |||
@@ -73,6 +74,9 @@ from synapse.util import json_decoder | |||
from synapse.util.async_helpers import timeout_deferred | |||
from synapse.util.metrics import Measure | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
outgoing_requests_counter = Counter( | |||
@@ -319,7 +323,7 @@ class MatrixFederationHttpClient: | |||
requests. | |||
""" | |||
def __init__(self, hs, tls_client_options_factory): | |||
def __init__(self, hs: "HomeServer", tls_client_options_factory): | |||
self.hs = hs | |||
self.signing_key = hs.signing_key | |||
self.server_name = hs.hostname | |||
@@ -711,7 +715,7 @@ class MatrixFederationHttpClient: | |||
Returns: | |||
A list of headers to be added as "Authorization:" headers | |||
""" | |||
request = { | |||
request: JsonDict = { | |||
"method": method.decode("ascii"), | |||
"uri": url_bytes.decode("ascii"), | |||
"origin": self.server_name, | |||
@@ -22,6 +22,7 @@ import urllib | |||
from http import HTTPStatus | |||
from inspect import isawaitable | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Awaitable, | |||
Callable, | |||
@@ -61,6 +62,9 @@ from synapse.util import json_encoder | |||
from synapse.util.caches import intern_dict | |||
from synapse.util.iterutils import chunk_seq | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
HTML_ERROR_TEMPLATE = """<!DOCTYPE html> | |||
@@ -343,6 +347,11 @@ class DirectServeJsonResource(_AsyncResource): | |||
return_json_error(f, request) | |||
_PathEntry = collections.namedtuple( | |||
"_PathEntry", ["pattern", "callback", "servlet_classname"] | |||
) | |||
class JsonResource(DirectServeJsonResource): | |||
"""This implements the HttpServer interface and provides JSON support for | |||
Resources. | |||
@@ -359,14 +368,10 @@ class JsonResource(DirectServeJsonResource): | |||
isLeaf = True | |||
_PathEntry = collections.namedtuple( | |||
"_PathEntry", ["pattern", "callback", "servlet_classname"] | |||
) | |||
def __init__(self, hs, canonical_json=True, extract_context=False): | |||
def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False): | |||
super().__init__(canonical_json, extract_context) | |||
self.clock = hs.get_clock() | |||
self.path_regexs = {} | |||
self.path_regexs: Dict[bytes, List[_PathEntry]] = {} | |||
self.hs = hs | |||
def register_paths(self, method, path_patterns, callback, servlet_classname): | |||
@@ -391,7 +396,7 @@ class JsonResource(DirectServeJsonResource): | |||
for path_pattern in path_patterns: | |||
logger.debug("Registering for %s %s", method, path_pattern.pattern) | |||
self.path_regexs.setdefault(method, []).append( | |||
self._PathEntry(path_pattern, callback, servlet_classname) | |||
_PathEntry(path_pattern, callback, servlet_classname) | |||
) | |||
def _get_handler_for_request( | |||
@@ -12,6 +12,8 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING | |||
from synapse.http.server import JsonResource | |||
from synapse.replication.http import ( | |||
account_data, | |||
@@ -26,16 +28,19 @@ from synapse.replication.http import ( | |||
streams, | |||
) | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
REPLICATION_PREFIX = "/_synapse/replication" | |||
class ReplicationRestResource(JsonResource): | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
# We enable extracting jaeger contexts here as these are internal APIs. | |||
super().__init__(hs, canonical_json=False, extract_context=True) | |||
self.register_servlets(hs) | |||
def register_servlets(self, hs): | |||
def register_servlets(self, hs: "HomeServer"): | |||
send_event.register_servlets(hs, self) | |||
federation.register_servlets(hs, self) | |||
presence.register_servlets(hs, self) | |||
@@ -17,7 +17,7 @@ import logging | |||
import re | |||
import urllib | |||
from inspect import signature | |||
from typing import TYPE_CHECKING, Dict, List, Tuple | |||
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Dict, List, Tuple | |||
from prometheus_client import Counter, Gauge | |||
@@ -156,7 +156,7 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): | |||
pass | |||
@classmethod | |||
def make_client(cls, hs): | |||
def make_client(cls, hs: "HomeServer"): | |||
"""Create a client that makes requests. | |||
Returns a callable that accepts the same parameters as | |||
@@ -208,7 +208,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): | |||
url_args.append(txn_id) | |||
if cls.METHOD == "POST": | |||
request_func = client.post_json_get_json | |||
request_func: Callable[ | |||
..., Awaitable[Any] | |||
] = client.post_json_get_json | |||
elif cls.METHOD == "PUT": | |||
request_func = client.put_json | |||
elif cls.METHOD == "GET": | |||
@@ -13,10 +13,14 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING | |||
from synapse.http.servlet import parse_json_object_from_request | |||
from synapse.replication.http._base import ReplicationEndpoint | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -37,7 +41,7 @@ class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): | |||
PATH_ARGS = ("user_id", "account_data_type") | |||
CACHE = False | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.handler = hs.get_account_data_handler() | |||
@@ -78,7 +82,7 @@ class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): | |||
PATH_ARGS = ("user_id", "room_id", "account_data_type") | |||
CACHE = False | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.handler = hs.get_account_data_handler() | |||
@@ -119,7 +123,7 @@ class ReplicationAddTagRestServlet(ReplicationEndpoint): | |||
PATH_ARGS = ("user_id", "room_id", "tag") | |||
CACHE = False | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.handler = hs.get_account_data_handler() | |||
@@ -162,7 +166,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): | |||
) | |||
CACHE = False | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.handler = hs.get_account_data_handler() | |||
@@ -183,7 +187,7 @@ class ReplicationRemoveTagRestServlet(ReplicationEndpoint): | |||
return 200, {"max_stream_id": max_stream_id} | |||
def register_servlets(hs, http_server): | |||
def register_servlets(hs: "HomeServer", http_server): | |||
ReplicationUserAccountDataRestServlet(hs).register(http_server) | |||
ReplicationRoomAccountDataRestServlet(hs).register(http_server) | |||
ReplicationAddTagRestServlet(hs).register(http_server) | |||
@@ -13,9 +13,13 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING | |||
from synapse.replication.http._base import ReplicationEndpoint | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -51,7 +55,7 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): | |||
PATH_ARGS = ("user_id",) | |||
CACHE = False | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.device_list_updater = hs.get_device_handler().device_list_updater | |||
@@ -68,5 +72,5 @@ class ReplicationUserDevicesResyncRestServlet(ReplicationEndpoint): | |||
return 200, user_devices | |||
def register_servlets(hs, http_server): | |||
def register_servlets(hs: "HomeServer", http_server): | |||
ReplicationUserDevicesResyncRestServlet(hs).register(http_server) |
@@ -13,6 +13,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING | |||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS | |||
from synapse.events import make_event_from_dict | |||
@@ -21,6 +22,9 @@ from synapse.http.servlet import parse_json_object_from_request | |||
from synapse.replication.http._base import ReplicationEndpoint | |||
from synapse.util.metrics import Measure | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -56,7 +60,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint): | |||
NAME = "fed_send_events" | |||
PATH_ARGS = () | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.store = hs.get_datastore() | |||
@@ -151,7 +155,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint): | |||
NAME = "fed_send_edu" | |||
PATH_ARGS = ("edu_type",) | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.store = hs.get_datastore() | |||
@@ -194,7 +198,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint): | |||
# This is a query, so let's not bother caching | |||
CACHE = False | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.store = hs.get_datastore() | |||
@@ -238,7 +242,7 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint): | |||
NAME = "fed_cleanup_room" | |||
PATH_ARGS = ("room_id",) | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.store = hs.get_datastore() | |||
@@ -273,7 +277,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): | |||
NAME = "store_room_on_outlier_membership" | |||
PATH_ARGS = ("room_id",) | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.store = hs.get_datastore() | |||
@@ -289,7 +293,7 @@ class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint): | |||
return 200, {} | |||
def register_servlets(hs, http_server): | |||
def register_servlets(hs: "HomeServer", http_server): | |||
ReplicationFederationSendEventsRestServlet(hs).register(http_server) | |||
ReplicationFederationSendEduRestServlet(hs).register(http_server) | |||
ReplicationGetQueryRestServlet(hs).register(http_server) | |||
@@ -13,10 +13,14 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING | |||
from synapse.http.servlet import parse_json_object_from_request | |||
from synapse.replication.http._base import ReplicationEndpoint | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -30,7 +34,7 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): | |||
NAME = "device_check_registered" | |||
PATH_ARGS = ("user_id",) | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.registration_handler = hs.get_registration_handler() | |||
@@ -82,5 +86,5 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint): | |||
return 200, res | |||
def register_servlets(hs, http_server): | |||
def register_servlets(hs: "HomeServer", http_server): | |||
RegisterDeviceReplicationServlet(hs).register(http_server) |
@@ -45,7 +45,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint): | |||
NAME = "remote_join" | |||
PATH_ARGS = ("room_id", "user_id") | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.federation_handler = hs.get_federation_handler() | |||
@@ -320,7 +320,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): | |||
PATH_ARGS = ("room_id", "user_id", "change") | |||
CACHE = False # No point caching as should return instantly. | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.registeration_handler = hs.get_registration_handler() | |||
@@ -360,7 +360,7 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint): | |||
return 200, {} | |||
def register_servlets(hs, http_server): | |||
def register_servlets(hs: "HomeServer", http_server): | |||
ReplicationRemoteJoinRestServlet(hs).register(http_server) | |||
ReplicationRemoteRejectInviteRestServlet(hs).register(http_server) | |||
ReplicationUserJoinedLeftRoomRestServlet(hs).register(http_server) |
@@ -117,6 +117,6 @@ class ReplicationPresenceSetState(ReplicationEndpoint): | |||
) | |||
def register_servlets(hs, http_server): | |||
def register_servlets(hs: "HomeServer", http_server): | |||
ReplicationBumpPresenceActiveTime(hs).register(http_server) | |||
ReplicationPresenceSetState(hs).register(http_server) |
@@ -67,5 +67,5 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint): | |||
return 200, {} | |||
def register_servlets(hs, http_server): | |||
def register_servlets(hs: "HomeServer", http_server): | |||
ReplicationRemovePusherRestServlet(hs).register(http_server) |
@@ -13,10 +13,14 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING | |||
from synapse.http.servlet import parse_json_object_from_request | |||
from synapse.replication.http._base import ReplicationEndpoint | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -26,7 +30,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint): | |||
NAME = "register_user" | |||
PATH_ARGS = ("user_id",) | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.store = hs.get_datastore() | |||
self.registration_handler = hs.get_registration_handler() | |||
@@ -100,7 +104,7 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): | |||
NAME = "post_register" | |||
PATH_ARGS = ("user_id",) | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.store = hs.get_datastore() | |||
self.registration_handler = hs.get_registration_handler() | |||
@@ -130,6 +134,6 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint): | |||
return 200, {} | |||
def register_servlets(hs, http_server): | |||
def register_servlets(hs: "HomeServer", http_server): | |||
ReplicationRegisterServlet(hs).register(http_server) | |||
ReplicationPostRegisterActionsServlet(hs).register(http_server) |
@@ -13,6 +13,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING | |||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS | |||
from synapse.events import make_event_from_dict | |||
@@ -22,6 +23,9 @@ from synapse.replication.http._base import ReplicationEndpoint | |||
from synapse.types import Requester, UserID | |||
from synapse.util.metrics import Measure | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -57,7 +61,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): | |||
NAME = "send_event" | |||
PATH_ARGS = ("event_id",) | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self.event_creation_handler = hs.get_event_creation_handler() | |||
@@ -135,5 +139,5 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint): | |||
) | |||
def register_servlets(hs, http_server): | |||
def register_servlets(hs: "HomeServer", http_server): | |||
ReplicationSendEventRestServlet(hs).register(http_server) |
@@ -13,11 +13,15 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING | |||
from synapse.api.errors import SynapseError | |||
from synapse.http.servlet import parse_integer | |||
from synapse.replication.http._base import ReplicationEndpoint | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -46,7 +50,7 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): | |||
PATH_ARGS = ("stream_name",) | |||
METHOD = "GET" | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
super().__init__(hs) | |||
self._instance_name = hs.get_instance_name() | |||
@@ -74,5 +78,5 @@ class ReplicationGetStreamUpdates(ReplicationEndpoint): | |||
) | |||
def register_servlets(hs, http_server): | |||
def register_servlets(hs: "HomeServer", http_server): | |||
ReplicationGetStreamUpdates(hs).register(http_server) |
@@ -13,18 +13,21 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Optional | |||
from typing import TYPE_CHECKING, Optional | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | |||
from synapse.storage.engines import PostgresEngine | |||
from synapse.storage.util.id_generators import MultiWriterIdGenerator | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
class BaseSlavedStore(CacheInvalidationWorkerStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
if isinstance(self.database_engine, PostgresEngine): | |||
self._cache_id_gen: Optional[ | |||
@@ -12,15 +12,20 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY | |||
from synapse.util.caches.lrucache import LruCache | |||
from ._base import BaseSlavedStore | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
class SlavedClientIpStore(BaseSlavedStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.client_ip_last_seen: LruCache[tuple, int] = LruCache( | |||
@@ -12,6 +12,8 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING | |||
from synapse.replication.slave.storage._base import BaseSlavedStore | |||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |||
from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream | |||
@@ -20,9 +22,12 @@ from synapse.storage.databases.main.devices import DeviceWorkerStore | |||
from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.hs = hs | |||
@@ -13,6 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.event_federation import EventFederationWorkerStore | |||
@@ -30,6 +31,9 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
from ._base import BaseSlavedStore | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -54,7 +58,7 @@ class SlavedEventStore( | |||
RelationsWorkerStore, | |||
BaseSlavedStore, | |||
): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
events_max = self._stream_id_gen.get_current_token() | |||
@@ -12,14 +12,19 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING | |||
from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.filtering import FilteringStore | |||
from ._base import BaseSlavedStore | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
class SlavedFilteringStore(BaseSlavedStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
# Filters are immutable so this cache doesn't need to be expired | |||
@@ -12,6 +12,8 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from typing import TYPE_CHECKING | |||
from synapse.replication.slave.storage._base import BaseSlavedStore | |||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |||
from synapse.replication.tcp.streams import GroupServerStream | |||
@@ -19,9 +21,12 @@ from synapse.storage.database import DatabasePool | |||
from synapse.storage.databases.main.group_server import GroupServerWorkerStore | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.hs = hs | |||
@@ -21,6 +21,8 @@ from synapse.logging.context import make_deferred_yieldable | |||
from synapse.util import json_decoder, json_encoder | |||
if TYPE_CHECKING: | |||
from txredisapi import RedisProtocol | |||
from synapse.server import HomeServer | |||
set_counter = Counter( | |||
@@ -59,7 +61,12 @@ class ExternalCache: | |||
""" | |||
def __init__(self, hs: "HomeServer"): | |||
self._redis_connection = hs.get_outbound_redis_connection() | |||
if hs.config.redis.redis_enabled: | |||
self._redis_connection: Optional[ | |||
"RedisProtocol" | |||
] = hs.get_outbound_redis_connection() | |||
else: | |||
self._redis_connection = None | |||
def _get_redis_key(self, cache_name: str, key: str) -> str: | |||
return "cache_v1:%s:%s" % (cache_name, key) | |||
@@ -294,7 +294,7 @@ class ReplicationCommandHandler: | |||
# This shouldn't be possible | |||
raise Exception("Unrecognised command %s in stream queue", cmd.NAME) | |||
def start_replication(self, hs): | |||
def start_replication(self, hs: "HomeServer"): | |||
"""Helper method to start a replication connection to the remote server | |||
using TCP. | |||
""" | |||
@@ -321,6 +321,8 @@ class ReplicationCommandHandler: | |||
hs.config.redis.redis_host, # type: ignore[arg-type] | |||
hs.config.redis.redis_port, | |||
self._factory, | |||
timeout=30, | |||
bindAddress=None, | |||
) | |||
else: | |||
client_name = hs.get_instance_name() | |||
@@ -331,6 +333,8 @@ class ReplicationCommandHandler: | |||
host, # type: ignore[arg-type] | |||
port, | |||
self._factory, | |||
timeout=30, | |||
bindAddress=None, | |||
) | |||
def get_streams(self) -> Dict[str, Stream]: | |||
@@ -16,6 +16,7 @@ | |||
import logging | |||
import random | |||
from typing import TYPE_CHECKING | |||
from prometheus_client import Counter | |||
@@ -27,6 +28,9 @@ from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol | |||
from synapse.replication.tcp.streams import EventsStream | |||
from synapse.util.metrics import Measure | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
stream_updates_counter = Counter( | |||
"synapse_replication_tcp_resource_stream_updates", "", ["stream_name"] | |||
) | |||
@@ -37,7 +41,7 @@ logger = logging.getLogger(__name__) | |||
class ReplicationStreamProtocolFactory(Factory): | |||
"""Factory for new replication connections.""" | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
self.command_handler = hs.get_tcp_replication() | |||
self.clock = hs.get_clock() | |||
self.server_name = hs.config.server.server_name | |||
@@ -65,7 +69,7 @@ class ReplicationStreamer: | |||
data is available it will propagate to all connected clients. | |||
""" | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
self.store = hs.get_datastore() | |||
self.clock = hs.get_clock() | |||
self.notifier = hs.get_notifier() | |||
@@ -241,7 +241,7 @@ class BackfillStream(Stream): | |||
NAME = "backfill" | |||
ROW_TYPE = BackfillStreamRow | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
self.store = hs.get_datastore() | |||
super().__init__( | |||
hs.get_instance_name(), | |||
@@ -363,7 +363,7 @@ class ReceiptsStream(Stream): | |||
NAME = "receipts" | |||
ROW_TYPE = ReceiptsStreamRow | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
store = hs.get_datastore() | |||
super().__init__( | |||
hs.get_instance_name(), | |||
@@ -380,7 +380,7 @@ class PushRulesStream(Stream): | |||
NAME = "push_rules" | |||
ROW_TYPE = PushRulesStreamRow | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
self.store = hs.get_datastore() | |||
super().__init__( | |||
@@ -405,7 +405,7 @@ class PushersStream(Stream): | |||
NAME = "pushers" | |||
ROW_TYPE = PushersStreamRow | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
store = hs.get_datastore() | |||
super().__init__( | |||
@@ -438,7 +438,7 @@ class CachesStream(Stream): | |||
NAME = "caches" | |||
ROW_TYPE = CachesStreamRow | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
store = hs.get_datastore() | |||
super().__init__( | |||
hs.get_instance_name(), | |||
@@ -459,7 +459,7 @@ class DeviceListsStream(Stream): | |||
NAME = "device_lists" | |||
ROW_TYPE = DeviceListsStreamRow | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
store = hs.get_datastore() | |||
super().__init__( | |||
hs.get_instance_name(), | |||
@@ -476,7 +476,7 @@ class ToDeviceStream(Stream): | |||
NAME = "to_device" | |||
ROW_TYPE = ToDeviceStreamRow | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
store = hs.get_datastore() | |||
super().__init__( | |||
hs.get_instance_name(), | |||
@@ -495,7 +495,7 @@ class TagAccountDataStream(Stream): | |||
NAME = "tag_account_data" | |||
ROW_TYPE = TagAccountDataStreamRow | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
store = hs.get_datastore() | |||
super().__init__( | |||
hs.get_instance_name(), | |||
@@ -582,7 +582,7 @@ class GroupServerStream(Stream): | |||
NAME = "groups" | |||
ROW_TYPE = GroupsStreamRow | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
store = hs.get_datastore() | |||
super().__init__( | |||
hs.get_instance_name(), | |||
@@ -599,7 +599,7 @@ class UserSignatureStream(Stream): | |||
NAME = "user_signature" | |||
ROW_TYPE = UserSignatureStreamRow | |||
def __init__(self, hs): | |||
def __init__(self, hs: "HomeServer"): | |||
store = hs.get_datastore() | |||
super().__init__( | |||
hs.get_instance_name(), | |||
@@ -110,7 +110,7 @@ class DevicesRestServlet(RestServlet): | |||
def __init__(self, hs: "HomeServer"): | |||
""" | |||
Args: | |||
hs (synapse.server.HomeServer): server | |||
hs: server | |||
""" | |||
self.hs = hs | |||
self.auth = hs.get_auth() | |||
@@ -800,9 +800,14 @@ class HomeServer(metaclass=abc.ABCMeta): | |||
return ExternalCache(self) | |||
@cache_in_self | |||
def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: | |||
if not self.config.redis.redis_enabled: | |||
return None | |||
def get_outbound_redis_connection(self) -> "RedisProtocol": | |||
""" | |||
The Redis connection used for replication. | |||
Raises: | |||
AssertionError: if Redis is not enabled in the homeserver config. | |||
""" | |||
assert self.config.redis.redis_enabled | |||
# We only want to import redis module if we're using it, as we have | |||
# `txredisapi` as an optional dependency. | |||
@@ -19,6 +19,7 @@ from collections import defaultdict | |||
from sys import intern | |||
from time import monotonic as monotonic_time | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Callable, | |||
Collection, | |||
@@ -52,6 +53,9 @@ from synapse.storage.background_updates import BackgroundUpdater | |||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine | |||
from synapse.storage.types import Connection, Cursor | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
# python 3 does not have a maximum int value | |||
MAX_TXN_ID = 2 ** 63 - 1 | |||
@@ -392,7 +396,7 @@ class DatabasePool: | |||
def __init__( | |||
self, | |||
hs, | |||
hs: "HomeServer", | |||
database_config: DatabaseConnectionConfig, | |||
engine: BaseDatabaseEngine, | |||
): | |||
@@ -13,33 +13,49 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import TYPE_CHECKING, Generic, List, Optional, Type, TypeVar | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import DatabasePool, make_conn | |||
from synapse.storage.databases.main.events import PersistEventsStore | |||
from synapse.storage.databases.state import StateGroupDataStore | |||
from synapse.storage.engines import create_engine | |||
from synapse.storage.prepare_database import prepare_database | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
class Databases: | |||
DataStoreT = TypeVar("DataStoreT", bound=SQLBaseStore, covariant=True) | |||
class Databases(Generic[DataStoreT]): | |||
"""The various databases. | |||
These are low level interfaces to physical databases. | |||
Attributes: | |||
main (DataStore) | |||
databases | |||
main | |||
state | |||
persist_events | |||
""" | |||
def __init__(self, main_store_class, hs): | |||
databases: List[DatabasePool] | |||
main: DataStoreT | |||
state: StateGroupDataStore | |||
persist_events: Optional[PersistEventsStore] | |||
def __init__(self, main_store_class: Type[DataStoreT], hs: "HomeServer"): | |||
# Note we pass in the main store class here as workers use a different main | |||
# store. | |||
self.databases = [] | |||
main = None | |||
state = None | |||
persist_events = None | |||
main: Optional[DataStoreT] = None | |||
state: Optional[StateGroupDataStore] = None | |||
persist_events: Optional[PersistEventsStore] = None | |||
for database_config in hs.config.database.databases: | |||
db_name = database_config.name | |||
@@ -15,7 +15,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import List, Optional, Tuple | |||
from typing import TYPE_CHECKING, List, Optional, Tuple | |||
from synapse.config.homeserver import HomeServerConfig | |||
from synapse.storage.database import DatabasePool | |||
@@ -75,6 +75,9 @@ from .ui_auth import UIAuthStore | |||
from .user_directory import UserDirectoryStore | |||
from .user_erasure_store import UserErasureStore | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -126,7 +129,7 @@ class DataStore( | |||
LockStore, | |||
SessionStore, | |||
): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
self.hs = hs | |||
self._clock = hs.get_clock() | |||
self.database_engine = database.engine | |||
@@ -14,7 +14,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Dict, List, Optional, Set, Tuple | |||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple | |||
from synapse.api.constants import AccountDataTypes | |||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker | |||
@@ -28,6 +28,9 @@ from synapse.util import json_encoder | |||
from synapse.util.caches.descriptors import cached | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -36,7 +39,7 @@ class AccountDataWorkerStore(SQLBaseStore): | |||
`get_max_account_data_stream_id` which can be called in the initializer. | |||
""" | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
self._instance_name = hs.get_instance_name() | |||
if isinstance(database.engine, PostgresEngine): | |||
@@ -15,7 +15,7 @@ | |||
import itertools | |||
import logging | |||
from typing import Any, Iterable, List, Optional, Tuple | |||
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple | |||
from synapse.api.constants import EventTypes | |||
from synapse.replication.tcp.streams import BackfillStream, CachesStream | |||
@@ -29,6 +29,9 @@ from synapse.storage.database import DatabasePool | |||
from synapse.storage.engines import PostgresEngine | |||
from synapse.util.iterutils import batch_iter | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -38,7 +41,7 @@ CURRENT_STATE_CACHE_NAME = "cs_cache_fake" | |||
class CacheInvalidationWorkerStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self._instance_name = hs.get_instance_name() | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import List, Optional, Tuple | |||
from typing import TYPE_CHECKING, List, Optional, Tuple | |||
from synapse.logging import issue9533_logger | |||
from synapse.logging.opentracing import log_kv, set_tag, trace | |||
@@ -26,11 +26,14 @@ from synapse.util import json_encoder | |||
from synapse.util.caches.expiringcache import ExpiringCache | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
class DeviceInboxWorkerStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self._instance_name = hs.get_instance_name() | |||
@@ -553,7 +556,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): | |||
class DeviceInboxBackgroundUpdateStore(SQLBaseStore): | |||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.db_pool.updates.register_background_index_update( | |||
@@ -15,7 +15,17 @@ | |||
# limitations under the License. | |||
import abc | |||
import logging | |||
from typing import Any, Collection, Dict, Iterable, List, Optional, Set, Tuple | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Collection, | |||
Dict, | |||
Iterable, | |||
List, | |||
Optional, | |||
Set, | |||
Tuple, | |||
) | |||
from synapse.api.errors import Codes, StoreError | |||
from synapse.logging.opentracing import ( | |||
@@ -38,6 +48,9 @@ from synapse.util.caches.lrucache import LruCache | |||
from synapse.util.iterutils import batch_iter | |||
from synapse.util.stringutils import shortstr | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES = ( | |||
@@ -48,7 +61,7 @@ BG_UPDATE_REMOVE_DUP_OUTBOUND_POKES = "remove_dup_outbound_pokes" | |||
class DeviceWorkerStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
if hs.config.worker.run_background_tasks: | |||
@@ -915,7 +928,7 @@ class DeviceWorkerStore(SQLBaseStore): | |||
class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.db_pool.updates.register_background_index_update( | |||
@@ -1047,7 +1060,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore): | |||
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
# Map of (user_id, device_id) -> bool. If there is an entry that implies | |||
@@ -14,7 +14,7 @@ | |||
import itertools | |||
import logging | |||
from queue import Empty, PriorityQueue | |||
from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple | |||
from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple | |||
from prometheus_client import Counter, Gauge | |||
@@ -34,6 +34,9 @@ from synapse.util.caches.descriptors import cached | |||
from synapse.util.caches.lrucache import LruCache | |||
from synapse.util.iterutils import batch_iter | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
oldest_pdu_in_federation_staging = Gauge( | |||
"synapse_federation_server_oldest_inbound_pdu_in_staging", | |||
"The age in seconds since we received the oldest pdu in the federation staging area", | |||
@@ -59,7 +62,7 @@ class _NoChainCoverIndex(Exception): | |||
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
if hs.config.worker.run_background_tasks: | |||
@@ -1511,7 +1514,7 @@ class EventFederationStore(EventFederationWorkerStore): | |||
EVENT_AUTH_STATE_ONLY = "event_auth_state_only" | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.db_pool.updates.register_background_update_handler( | |||
@@ -13,7 +13,7 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import Dict, List, Optional, Tuple, Union | |||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union | |||
import attr | |||
@@ -23,6 +23,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction | |||
from synapse.util import json_encoder | |||
from synapse.util.caches.descriptors import cached | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -64,7 +67,7 @@ def _deserialize_action(actions, is_highlight): | |||
class EventPushActionsWorkerStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
# These get correctly set by _find_stream_orderings_for_times_txn | |||
@@ -892,7 +895,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): | |||
class EventPushActionsStore(EventPushActionsWorkerStore): | |||
EPA_HIGHLIGHT_INDEX = "epa_highlight_index" | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.db_pool.updates.register_background_index_update( | |||
@@ -13,7 +13,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Dict, List, Optional, Tuple | |||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple | |||
import attr | |||
@@ -26,6 +26,9 @@ from synapse.storage.databases.main.events import PersistEventsStore | |||
from synapse.storage.types import Cursor | |||
from synapse.types import JsonDict | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -76,7 +79,7 @@ class _CalculateChainCover: | |||
class EventsBackgroundUpdatesStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.db_pool.updates.register_background_update_handler( | |||
@@ -13,11 +13,14 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from enum import Enum | |||
from typing import Any, Dict, Iterable, List, Optional, Tuple | |||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import DatabasePool | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
BG_UPDATE_REMOVE_MEDIA_REPO_INDEX_WITHOUT_METHOD = ( | |||
"media_repository_drop_index_wo_method" | |||
) | |||
@@ -43,7 +46,7 @@ class MediaSortOrder(Enum): | |||
class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.db_pool.updates.register_background_index_update( | |||
@@ -123,7 +126,7 @@ class MediaRepositoryBackgroundUpdateStore(SQLBaseStore): | |||
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): | |||
"""Persistence for attachments and avatars""" | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.server_name = hs.hostname | |||
@@ -14,7 +14,7 @@ | |||
import calendar | |||
import logging | |||
import time | |||
from typing import Dict | |||
from typing import TYPE_CHECKING, Dict | |||
from synapse.metrics import GaugeBucketCollector | |||
from synapse.metrics.background_process_metrics import wrap_as_background_process | |||
@@ -24,6 +24,9 @@ from synapse.storage.databases.main.event_push_actions import ( | |||
EventPushActionsWorkerStore, | |||
) | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
# Collect metrics on the number of forward extremities that exist. | |||
@@ -52,7 +55,7 @@ class ServerMetricsStore(EventPushActionsWorkerStore, SQLBaseStore): | |||
stats and prometheus metrics. | |||
""" | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
# Read the extrems every 60 minutes | |||
@@ -12,13 +12,16 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
from typing import Dict, List, Optional | |||
from typing import TYPE_CHECKING, Dict, List, Optional | |||
from synapse.metrics.background_process_metrics import wrap_as_background_process | |||
from synapse.storage._base import SQLBaseStore | |||
from synapse.storage.database import DatabasePool, make_in_list_sql_clause | |||
from synapse.util.caches.descriptors import cached | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
# Number of msec of granularity to store the monthly_active_user timestamp | |||
@@ -27,7 +30,7 @@ LAST_SEEN_GRANULARITY = 60 * 60 * 1000 | |||
class MonthlyActiveUsersWorkerStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self._clock = hs.get_clock() | |||
self.hs = hs | |||
@@ -209,7 +212,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): | |||
class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self._mau_stats_only = hs.config.server.mau_stats_only | |||
@@ -14,7 +14,7 @@ | |||
# limitations under the License. | |||
import abc | |||
import logging | |||
from typing import Dict, List, Tuple, Union | |||
from typing import TYPE_CHECKING, Dict, List, Tuple, Union | |||
from synapse.api.errors import NotFoundError, StoreError | |||
from synapse.push.baserules import list_with_base_rules | |||
@@ -33,6 +33,9 @@ from synapse.util import json_encoder | |||
from synapse.util.caches.descriptors import cached, cachedList | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -75,7 +78,7 @@ class PushRulesWorkerStore( | |||
`get_max_push_rules_stream_id` which can be called in the initializer. | |||
""" | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
if hs.config.worker.worker_app is None: | |||
@@ -14,7 +14,7 @@ | |||
# limitations under the License. | |||
import logging | |||
from typing import Any, Dict, Iterable, List, Optional, Tuple | |||
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple | |||
from twisted.internet import defer | |||
@@ -29,11 +29,14 @@ from synapse.util import json_encoder | |||
from synapse.util.caches.descriptors import cached, cachedList | |||
from synapse.util.caches.stream_change_cache import StreamChangeCache | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
class ReceiptsWorkerStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
self._instance_name = hs.get_instance_name() | |||
if isinstance(database.engine, PostgresEngine): | |||
@@ -17,7 +17,7 @@ import collections | |||
import logging | |||
from abc import abstractmethod | |||
from enum import Enum | |||
from typing import Any, Dict, List, Optional, Tuple | |||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple | |||
from synapse.api.constants import EventContentFields, EventTypes, JoinRules | |||
from synapse.api.errors import StoreError | |||
@@ -32,6 +32,9 @@ from synapse.util import json_encoder | |||
from synapse.util.caches.descriptors import cached | |||
from synapse.util.stringutils import MXC_REGEX | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -69,7 +72,7 @@ class RoomSortOrder(Enum): | |||
class RoomWorkerStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.config = hs.config | |||
@@ -1026,7 +1029,7 @@ _REPLACE_ROOM_DEPTH_SQL_COMMANDS = ( | |||
class RoomBackgroundUpdateStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.config = hs.config | |||
@@ -1411,7 +1414,7 @@ class RoomBackgroundUpdateStore(SQLBaseStore): | |||
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.config = hs.config | |||
@@ -53,6 +53,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList | |||
from synapse.util.metrics import Measure | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
from synapse.state import _StateCacheEntry | |||
logger = logging.getLogger(__name__) | |||
@@ -63,7 +64,7 @@ _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership" | |||
class RoomMemberWorkerStore(EventsWorkerStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
# Used by `_get_joined_hosts` to ensure only one thing mutates the cache | |||
@@ -982,7 +983,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): | |||
class RoomMemberBackgroundUpdateStore(SQLBaseStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.db_pool.updates.register_background_update_handler( | |||
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile | |||
@@ -1132,7 +1133,7 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore): | |||
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
async def forget(self, user_id: str, room_id: str) -> None: | |||
@@ -15,7 +15,7 @@ | |||
import logging | |||
import re | |||
from collections import namedtuple | |||
from typing import Collection, Iterable, List, Optional, Set | |||
from typing import TYPE_CHECKING, Collection, Iterable, List, Optional, Set | |||
from synapse.api.errors import SynapseError | |||
from synapse.events import EventBase | |||
@@ -24,6 +24,9 @@ from synapse.storage.database import DatabasePool, LoggingTransaction | |||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour | |||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
SearchEntry = namedtuple( | |||
@@ -102,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): | |||
EVENT_SEARCH_USE_GIST_POSTGRES_NAME = "event_search_postgres_gist" | |||
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin" | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
if not hs.config.server.enable_search: | |||
@@ -355,7 +358,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): | |||
class SearchStore(SearchBackgroundUpdateStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
async def search_msgs(self, room_ids, search_term, keys): | |||
@@ -15,7 +15,7 @@ | |||
import collections.abc | |||
import logging | |||
from collections import namedtuple | |||
from typing import Iterable, Optional, Set | |||
from typing import TYPE_CHECKING, Iterable, Optional, Set | |||
from synapse.api.constants import EventTypes, Membership | |||
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError | |||
@@ -30,6 +30,9 @@ from synapse.types import StateMap | |||
from synapse.util.caches import intern_string | |||
from synapse.util.caches.descriptors import cached, cachedList | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
@@ -53,7 +56,7 @@ class _GetStateGroupDelta( | |||
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): | |||
"""The parts of StateGroupStore that can be called from workers.""" | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
async def get_room_version(self, room_id: str) -> RoomVersion: | |||
@@ -346,7 +349,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore): | |||
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index" | |||
DELETE_CURRENT_STATE_UPDATE_NAME = "delete_old_current_state_events" | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.server_name = hs.hostname | |||
@@ -533,5 +536,5 @@ class StateStore(StateGroupWorkerStore, MainStateBackgroundUpdateStore): | |||
* `state_groups_state`: Maps state group to state events. | |||
""" | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) |
@@ -16,7 +16,7 @@ | |||
import logging | |||
from enum import Enum | |||
from itertools import chain | |||
from typing import Any, Dict, List, Optional, Tuple | |||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple | |||
from typing_extensions import Counter | |||
@@ -29,6 +29,9 @@ from synapse.storage.databases.main.state_deltas import StateDeltasStore | |||
from synapse.types import JsonDict | |||
from synapse.util.caches.descriptors import cached | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
# these fields track absolutes (e.g. total number of rooms on the server) | |||
@@ -93,7 +96,7 @@ class UserSortOrder(Enum): | |||
class StatsStore(StateDeltasStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
self.server_name = hs.hostname | |||
@@ -14,7 +14,7 @@ | |||
import logging | |||
from collections import namedtuple | |||
from typing import Iterable, List, Optional, Tuple | |||
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple | |||
import attr | |||
from canonicaljson import encode_canonical_json | |||
@@ -26,6 +26,9 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore | |||
from synapse.types import JsonDict | |||
from synapse.util.caches.descriptors import cached | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
db_binary_type = memoryview | |||
logger = logging.getLogger(__name__) | |||
@@ -57,7 +60,7 @@ class DestinationRetryTimings: | |||
class TransactionWorkerStore(CacheInvalidationWorkerStore): | |||
def __init__(self, database: DatabasePool, db_conn, hs): | |||
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): | |||
super().__init__(database, db_conn, hs) | |||
if hs.config.worker.run_background_tasks: | |||
@@ -18,6 +18,7 @@ import itertools | |||
import logging | |||
from collections import deque | |||
from typing import ( | |||
TYPE_CHECKING, | |||
Any, | |||
Awaitable, | |||
Callable, | |||
@@ -56,6 +57,9 @@ from synapse.types import ( | |||
from synapse.util.async_helpers import ObservableDeferred, yieldable_gather_results | |||
from synapse.util.metrics import Measure | |||
if TYPE_CHECKING: | |||
from synapse.server import HomeServer | |||
logger = logging.getLogger(__name__) | |||
# The number of times we are recalculating the current state | |||
@@ -272,7 +276,7 @@ class EventsPersistenceStorage: | |||
current state and forward extremity changes. | |||
""" | |||
def __init__(self, hs, stores: Databases): | |||
def __init__(self, hs: "HomeServer", stores: Databases): | |||
# We ultimately want to split out the state store from the main store, | |||
# so we use separate variables here even though they point to the same | |||
# store for now. | |||