@@ -9,4 +9,3 @@ source= | |||
[report] | |||
precision = 2 | |||
ignore_errors = True |
@@ -59,7 +59,5 @@ env/ | |||
.vscode/ | |||
.ropeproject/ | |||
*.deb | |||
/debs |
@@ -0,0 +1 @@ | |||
Getting URL previews of IP addresses no longer fails on Python 3. |
@@ -21,28 +21,25 @@ from six.moves import urllib | |||
import treq | |||
from canonicaljson import encode_canonical_json, json | |||
from netaddr import IPAddress | |||
from prometheus_client import Counter | |||
from zope.interface import implementer, provider | |||
from OpenSSL import SSL | |||
from OpenSSL.SSL import VERIFY_NONE | |||
from twisted.internet import defer, protocol, reactor, ssl | |||
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS | |||
from twisted.web._newclient import ResponseDone | |||
from twisted.web.client import ( | |||
Agent, | |||
BrowserLikeRedirectAgent, | |||
ContentDecoderAgent, | |||
GzipDecoder, | |||
HTTPConnectionPool, | |||
PartialDownloadError, | |||
readBody, | |||
from twisted.internet import defer, protocol, ssl | |||
from twisted.internet.interfaces import ( | |||
IReactorPluggableNameResolver, | |||
IResolutionReceiver, | |||
) | |||
from twisted.python.failure import Failure | |||
from twisted.web._newclient import ResponseDone | |||
from twisted.web.client import Agent, HTTPConnectionPool, PartialDownloadError, readBody | |||
from twisted.web.http import PotentialDataLoss | |||
from twisted.web.http_headers import Headers | |||
from synapse.api.errors import Codes, HttpResponseException, SynapseError | |||
from synapse.http import cancelled_to_request_timed_out_error, redact_uri | |||
from synapse.http.endpoint import SpiderEndpoint | |||
from synapse.util.async_helpers import timeout_deferred | |||
from synapse.util.caches import CACHE_SIZE_FACTOR | |||
from synapse.util.logcontext import make_deferred_yieldable | |||
@@ -50,8 +47,125 @@ from synapse.util.logcontext import make_deferred_yieldable | |||
logger = logging.getLogger(__name__) | |||
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) | |||
incoming_responses_counter = Counter("synapse_http_client_responses", "", | |||
["method", "code"]) | |||
incoming_responses_counter = Counter( | |||
"synapse_http_client_responses", "", ["method", "code"] | |||
) | |||
def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): | |||
""" | |||
Args: | |||
ip_address (netaddr.IPAddress) | |||
ip_whitelist (netaddr.IPSet) | |||
ip_blacklist (netaddr.IPSet) | |||
""" | |||
if ip_address in ip_blacklist: | |||
if ip_whitelist is None or ip_address not in ip_whitelist: | |||
return True | |||
return False | |||
class IPBlacklistingResolver(object): | |||
""" | |||
A proxy for reactor.nameResolver which only produces non-blacklisted IP | |||
addresses, preventing DNS rebinding attacks on URL preview. | |||
""" | |||
def __init__(self, reactor, ip_whitelist, ip_blacklist): | |||
""" | |||
Args: | |||
reactor (twisted.internet.reactor) | |||
ip_whitelist (netaddr.IPSet) | |||
ip_blacklist (netaddr.IPSet) | |||
""" | |||
self._reactor = reactor | |||
self._ip_whitelist = ip_whitelist | |||
self._ip_blacklist = ip_blacklist | |||
def resolveHostName(self, recv, hostname, portNumber=0): | |||
r = recv() | |||
d = defer.Deferred() | |||
addresses = [] | |||
@provider(IResolutionReceiver) | |||
class EndpointReceiver(object): | |||
@staticmethod | |||
def resolutionBegan(resolutionInProgress): | |||
pass | |||
@staticmethod | |||
def addressResolved(address): | |||
ip_address = IPAddress(address.host) | |||
if check_against_blacklist( | |||
ip_address, self._ip_whitelist, self._ip_blacklist | |||
): | |||
logger.info( | |||
"Dropped %s from DNS resolution to %s" % (ip_address, hostname) | |||
) | |||
raise SynapseError(403, "IP address blocked by IP blacklist entry") | |||
addresses.append(address) | |||
@staticmethod | |||
def resolutionComplete(): | |||
d.callback(addresses) | |||
self._reactor.nameResolver.resolveHostName( | |||
EndpointReceiver, hostname, portNumber=portNumber | |||
) | |||
def _callback(addrs): | |||
r.resolutionBegan(None) | |||
for i in addrs: | |||
r.addressResolved(i) | |||
r.resolutionComplete() | |||
d.addCallback(_callback) | |||
return r | |||
class BlacklistingAgentWrapper(Agent): | |||
""" | |||
An Agent wrapper which will prevent access to IP addresses being accessed | |||
directly (without an IP address lookup). | |||
""" | |||
def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None): | |||
""" | |||
Args: | |||
agent (twisted.web.client.Agent): The Agent to wrap. | |||
reactor (twisted.internet.reactor) | |||
ip_whitelist (netaddr.IPSet) | |||
ip_blacklist (netaddr.IPSet) | |||
""" | |||
self._agent = agent | |||
self._ip_whitelist = ip_whitelist | |||
self._ip_blacklist = ip_blacklist | |||
def request(self, method, uri, headers=None, bodyProducer=None): | |||
h = urllib.parse.urlparse(uri.decode('ascii')) | |||
try: | |||
ip_address = IPAddress(h.hostname) | |||
if check_against_blacklist( | |||
ip_address, self._ip_whitelist, self._ip_blacklist | |||
): | |||
logger.info( | |||
"Blocking access to %s because of blacklist" % (ip_address,) | |||
) | |||
e = SynapseError(403, "IP address blocked by IP blacklist entry") | |||
return defer.fail(Failure(e)) | |||
except Exception: | |||
# Not an IP | |||
pass | |||
return self._agent.request( | |||
method, uri, headers=headers, bodyProducer=bodyProducer | |||
) | |||
class SimpleHttpClient(object): | |||
@@ -59,14 +173,54 @@ class SimpleHttpClient(object): | |||
A simple, no-frills HTTP client with methods that wrap up common ways of | |||
using HTTP in Matrix | |||
""" | |||
def __init__(self, hs): | |||
def __init__(self, hs, treq_args={}, ip_whitelist=None, ip_blacklist=None): | |||
""" | |||
Args: | |||
hs (synapse.server.HomeServer) | |||
treq_args (dict): Extra keyword arguments to be given to treq.request. | |||
ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that | |||
we may not request. | |||
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can | |||
request if it were otherwise caught in a blacklist. | |||
""" | |||
self.hs = hs | |||
pool = HTTPConnectionPool(reactor) | |||
self._ip_whitelist = ip_whitelist | |||
self._ip_blacklist = ip_blacklist | |||
self._extra_treq_args = treq_args | |||
self.user_agent = hs.version_string | |||
self.clock = hs.get_clock() | |||
if hs.config.user_agent_suffix: | |||
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix) | |||
self.user_agent = self.user_agent.encode('ascii') | |||
if self._ip_blacklist: | |||
real_reactor = hs.get_reactor() | |||
# If we have an IP blacklist, we need to use a DNS resolver which | |||
# filters out blacklisted IP addresses, to prevent DNS rebinding. | |||
nameResolver = IPBlacklistingResolver( | |||
real_reactor, self._ip_whitelist, self._ip_blacklist | |||
) | |||
@implementer(IReactorPluggableNameResolver) | |||
class Reactor(object): | |||
def __getattr__(_self, attr): | |||
if attr == "nameResolver": | |||
return nameResolver | |||
else: | |||
return getattr(real_reactor, attr) | |||
self.reactor = Reactor() | |||
else: | |||
self.reactor = hs.get_reactor() | |||
# the pusher makes lots of concurrent SSL connections to sygnal, and | |||
# tends to do so in batches, so we need to allow the pool to keep lots | |||
# of idle connections around. | |||
# tends to do so in batches, so we need to allow the pool to keep | |||
# lots of idle connections around. | |||
pool = HTTPConnectionPool(self.reactor) | |||
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) | |||
pool.cachedConnectionTimeout = 2 * 60 | |||
@@ -74,20 +228,35 @@ class SimpleHttpClient(object): | |||
# BrowserLikePolicyForHTTPS which will do regular cert validation | |||
# 'like a browser' | |||
self.agent = Agent( | |||
reactor, | |||
self.reactor, | |||
connectTimeout=15, | |||
contextFactory=hs.get_http_client_context_factory(), | |||
contextFactory=self.hs.get_http_client_context_factory(), | |||
pool=pool, | |||
) | |||
self.user_agent = hs.version_string | |||
self.clock = hs.get_clock() | |||
if hs.config.user_agent_suffix: | |||
self.user_agent = "%s %s" % (self.user_agent, hs.config.user_agent_suffix,) | |||
self.user_agent = self.user_agent.encode('ascii') | |||
if self._ip_blacklist: | |||
# If we have an IP blacklist, we then install the blacklisting Agent | |||
# which prevents direct access to IP addresses, that are not caught | |||
# by the DNS resolution. | |||
self.agent = BlacklistingAgentWrapper( | |||
self.agent, | |||
self.reactor, | |||
ip_whitelist=self._ip_whitelist, | |||
ip_blacklist=self._ip_blacklist, | |||
) | |||
@defer.inlineCallbacks | |||
def request(self, method, uri, data=b'', headers=None): | |||
""" | |||
Args: | |||
method (str): HTTP method to use. | |||
uri (str): URI to query. | |||
data (bytes): Data to send in the request body, if applicable. | |||
headers (t.w.http_headers.Headers): Request headers. | |||
Raises: | |||
SynapseError: If the IP is blacklisted. | |||
""" | |||
# A small wrapper around self.agent.request() so we can easily attach | |||
# counters to it | |||
outgoing_requests_counter.labels(method).inc() | |||
@@ -97,25 +266,34 @@ class SimpleHttpClient(object): | |||
try: | |||
request_deferred = treq.request( | |||
method, uri, agent=self.agent, data=data, headers=headers | |||
method, | |||
uri, | |||
agent=self.agent, | |||
data=data, | |||
headers=headers, | |||
**self._extra_treq_args | |||
) | |||
request_deferred = timeout_deferred( | |||
request_deferred, 60, self.hs.get_reactor(), | |||
request_deferred, | |||
60, | |||
self.hs.get_reactor(), | |||
cancelled_to_request_timed_out_error, | |||
) | |||
response = yield make_deferred_yieldable(request_deferred) | |||
incoming_responses_counter.labels(method, response.code).inc() | |||
logger.info( | |||
"Received response to %s %s: %s", | |||
method, redact_uri(uri), response.code | |||
"Received response to %s %s: %s", method, redact_uri(uri), response.code | |||
) | |||
defer.returnValue(response) | |||
except Exception as e: | |||
incoming_responses_counter.labels(method, "ERR").inc() | |||
logger.info( | |||
"Error sending request to %s %s: %s %s", | |||
method, redact_uri(uri), type(e).__name__, e.args[0] | |||
method, | |||
redact_uri(uri), | |||
type(e).__name__, | |||
e.args[0], | |||
) | |||
raise | |||
@@ -140,8 +318,9 @@ class SimpleHttpClient(object): | |||
# TODO: Do we ever want to log message contents? | |||
logger.debug("post_urlencoded_get_json args: %s", args) | |||
query_bytes = urllib.parse.urlencode( | |||
encode_urlencode_args(args), True).encode("utf8") | |||
query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode( | |||
"utf8" | |||
) | |||
actual_headers = { | |||
b"Content-Type": [b"application/x-www-form-urlencoded"], | |||
@@ -151,10 +330,7 @@ class SimpleHttpClient(object): | |||
actual_headers.update(headers) | |||
response = yield self.request( | |||
"POST", | |||
uri, | |||
headers=Headers(actual_headers), | |||
data=query_bytes | |||
"POST", uri, headers=Headers(actual_headers), data=query_bytes | |||
) | |||
if 200 <= response.code < 300: | |||
@@ -193,10 +369,7 @@ class SimpleHttpClient(object): | |||
actual_headers.update(headers) | |||
response = yield self.request( | |||
"POST", | |||
uri, | |||
headers=Headers(actual_headers), | |||
data=json_str | |||
"POST", uri, headers=Headers(actual_headers), data=json_str | |||
) | |||
body = yield make_deferred_yieldable(readBody(response)) | |||
@@ -264,10 +437,7 @@ class SimpleHttpClient(object): | |||
actual_headers.update(headers) | |||
response = yield self.request( | |||
"PUT", | |||
uri, | |||
headers=Headers(actual_headers), | |||
data=json_str | |||
"PUT", uri, headers=Headers(actual_headers), data=json_str | |||
) | |||
body = yield make_deferred_yieldable(readBody(response)) | |||
@@ -299,17 +469,11 @@ class SimpleHttpClient(object): | |||
query_bytes = urllib.parse.urlencode(args, True) | |||
uri = "%s?%s" % (uri, query_bytes) | |||
actual_headers = { | |||
b"User-Agent": [self.user_agent], | |||
} | |||
actual_headers = {b"User-Agent": [self.user_agent]} | |||
if headers: | |||
actual_headers.update(headers) | |||
response = yield self.request( | |||
"GET", | |||
uri, | |||
headers=Headers(actual_headers), | |||
) | |||
response = yield self.request("GET", uri, headers=Headers(actual_headers)) | |||
body = yield make_deferred_yieldable(readBody(response)) | |||
@@ -334,22 +498,18 @@ class SimpleHttpClient(object): | |||
headers, absolute URI of the response and HTTP response code. | |||
""" | |||
actual_headers = { | |||
b"User-Agent": [self.user_agent], | |||
} | |||
actual_headers = {b"User-Agent": [self.user_agent]} | |||
if headers: | |||
actual_headers.update(headers) | |||
response = yield self.request( | |||
"GET", | |||
url, | |||
headers=Headers(actual_headers), | |||
) | |||
response = yield self.request("GET", url, headers=Headers(actual_headers)) | |||
resp_headers = dict(response.headers.getAllRawHeaders()) | |||
if (b'Content-Length' in resp_headers and | |||
int(resp_headers[b'Content-Length']) > max_size): | |||
if ( | |||
b'Content-Length' in resp_headers | |||
and int(resp_headers[b'Content-Length'][0]) > max_size | |||
): | |||
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,)) | |||
raise SynapseError( | |||
502, | |||
@@ -359,26 +519,20 @@ class SimpleHttpClient(object): | |||
if response.code > 299: | |||
logger.warn("Got %d when downloading %s" % (response.code, url)) | |||
raise SynapseError( | |||
502, | |||
"Got error %d" % (response.code,), | |||
Codes.UNKNOWN, | |||
) | |||
raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN) | |||
# TODO: if our Content-Type is HTML or something, just read the first | |||
# N bytes into RAM rather than saving it all to disk only to read it | |||
# straight back in again | |||
try: | |||
length = yield make_deferred_yieldable(_readBodyToFile( | |||
response, output_stream, max_size, | |||
)) | |||
length = yield make_deferred_yieldable( | |||
_readBodyToFile(response, output_stream, max_size) | |||
) | |||
except Exception as e: | |||
logger.exception("Failed to download body") | |||
raise SynapseError( | |||
502, | |||
("Failed to download remote body: %s" % e), | |||
Codes.UNKNOWN, | |||
502, ("Failed to download remote body: %s" % e), Codes.UNKNOWN | |||
) | |||
defer.returnValue( | |||
@@ -387,13 +541,14 @@ class SimpleHttpClient(object): | |||
resp_headers, | |||
response.request.absoluteURI.decode('ascii'), | |||
response.code, | |||
), | |||
) | |||
) | |||
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. | |||
# The two should be factored out. | |||
class _ReadBodyToFileProtocol(protocol.Protocol): | |||
def __init__(self, stream, deferred, max_size): | |||
self.stream = stream | |||
@@ -405,11 +560,13 @@ class _ReadBodyToFileProtocol(protocol.Protocol): | |||
self.stream.write(data) | |||
self.length += len(data) | |||
if self.max_size is not None and self.length >= self.max_size: | |||
self.deferred.errback(SynapseError( | |||
502, | |||
"Requested file is too large > %r bytes" % (self.max_size,), | |||
Codes.TOO_LARGE, | |||
)) | |||
self.deferred.errback( | |||
SynapseError( | |||
502, | |||
"Requested file is too large > %r bytes" % (self.max_size,), | |||
Codes.TOO_LARGE, | |||
) | |||
) | |||
self.deferred = defer.Deferred() | |||
self.transport.loseConnection() | |||
@@ -427,6 +584,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol): | |||
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. | |||
# The two should be factored out. | |||
def _readBodyToFile(response, stream, max_size): | |||
d = defer.Deferred() | |||
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) | |||
@@ -449,10 +607,12 @@ class CaptchaServerHttpClient(SimpleHttpClient): | |||
"POST", | |||
url, | |||
data=query_bytes, | |||
headers=Headers({ | |||
b"Content-Type": [b"application/x-www-form-urlencoded"], | |||
b"User-Agent": [self.user_agent], | |||
}) | |||
headers=Headers( | |||
{ | |||
b"Content-Type": [b"application/x-www-form-urlencoded"], | |||
b"User-Agent": [self.user_agent], | |||
} | |||
), | |||
) | |||
try: | |||
@@ -463,57 +623,6 @@ class CaptchaServerHttpClient(SimpleHttpClient): | |||
defer.returnValue(e.response) | |||
class SpiderEndpointFactory(object): | |||
def __init__(self, hs): | |||
self.blacklist = hs.config.url_preview_ip_range_blacklist | |||
self.whitelist = hs.config.url_preview_ip_range_whitelist | |||
self.policyForHTTPS = hs.get_http_client_context_factory() | |||
def endpointForURI(self, uri): | |||
logger.info("Getting endpoint for %s", uri.toBytes()) | |||
if uri.scheme == b"http": | |||
endpoint_factory = HostnameEndpoint | |||
elif uri.scheme == b"https": | |||
tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port) | |||
def endpoint_factory(reactor, host, port, **kw): | |||
return wrapClientTLS( | |||
tlsCreator, | |||
HostnameEndpoint(reactor, host, port, **kw)) | |||
else: | |||
logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme) | |||
return None | |||
return SpiderEndpoint( | |||
reactor, uri.host, uri.port, self.blacklist, self.whitelist, | |||
endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15), | |||
) | |||
class SpiderHttpClient(SimpleHttpClient): | |||
""" | |||
Separate HTTP client for spidering arbitrary URLs. | |||
Special in that it follows retries and has a UA that looks | |||
like a browser. | |||
used by the preview_url endpoint in the content repo. | |||
""" | |||
def __init__(self, hs): | |||
SimpleHttpClient.__init__(self, hs) | |||
# clobber the base class's agent and UA: | |||
self.agent = ContentDecoderAgent( | |||
BrowserLikeRedirectAgent( | |||
Agent.usingEndpointFactory( | |||
reactor, | |||
SpiderEndpointFactory(hs) | |||
) | |||
), [(b'gzip', GzipDecoder)] | |||
) | |||
# We could look like Chrome: | |||
# self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko) | |||
# Chrome Safari" % hs.version_string) | |||
def encode_urlencode_args(args): | |||
return {k: encode_urlencode_arg(v) for k, v in args.items()} | |||
@@ -218,41 +218,6 @@ class _WrappedConnection(object): | |||
return d | |||
class SpiderEndpoint(object): | |||
"""An endpoint which refuses to connect to blacklisted IP addresses | |||
Implements twisted.internet.interfaces.IStreamClientEndpoint. | |||
""" | |||
def __init__(self, reactor, host, port, blacklist, whitelist, | |||
endpoint=HostnameEndpoint, endpoint_kw_args={}): | |||
self.reactor = reactor | |||
self.host = host | |||
self.port = port | |||
self.blacklist = blacklist | |||
self.whitelist = whitelist | |||
self.endpoint = endpoint | |||
self.endpoint_kw_args = endpoint_kw_args | |||
@defer.inlineCallbacks | |||
def connect(self, protocolFactory): | |||
address = yield self.reactor.resolve(self.host) | |||
from netaddr import IPAddress | |||
ip_address = IPAddress(address) | |||
if ip_address in self.blacklist: | |||
if self.whitelist is None or ip_address not in self.whitelist: | |||
raise ConnectError( | |||
"Refusing to spider blacklisted IP address %s" % address | |||
) | |||
logger.info("Connecting to %s:%s", address, self.port) | |||
endpoint = self.endpoint( | |||
self.reactor, address, self.port, **self.endpoint_kw_args | |||
) | |||
connection = yield endpoint.connect(protocolFactory) | |||
defer.returnValue(connection) | |||
class SRVClientEndpoint(object): | |||
"""An endpoint which looks up SRV records for a service. | |||
Cycles through the list of servers starting with each call to connect | |||
@@ -35,7 +35,7 @@ from twisted.web.resource import Resource | |||
from twisted.web.server import NOT_DONE_YET | |||
from synapse.api.errors import Codes, SynapseError | |||
from synapse.http.client import SpiderHttpClient | |||
from synapse.http.client import SimpleHttpClient | |||
from synapse.http.server import ( | |||
respond_with_json, | |||
respond_with_json_bytes, | |||
@@ -69,7 +69,12 @@ class PreviewUrlResource(Resource): | |||
self.max_spider_size = hs.config.max_spider_size | |||
self.server_name = hs.hostname | |||
self.store = hs.get_datastore() | |||
self.client = SpiderHttpClient(hs) | |||
self.client = SimpleHttpClient( | |||
hs, | |||
treq_args={"browser_like_redirects": True}, | |||
ip_whitelist=hs.config.url_preview_ip_range_whitelist, | |||
ip_blacklist=hs.config.url_preview_ip_range_blacklist, | |||
) | |||
self.media_repo = media_repo | |||
self.primary_base_path = media_repo.primary_base_path | |||
self.media_storage = media_storage | |||
@@ -318,6 +323,11 @@ class PreviewUrlResource(Resource): | |||
length, headers, uri, code = yield self.client.get_file( | |||
url, output_stream=f, max_size=self.max_spider_size, | |||
) | |||
except SynapseError: | |||
# Pass SynapseErrors through directly, so that the servlet | |||
# handler will return a SynapseError to the client instead of | |||
# blank data or a 500. | |||
raise | |||
except Exception as e: | |||
# FIXME: pass through 404s and other error messages nicely | |||
logger.warn("Error downloading %s: %r", url, e) | |||
@@ -15,21 +15,55 @@ | |||
import os | |||
from mock import Mock | |||
import attr | |||
from netaddr import IPSet | |||
from twisted.internet.defer import Deferred | |||
from twisted.internet._resolver import HostResolution | |||
from twisted.internet.address import IPv4Address, IPv6Address | |||
from twisted.internet.error import DNSLookupError | |||
from twisted.python.failure import Failure | |||
from twisted.test.proto_helpers import AccumulatingProtocol | |||
from twisted.web._newclient import ResponseDone | |||
from synapse.config.repository import MediaStorageProviderConfig | |||
from synapse.util.logcontext import make_deferred_yieldable | |||
from synapse.util.module_loader import load_module | |||
from tests import unittest | |||
from tests.server import FakeTransport | |||
@attr.s | |||
class FakeResponse(object): | |||
version = attr.ib() | |||
code = attr.ib() | |||
phrase = attr.ib() | |||
headers = attr.ib() | |||
body = attr.ib() | |||
absoluteURI = attr.ib() | |||
@property | |||
def request(self): | |||
@attr.s | |||
class FakeTransport(object): | |||
absoluteURI = self.absoluteURI | |||
return FakeTransport() | |||
def deliverBody(self, protocol): | |||
protocol.dataReceived(self.body) | |||
protocol.connectionLost(Failure(ResponseDone())) | |||
class URLPreviewTests(unittest.HomeserverTestCase): | |||
hijack_auth = True | |||
user_id = "@test:user" | |||
end_content = ( | |||
b'<html><head>' | |||
b'<meta property="og:title" content="~matrix~" />' | |||
b'<meta property="og:description" content="hi" />' | |||
b'</head></html>' | |||
) | |||
def make_homeserver(self, reactor, clock): | |||
@@ -39,6 +73,15 @@ class URLPreviewTests(unittest.HomeserverTestCase): | |||
config = self.default_config() | |||
config.url_preview_enabled = True | |||
config.max_spider_size = 9999999 | |||
config.url_preview_ip_range_blacklist = IPSet( | |||
( | |||
"192.168.1.1", | |||
"1.0.0.0/8", | |||
"3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", | |||
"2001:800::/21", | |||
) | |||
) | |||
config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",)) | |||
config.url_preview_url_blacklist = [] | |||
config.media_store_path = self.storage_path | |||
@@ -62,63 +105,50 @@ class URLPreviewTests(unittest.HomeserverTestCase): | |||
def prepare(self, reactor, clock, hs): | |||
self.fetches = [] | |||
self.media_repo = hs.get_media_repository_resource() | |||
self.preview_url = self.media_repo.children[b'preview_url'] | |||
def get_file(url, output_stream, max_size): | |||
""" | |||
Returns tuple[int,dict,str,int] of file length, response headers, | |||
absolute URI, and response code. | |||
""" | |||
self.lookups = {} | |||
def write_to(r): | |||
data, response = r | |||
output_stream.write(data) | |||
return response | |||
class Resolver(object): | |||
def resolveHostName( | |||
_self, | |||
resolutionReceiver, | |||
hostName, | |||
portNumber=0, | |||
addressTypes=None, | |||
transportSemantics='TCP', | |||
): | |||
d = Deferred() | |||
d.addCallback(write_to) | |||
self.fetches.append((d, url)) | |||
return make_deferred_yieldable(d) | |||
resolution = HostResolution(hostName) | |||
resolutionReceiver.resolutionBegan(resolution) | |||
if hostName not in self.lookups: | |||
raise DNSLookupError("OH NO") | |||
client = Mock() | |||
client.get_file = get_file | |||
for i in self.lookups[hostName]: | |||
resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber)) | |||
resolutionReceiver.resolutionComplete() | |||
return resolutionReceiver | |||
self.media_repo = hs.get_media_repository_resource() | |||
preview_url = self.media_repo.children[b'preview_url'] | |||
preview_url.client = client | |||
self.preview_url = preview_url | |||
self.reactor.nameResolver = Resolver() | |||
def test_cache_returns_correct_type(self): | |||
self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=matrix.org", shorthand=False | |||
"GET", "url_preview?url=http://matrix.org", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
# We've made one fetch | |||
self.assertEqual(len(self.fetches), 1) | |||
end_content = ( | |||
b'<html><head>' | |||
b'<meta property="og:title" content="~matrix~" />' | |||
b'<meta property="og:description" content="hi" />' | |||
b'</head></html>' | |||
) | |||
self.fetches[0][0].callback( | |||
( | |||
end_content, | |||
( | |||
len(end_content), | |||
{ | |||
b"Content-Length": [b"%d" % (len(end_content))], | |||
b"Content-Type": [b'text/html; charset="utf8"'], | |||
}, | |||
"https://example.com", | |||
200, | |||
), | |||
) | |||
client = self.reactor.tcpClients[0][2].buildProtocol(None) | |||
server = AccumulatingProtocol() | |||
server.makeConnection(FakeTransport(client, self.reactor)) | |||
client.makeConnection(FakeTransport(server, self.reactor)) | |||
client.dataReceived( | |||
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" | |||
% (len(self.end_content),) | |||
+ self.end_content | |||
) | |||
self.pump() | |||
@@ -129,14 +159,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): | |||
# Check the cache returns the correct response | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=matrix.org", shorthand=False | |||
"GET", "url_preview?url=http://matrix.org", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
# Only one fetch, still, since we'll lean on the cache | |||
self.assertEqual(len(self.fetches), 1) | |||
# Check the cache response has the same content | |||
self.assertEqual(channel.code, 200) | |||
self.assertEqual( | |||
@@ -144,20 +171,17 @@ class URLPreviewTests(unittest.HomeserverTestCase): | |||
) | |||
# Clear the in-memory cache | |||
self.assertIn("matrix.org", self.preview_url._cache) | |||
self.preview_url._cache.pop("matrix.org") | |||
self.assertNotIn("matrix.org", self.preview_url._cache) | |||
self.assertIn("http://matrix.org", self.preview_url._cache) | |||
self.preview_url._cache.pop("http://matrix.org") | |||
self.assertNotIn("http://matrix.org", self.preview_url._cache) | |||
# Check the database cache returns the correct response | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=matrix.org", shorthand=False | |||
"GET", "url_preview?url=http://matrix.org", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
# Only one fetch, still, since we'll lean on the cache | |||
self.assertEqual(len(self.fetches), 1) | |||
# Check the cache response has the same content | |||
self.assertEqual(channel.code, 200) | |||
self.assertEqual( | |||
@@ -165,78 +189,282 @@ class URLPreviewTests(unittest.HomeserverTestCase): | |||
) | |||
def test_non_ascii_preview_httpequiv(self): | |||
self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] | |||
end_content = ( | |||
b'<html><head>' | |||
b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>' | |||
b'<meta property="og:title" content="\xe4\xea\xe0" />' | |||
b'<meta property="og:description" content="hi" />' | |||
b'</head></html>' | |||
) | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=matrix.org", shorthand=False | |||
"GET", "url_preview?url=http://matrix.org", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
# We've made one fetch | |||
self.assertEqual(len(self.fetches), 1) | |||
client = self.reactor.tcpClients[0][2].buildProtocol(None) | |||
server = AccumulatingProtocol() | |||
server.makeConnection(FakeTransport(client, self.reactor)) | |||
client.makeConnection(FakeTransport(server, self.reactor)) | |||
client.dataReceived( | |||
( | |||
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" | |||
b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n" | |||
) | |||
% (len(end_content),) | |||
+ end_content | |||
) | |||
self.pump() | |||
self.assertEqual(channel.code, 200) | |||
self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") | |||
def test_non_ascii_preview_content_type(self): | |||
self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")] | |||
end_content = ( | |||
b'<html><head>' | |||
b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>' | |||
b'<meta property="og:title" content="\xe4\xea\xe0" />' | |||
b'<meta property="og:description" content="hi" />' | |||
b'</head></html>' | |||
) | |||
self.fetches[0][0].callback( | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=http://matrix.org", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
client = self.reactor.tcpClients[0][2].buildProtocol(None) | |||
server = AccumulatingProtocol() | |||
server.makeConnection(FakeTransport(client, self.reactor)) | |||
client.makeConnection(FakeTransport(server, self.reactor)) | |||
client.dataReceived( | |||
( | |||
end_content, | |||
( | |||
len(end_content), | |||
{ | |||
b"Content-Length": [b"%d" % (len(end_content))], | |||
# This charset=utf-8 should be ignored, because the | |||
# document has a meta tag overriding it. | |||
b"Content-Type": [b'text/html; charset="utf8"'], | |||
}, | |||
"https://example.com", | |||
200, | |||
), | |||
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" | |||
b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n" | |||
) | |||
% (len(end_content),) | |||
+ end_content | |||
) | |||
self.pump() | |||
self.assertEqual(channel.code, 200) | |||
self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") | |||
def test_non_ascii_preview_content_type(self): | |||
def test_ipaddr(self): | |||
""" | |||
IP addresses can be previewed directly. | |||
""" | |||
self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")] | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=matrix.org", shorthand=False | |||
"GET", "url_preview?url=http://example.com", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
# We've made one fetch | |||
self.assertEqual(len(self.fetches), 1) | |||
client = self.reactor.tcpClients[0][2].buildProtocol(None) | |||
server = AccumulatingProtocol() | |||
server.makeConnection(FakeTransport(client, self.reactor)) | |||
client.makeConnection(FakeTransport(server, self.reactor)) | |||
client.dataReceived( | |||
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" | |||
% (len(self.end_content),) | |||
+ self.end_content | |||
) | |||
end_content = ( | |||
b'<html><head>' | |||
b'<meta property="og:title" content="\xe4\xea\xe0" />' | |||
b'<meta property="og:description" content="hi" />' | |||
b'</head></html>' | |||
self.pump() | |||
self.assertEqual(channel.code, 200) | |||
self.assertEqual( | |||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} | |||
) | |||
self.fetches[0][0].callback( | |||
( | |||
end_content, | |||
( | |||
len(end_content), | |||
{ | |||
b"Content-Length": [b"%d" % (len(end_content))], | |||
b"Content-Type": [b'text/html; charset="windows-1251"'], | |||
}, | |||
"https://example.com", | |||
200, | |||
), | |||
) | |||
def test_blacklisted_ip_specific(self): | |||
""" | |||
Blacklisted IP addresses, found via DNS, are not spidered. | |||
""" | |||
self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=http://example.com", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
# No requests made. | |||
self.assertEqual(len(self.reactor.tcpClients), 0) | |||
self.assertEqual(channel.code, 403) | |||
self.assertEqual( | |||
channel.json_body, | |||
{ | |||
'errcode': 'M_UNKNOWN', | |||
'error': 'IP address blocked by IP blacklist entry', | |||
}, | |||
) | |||
def test_blacklisted_ip_range(self): | |||
""" | |||
Blacklisted IP ranges, IPs found over DNS, are not spidered. | |||
""" | |||
self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=http://example.com", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
self.assertEqual(channel.code, 403) | |||
self.assertEqual( | |||
channel.json_body, | |||
{ | |||
'errcode': 'M_UNKNOWN', | |||
'error': 'IP address blocked by IP blacklist entry', | |||
}, | |||
) | |||
def test_blacklisted_ip_specific_direct(self): | |||
""" | |||
Blacklisted IP addresses, accessed directly, are not spidered. | |||
""" | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=http://192.168.1.1", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
# No requests made. | |||
self.assertEqual(len(self.reactor.tcpClients), 0) | |||
self.assertEqual(channel.code, 403) | |||
self.assertEqual( | |||
channel.json_body, | |||
{ | |||
'errcode': 'M_UNKNOWN', | |||
'error': 'IP address blocked by IP blacklist entry', | |||
}, | |||
) | |||
def test_blacklisted_ip_range_direct(self): | |||
""" | |||
Blacklisted IP ranges, accessed directly, are not spidered. | |||
""" | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=http://1.1.1.2", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
self.assertEqual(channel.code, 403) | |||
self.assertEqual( | |||
channel.json_body, | |||
{ | |||
'errcode': 'M_UNKNOWN', | |||
'error': 'IP address blocked by IP blacklist entry', | |||
}, | |||
) | |||
def test_blacklisted_ip_range_whitelisted_ip(self): | |||
""" | |||
Blacklisted but then subsequently whitelisted IP addresses can be | |||
spidered. | |||
""" | |||
self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")] | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=http://example.com", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
client = self.reactor.tcpClients[0][2].buildProtocol(None) | |||
server = AccumulatingProtocol() | |||
server.makeConnection(FakeTransport(client, self.reactor)) | |||
client.makeConnection(FakeTransport(server, self.reactor)) | |||
client.dataReceived( | |||
b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n" | |||
% (len(self.end_content),) | |||
+ self.end_content | |||
) | |||
self.pump() | |||
self.assertEqual(channel.code, 200) | |||
self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430") | |||
self.assertEqual( | |||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"} | |||
) | |||
def test_blacklisted_ip_with_external_ip(self): | |||
""" | |||
If a hostname resolves a blacklisted IP, even if there's a | |||
non-blacklisted one, it will be rejected. | |||
""" | |||
# Hardcode the URL resolving to the IP we want. | |||
self.lookups[u"example.com"] = [ | |||
(IPv4Address, "1.1.1.2"), | |||
(IPv4Address, "8.8.8.8"), | |||
] | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=http://example.com", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
self.assertEqual(channel.code, 403) | |||
self.assertEqual( | |||
channel.json_body, | |||
{ | |||
'errcode': 'M_UNKNOWN', | |||
'error': 'IP address blocked by IP blacklist entry', | |||
}, | |||
) | |||
def test_blacklisted_ipv6_specific(self): | |||
""" | |||
Blacklisted IP addresses, found via DNS, are not spidered. | |||
""" | |||
self.lookups["example.com"] = [ | |||
(IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff") | |||
] | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=http://example.com", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
# No requests made. | |||
self.assertEqual(len(self.reactor.tcpClients), 0) | |||
self.assertEqual(channel.code, 403) | |||
self.assertEqual( | |||
channel.json_body, | |||
{ | |||
'errcode': 'M_UNKNOWN', | |||
'error': 'IP address blocked by IP blacklist entry', | |||
}, | |||
) | |||
def test_blacklisted_ipv6_range(self): | |||
""" | |||
Blacklisted IP ranges, IPs found over DNS, are not spidered. | |||
""" | |||
self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] | |||
request, channel = self.make_request( | |||
"GET", "url_preview?url=http://example.com", shorthand=False | |||
) | |||
request.render(self.preview_url) | |||
self.pump() | |||
self.assertEqual(channel.code, 403) | |||
self.assertEqual( | |||
channel.json_body, | |||
{ | |||
'errcode': 'M_UNKNOWN', | |||
'error': 'IP address blocked by IP blacklist entry', | |||
}, | |||
) |
@@ -383,8 +383,16 @@ class FakeTransport(object): | |||
self.disconnecting = True | |||
def pauseProducing(self): | |||
if not self.producer: | |||
return | |||
self.producer.pauseProducing() | |||
def resumeProducing(self): | |||
if not self.producer: | |||
return | |||
self.producer.resumeProducing() | |||
def unregisterProducer(self): | |||
if not self.producer: | |||
return | |||