@@ -0,0 +1 @@ | |||
Implement MSC1708 (.well-known routing for server-server federation) |
@@ -14,6 +14,8 @@ | |||
# limitations under the License. | |||
import json | |||
import logging | |||
import random | |||
import time | |||
import attr | |||
from netaddr import IPAddress | |||
@@ -22,13 +24,29 @@ from zope.interface import implementer | |||
from twisted.internet import defer | |||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS | |||
from twisted.web.client import URI, Agent, HTTPConnectionPool, readBody | |||
from twisted.web.http import stringToDatetime | |||
from twisted.web.http_headers import Headers | |||
from twisted.web.iweb import IAgent | |||
from synapse.http.federation.srv_resolver import SrvResolver, pick_server_from_list | |||
from synapse.util.caches.ttlcache import TTLCache | |||
from synapse.util.logcontext import make_deferred_yieldable | |||
# period to cache .well-known results for by default | |||
WELL_KNOWN_DEFAULT_CACHE_PERIOD = 24 * 3600 | |||
# jitter to add to the .well-known default cache ttl | |||
WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER = 10 * 60 | |||
# period to cache failure to fetch .well-known for | |||
WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600 | |||
# cap for .well-known cache period | |||
WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600 | |||
logger = logging.getLogger(__name__) | |||
well_known_cache = TTLCache('well-known') | |||
@implementer(IAgent) | |||
@@ -57,6 +75,7 @@ class MatrixFederationAgent(object): | |||
self, reactor, tls_client_options_factory, | |||
_well_known_tls_policy=None, | |||
_srv_resolver=None, | |||
_well_known_cache=well_known_cache, | |||
): | |||
self._reactor = reactor | |||
self._tls_client_options_factory = tls_client_options_factory | |||
@@ -77,6 +96,8 @@ class MatrixFederationAgent(object): | |||
_well_known_agent = Agent(self._reactor, pool=self._pool, **agent_args) | |||
self._well_known_agent = _well_known_agent | |||
self._well_known_cache = _well_known_cache | |||
@defer.inlineCallbacks | |||
def request(self, method, uri, headers=None, bodyProducer=None): | |||
""" | |||
@@ -259,7 +280,14 @@ class MatrixFederationAgent(object): | |||
Deferred[bytes|None]: either the new server name, from the .well-known, or | |||
None if there was no .well-known file. | |||
""" | |||
# FIXME: add a cache | |||
try: | |||
cached = self._well_known_cache[server_name] | |||
defer.returnValue(cached) | |||
except KeyError: | |||
pass | |||
# TODO: should we linearise so that we don't end up doing two .well-known requests | |||
# for the same server in parallel? | |||
uri = b"https://%s/.well-known/matrix/server" % (server_name, ) | |||
uri_str = uri.decode("ascii") | |||
@@ -270,12 +298,14 @@ class MatrixFederationAgent(object): | |||
) | |||
except Exception as e: | |||
logger.info("Connection error fetching %s: %s", uri_str, e) | |||
self._well_known_cache.set(server_name, None, WELL_KNOWN_INVALID_CACHE_PERIOD) | |||
defer.returnValue(None) | |||
body = yield make_deferred_yieldable(readBody(response)) | |||
if response.code != 200: | |||
logger.info("Error response %i from %s", response.code, uri_str) | |||
self._well_known_cache.set(server_name, None, WELL_KNOWN_INVALID_CACHE_PERIOD) | |||
defer.returnValue(None) | |||
try: | |||
@@ -287,7 +317,63 @@ class MatrixFederationAgent(object): | |||
raise Exception("Missing key 'm.server'") | |||
except Exception as e: | |||
raise Exception("invalid .well-known response from %s: %s" % (uri_str, e,)) | |||
defer.returnValue(parsed_body["m.server"].encode("ascii")) | |||
result = parsed_body["m.server"].encode("ascii") | |||
cache_period = _cache_period_from_headers( | |||
response.headers, | |||
time_now=self._reactor.seconds, | |||
) | |||
if cache_period is None: | |||
cache_period = WELL_KNOWN_DEFAULT_CACHE_PERIOD | |||
# add some randomness to the TTL to avoid a stampeding herd every hour after | |||
# startup | |||
cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER) | |||
else: | |||
cache_period = min(cache_period, WELL_KNOWN_MAX_CACHE_PERIOD) | |||
if cache_period > 0: | |||
self._well_known_cache.set(server_name, result, cache_period) | |||
defer.returnValue(result) | |||
def _cache_period_from_headers(headers, time_now=time.time): | |||
cache_controls = _parse_cache_control(headers) | |||
if b'no-store' in cache_controls: | |||
return 0 | |||
if b'max-age' in cache_controls: | |||
try: | |||
max_age = int(cache_controls[b'max-age']) | |||
return max_age | |||
except ValueError: | |||
pass | |||
expires = headers.getRawHeaders(b'expires') | |||
if expires is not None: | |||
try: | |||
expires_date = stringToDatetime(expires[-1]) | |||
return expires_date - time_now() | |||
except ValueError: | |||
# RFC7234 says 'A cache recipient MUST interpret invalid date formats, | |||
# especially the value "0", as representing a time in the past (i.e., | |||
# "already expired"). | |||
return 0 | |||
return None | |||
def _parse_cache_control(headers): | |||
cache_controls = {} | |||
for hdr in headers.getRawHeaders(b'cache-control', []): | |||
for directive in hdr.split(b','): | |||
splits = [x.strip() for x in directive.split(b'=', 1)] | |||
k = splits[0].lower() | |||
v = splits[1] if len(splits) > 1 else None | |||
cache_controls[k] = v | |||
return cache_controls | |||
@attr.s | |||
@@ -0,0 +1,161 @@ | |||
# -*- coding: utf-8 -*- | |||
# Copyright 2015, 2016 OpenMarket Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import logging | |||
import time | |||
import attr | |||
from sortedcontainers import SortedList | |||
from synapse.util.caches import register_cache | |||
logger = logging.getLogger(__name__) | |||
SENTINEL = object() | |||
class TTLCache(object): | |||
"""A key/value cache implementation where each entry has its own TTL""" | |||
def __init__(self, cache_name, timer=time.time): | |||
# map from key to _CacheEntry | |||
self._data = {} | |||
# the _CacheEntries, sorted by expiry time | |||
self._expiry_list = SortedList() | |||
self._timer = timer | |||
self._metrics = register_cache("ttl", cache_name, self) | |||
def set(self, key, value, ttl): | |||
"""Add/update an entry in the cache | |||
Args: | |||
key: key for this entry | |||
value: value for this entry | |||
ttl (float): TTL for this entry, in seconds | |||
""" | |||
expiry = self._timer() + ttl | |||
self.expire() | |||
e = self._data.pop(key, SENTINEL) | |||
if e != SENTINEL: | |||
self._expiry_list.remove(e) | |||
entry = _CacheEntry(expiry_time=expiry, key=key, value=value) | |||
self._data[key] = entry | |||
self._expiry_list.add(entry) | |||
def get(self, key, default=SENTINEL): | |||
"""Get a value from the cache | |||
Args: | |||
key: key to look up | |||
default: default value to return, if key is not found. If not set, and the | |||
key is not found, a KeyError will be raised | |||
Returns: | |||
value from the cache, or the default | |||
""" | |||
self.expire() | |||
e = self._data.get(key, SENTINEL) | |||
if e == SENTINEL: | |||
self._metrics.inc_misses() | |||
if default == SENTINEL: | |||
raise KeyError(key) | |||
return default | |||
self._metrics.inc_hits() | |||
return e.value | |||
def get_with_expiry(self, key): | |||
"""Get a value, and its expiry time, from the cache | |||
Args: | |||
key: key to look up | |||
Returns: | |||
Tuple[Any, float]: the value from the cache, and the expiry time | |||
Raises: | |||
KeyError if the entry is not found | |||
""" | |||
self.expire() | |||
try: | |||
e = self._data[key] | |||
except KeyError: | |||
self._metrics.inc_misses() | |||
raise | |||
self._metrics.inc_hits() | |||
return e.value, e.expiry_time | |||
def pop(self, key, default=SENTINEL): | |||
"""Remove a value from the cache | |||
If key is in the cache, remove it and return its value, else return default. | |||
If default is not given and key is not in the cache, a KeyError is raised. | |||
Args: | |||
key: key to look up | |||
default: default value to return, if key is not found. If not set, and the | |||
key is not found, a KeyError will be raised | |||
Returns: | |||
value from the cache, or the default | |||
""" | |||
self.expire() | |||
e = self._data.pop(key, SENTINEL) | |||
if e == SENTINEL: | |||
self._metrics.inc_misses() | |||
if default == SENTINEL: | |||
raise KeyError(key) | |||
return default | |||
self._expiry_list.remove(e) | |||
self._metrics.inc_hits() | |||
return e.value | |||
def __getitem__(self, key): | |||
return self.get(key) | |||
def __delitem__(self, key): | |||
self.pop(key) | |||
def __contains__(self, key): | |||
return key in self._data | |||
def __len__(self): | |||
self.expire() | |||
return len(self._data) | |||
def expire(self): | |||
"""Run the expiry on the cache. Any entries whose expiry times are due will | |||
be removed | |||
""" | |||
now = self._timer() | |||
while self._expiry_list: | |||
first_entry = self._expiry_list[0] | |||
if first_entry.expiry_time - now > 0.0: | |||
break | |||
del self._data[first_entry.key] | |||
del self._expiry_list[0] | |||
@attr.s(frozen=True, slots=True) | |||
class _CacheEntry(object): | |||
"""TTLCache entry""" | |||
# expiry_time is the first attribute, so that entries are sorted by expiry. | |||
expiry_time = attr.ib() | |||
key = attr.ib() | |||
value = attr.ib() |
@@ -24,11 +24,16 @@ from twisted.internet._sslverify import ClientTLSOptions, OpenSSLCertificateOpti | |||
from twisted.internet.protocol import Factory | |||
from twisted.protocols.tls import TLSMemoryBIOFactory | |||
from twisted.web.http import HTTPChannel | |||
from twisted.web.http_headers import Headers | |||
from twisted.web.iweb import IPolicyForHTTPS | |||
from synapse.crypto.context_factory import ClientTLSOptionsFactory | |||
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent | |||
from synapse.http.federation.matrix_federation_agent import ( | |||
MatrixFederationAgent, | |||
_cache_period_from_headers, | |||
) | |||
from synapse.http.federation.srv_resolver import Server | |||
from synapse.util.caches.ttlcache import TTLCache | |||
from synapse.util.logcontext import LoggingContext | |||
from tests.http import ServerTLSContext | |||
@@ -44,11 +49,14 @@ class MatrixFederationAgentTests(TestCase): | |||
self.mock_resolver = Mock() | |||
self.well_known_cache = TTLCache("test_cache", timer=self.reactor.seconds) | |||
self.agent = MatrixFederationAgent( | |||
reactor=self.reactor, | |||
tls_client_options_factory=ClientTLSOptionsFactory(None), | |||
_well_known_tls_policy=TrustingTLSPolicyForHTTPS(), | |||
_srv_resolver=self.mock_resolver, | |||
_well_known_cache=self.well_known_cache, | |||
) | |||
def _make_connection(self, client_factory, expected_sni): | |||
@@ -115,7 +123,9 @@ class MatrixFederationAgentTests(TestCase): | |||
finally: | |||
_check_logcontext(context) | |||
def _handle_well_known_connection(self, client_factory, expected_sni, target_server): | |||
def _handle_well_known_connection( | |||
self, client_factory, expected_sni, target_server, response_headers={}, | |||
): | |||
"""Handle an outgoing HTTPs connection: wire it up to a server, check that the | |||
request is for a .well-known, and send the response. | |||
@@ -124,6 +134,8 @@ class MatrixFederationAgentTests(TestCase): | |||
expected_sni (bytes): SNI that we expect the outgoing connection to send | |||
target_server (bytes): target server that we should redirect to in the | |||
.well-known response. | |||
Returns: | |||
HTTPChannel: server impl | |||
""" | |||
# make the connection for .well-known | |||
well_known_server = self._make_connection( | |||
@@ -133,9 +145,10 @@ class MatrixFederationAgentTests(TestCase): | |||
# check the .well-known request and send a response | |||
self.assertEqual(len(well_known_server.requests), 1) | |||
request = well_known_server.requests[0] | |||
self._send_well_known_response(request, target_server) | |||
self._send_well_known_response(request, target_server, headers=response_headers) | |||
return well_known_server | |||
def _send_well_known_response(self, request, target_server): | |||
def _send_well_known_response(self, request, target_server, headers={}): | |||
"""Check that an incoming request looks like a valid .well-known request, and | |||
send back the response. | |||
""" | |||
@@ -146,6 +159,8 @@ class MatrixFederationAgentTests(TestCase): | |||
[b'testserv'], | |||
) | |||
# send back a response | |||
for k, v in headers.items(): | |||
request.setHeader(k, v) | |||
request.write(b'{ "m.server": "%s" }' % (target_server,)) | |||
request.finish() | |||
@@ -448,6 +463,13 @@ class MatrixFederationAgentTests(TestCase): | |||
self.reactor.pump((0.1,)) | |||
self.successResultOf(test_d) | |||
self.assertEqual(self.well_known_cache[b"testserv"], b"target-server") | |||
# check the cache expires | |||
self.reactor.pump((25 * 3600,)) | |||
self.well_known_cache.expire() | |||
self.assertNotIn(b"testserv", self.well_known_cache) | |||
def test_get_hostname_srv(self): | |||
""" | |||
Test the behaviour when there is a single SRV record | |||
@@ -661,6 +683,126 @@ class MatrixFederationAgentTests(TestCase): | |||
self.reactor.pump((0.1,)) | |||
self.successResultOf(test_d) | |||
@defer.inlineCallbacks | |||
def do_get_well_known(self, serv): | |||
try: | |||
result = yield self.agent._get_well_known(serv) | |||
logger.info("Result from well-known fetch: %s", result) | |||
except Exception as e: | |||
logger.warning("Error fetching well-known: %s", e) | |||
raise | |||
defer.returnValue(result) | |||
def test_well_known_cache(self): | |||
self.reactor.lookups["testserv"] = "1.2.3.4" | |||
fetch_d = self.do_get_well_known(b'testserv') | |||
# there should be an attempt to connect on port 443 for the .well-known | |||
clients = self.reactor.tcpClients | |||
self.assertEqual(len(clients), 1) | |||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) | |||
self.assertEqual(host, '1.2.3.4') | |||
self.assertEqual(port, 443) | |||
well_known_server = self._handle_well_known_connection( | |||
client_factory, | |||
expected_sni=b"testserv", | |||
response_headers={b'Cache-Control': b'max-age=10'}, | |||
target_server=b"target-server", | |||
) | |||
r = self.successResultOf(fetch_d) | |||
self.assertEqual(r, b'target-server') | |||
# close the tcp connection | |||
well_known_server.loseConnection() | |||
# repeat the request: it should hit the cache | |||
fetch_d = self.do_get_well_known(b'testserv') | |||
r = self.successResultOf(fetch_d) | |||
self.assertEqual(r, b'target-server') | |||
# expire the cache | |||
self.reactor.pump((10.0,)) | |||
# now it should connect again | |||
fetch_d = self.do_get_well_known(b'testserv') | |||
self.assertEqual(len(clients), 1) | |||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) | |||
self.assertEqual(host, '1.2.3.4') | |||
self.assertEqual(port, 443) | |||
self._handle_well_known_connection( | |||
client_factory, | |||
expected_sni=b"testserv", | |||
target_server=b"other-server", | |||
) | |||
r = self.successResultOf(fetch_d) | |||
self.assertEqual(r, b'other-server') | |||
class TestCachePeriodFromHeaders(TestCase): | |||
def test_cache_control(self): | |||
# uppercase | |||
self.assertEqual( | |||
_cache_period_from_headers( | |||
Headers({b'Cache-Control': [b'foo, Max-Age = 100, bar']}), | |||
), 100, | |||
) | |||
# missing value | |||
self.assertIsNone(_cache_period_from_headers( | |||
Headers({b'Cache-Control': [b'max-age=, bar']}), | |||
)) | |||
# hackernews: bogus due to semicolon | |||
self.assertIsNone(_cache_period_from_headers( | |||
Headers({b'Cache-Control': [b'private; max-age=0']}), | |||
)) | |||
# github | |||
self.assertEqual( | |||
_cache_period_from_headers( | |||
Headers({b'Cache-Control': [b'max-age=0, private, must-revalidate']}), | |||
), 0, | |||
) | |||
self.assertEqual( | |||
_cache_period_from_headers( | |||
Headers({b'cache-control': [b'private, max-age=0']}), | |||
), 0, | |||
) | |||
def test_expires(self): | |||
self.assertEqual( | |||
_cache_period_from_headers( | |||
Headers({b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT']}), | |||
time_now=lambda: 1548833700 | |||
), 33, | |||
) | |||
# cache-control overrides expires | |||
self.assertEqual( | |||
_cache_period_from_headers( | |||
Headers({ | |||
b'cache-control': [b'max-age=10'], | |||
b'Expires': [b'Wed, 30 Jan 2019 07:35:33 GMT'] | |||
}), | |||
time_now=lambda: 1548833700 | |||
), 10, | |||
) | |||
# invalid expires means immediate expiry | |||
self.assertEqual( | |||
_cache_period_from_headers( | |||
Headers({b'Expires': [b'0']}), | |||
), 0, | |||
) | |||
def _check_logcontext(context): | |||
current = LoggingContext.current_context() | |||
@@ -360,6 +360,7 @@ class FakeTransport(object): | |||
""" | |||
disconnecting = False | |||
disconnected = False | |||
buffer = attr.ib(default=b'') | |||
producer = attr.ib(default=None) | |||
@@ -370,14 +371,16 @@ class FakeTransport(object): | |||
return None | |||
def loseConnection(self, reason=None): | |||
logger.info("FakeTransport: loseConnection(%s)", reason) | |||
if not self.disconnecting: | |||
logger.info("FakeTransport: loseConnection(%s)", reason) | |||
self.disconnecting = True | |||
if self._protocol: | |||
self._protocol.connectionLost(reason) | |||
self.disconnected = True | |||
def abortConnection(self): | |||
self.disconnecting = True | |||
logger.info("FakeTransport: abortConnection()") | |||
self.loseConnection() | |||
def pauseProducing(self): | |||
if not self.producer: | |||
@@ -416,9 +419,16 @@ class FakeTransport(object): | |||
# TLSMemoryBIOProtocol | |||
return | |||
if self.disconnected: | |||
return | |||
logger.info("%s->%s: %s", self._protocol, self.other, self.buffer) | |||
if getattr(self.other, "transport") is not None: | |||
self.other.dataReceived(self.buffer) | |||
self.buffer = b"" | |||
try: | |||
self.other.dataReceived(self.buffer) | |||
self.buffer = b"" | |||
except Exception as e: | |||
logger.warning("Exception writing to protocol: %s", e) | |||
return | |||
self._reactor.callLater(0.0, _write) | |||
@@ -0,0 +1,83 @@ | |||
# -*- coding: utf-8 -*- | |||
# Copyright 2019 New Vector Ltd | |||
# | |||
# Licensed under the Apache License, Version 2.0 (the "License"); | |||
# you may not use this file except in compliance with the License. | |||
# You may obtain a copy of the License at | |||
# | |||
# http://www.apache.org/licenses/LICENSE-2.0 | |||
# | |||
# Unless required by applicable law or agreed to in writing, software | |||
# distributed under the License is distributed on an "AS IS" BASIS, | |||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
from mock import Mock | |||
from synapse.util.caches.ttlcache import TTLCache | |||
from tests import unittest | |||
class CacheTestCase(unittest.TestCase): | |||
def setUp(self): | |||
self.mock_timer = Mock(side_effect=lambda: 100.0) | |||
self.cache = TTLCache("test_cache", self.mock_timer) | |||
def test_get(self): | |||
"""simple set/get tests""" | |||
self.cache.set('one', '1', 10) | |||
self.cache.set('two', '2', 20) | |||
self.cache.set('three', '3', 30) | |||
self.assertEqual(len(self.cache), 3) | |||
self.assertTrue('one' in self.cache) | |||
self.assertEqual(self.cache.get('one'), '1') | |||
self.assertEqual(self.cache['one'], '1') | |||
self.assertEqual(self.cache.get_with_expiry('one'), ('1', 110)) | |||
self.assertEqual(self.cache._metrics.hits, 3) | |||
self.assertEqual(self.cache._metrics.misses, 0) | |||
self.cache.set('two', '2.5', 20) | |||
self.assertEqual(self.cache['two'], '2.5') | |||
self.assertEqual(self.cache._metrics.hits, 4) | |||
# non-existent-item tests | |||
self.assertEqual(self.cache.get('four', '4'), '4') | |||
self.assertIs(self.cache.get('four', None), None) | |||
with self.assertRaises(KeyError): | |||
self.cache['four'] | |||
with self.assertRaises(KeyError): | |||
self.cache.get('four') | |||
with self.assertRaises(KeyError): | |||
self.cache.get_with_expiry('four') | |||
self.assertEqual(self.cache._metrics.hits, 4) | |||
self.assertEqual(self.cache._metrics.misses, 5) | |||
def test_expiry(self): | |||
self.cache.set('one', '1', 10) | |||
self.cache.set('two', '2', 20) | |||
self.cache.set('three', '3', 30) | |||
self.assertEqual(len(self.cache), 3) | |||
self.assertEqual(self.cache['one'], '1') | |||
self.assertEqual(self.cache['two'], '2') | |||
# enough for the first entry to expire, but not the rest | |||
self.mock_timer.side_effect = lambda: 110.0 | |||
self.assertEqual(len(self.cache), 2) | |||
self.assertFalse('one' in self.cache) | |||
self.assertEqual(self.cache['two'], '2') | |||
self.assertEqual(self.cache['three'], '3') | |||
self.assertEqual(self.cache.get_with_expiry('two'), ('2', 120)) | |||
self.assertEqual(self.cache._metrics.hits, 5) | |||
self.assertEqual(self.cache._metrics.misses, 0) |