... otherwise, we don't get the cookie back.tags/v1.29.0rc1
@@ -0,0 +1 @@ | |||
Fix a bug in single sign-on which could cause a "No session cookie found" error. |
@@ -14,8 +14,9 @@ | |||
# See the License for the specific language governing permissions and | |||
# limitations under the License. | |||
import re | |||
from typing import Union | |||
from twisted.internet import task | |||
from twisted.internet import address, task | |||
from twisted.web.client import FileBodyProducer | |||
from twisted.web.iweb import IRequest | |||
@@ -53,6 +54,40 @@ class QuieterFileBodyProducer(FileBodyProducer): | |||
pass | |||
def get_request_uri(request: IRequest) -> bytes: | |||
"""Return the full URI that was requested by the client""" | |||
return b"%s://%s%s" % ( | |||
b"https" if request.isSecure() else b"http", | |||
_get_requested_host(request), | |||
# despite its name, "request.uri" is only the path and query-string. | |||
request.uri, | |||
) | |||
def _get_requested_host(request: IRequest) -> bytes: | |||
hostname = request.getHeader(b"host") | |||
if hostname: | |||
return hostname | |||
# no Host header, use the address/port that the request arrived on | |||
host = request.getHost() # type: Union[address.IPv4Address, address.IPv6Address] | |||
hostname = host.host.encode("ascii") | |||
if request.isSecure() and host.port == 443: | |||
# default port for https | |||
return hostname | |||
if not request.isSecure() and host.port == 80: | |||
# default port for http | |||
return hostname | |||
return b"%s:%i" % ( | |||
hostname, | |||
host.port, | |||
) | |||
def get_request_user_agent(request: IRequest, default: str = "") -> str: | |||
"""Return the last User-Agent header, or the given default.""" | |||
# There could be raw utf-8 bytes in the User-Agent header. | |||
@@ -20,6 +20,7 @@ from synapse.api.errors import Codes, LoginError, SynapseError | |||
from synapse.api.ratelimiting import Ratelimiter | |||
from synapse.appservice import ApplicationService | |||
from synapse.handlers.sso import SsoIdentityProvider | |||
from synapse.http import get_request_uri | |||
from synapse.http.server import HttpServer, finish_request | |||
from synapse.http.servlet import ( | |||
RestServlet, | |||
@@ -354,6 +355,7 @@ class SsoRedirectServlet(RestServlet): | |||
hs.get_oidc_handler() | |||
self._sso_handler = hs.get_sso_handler() | |||
self._msc2858_enabled = hs.config.experimental.msc2858_enabled | |||
self._public_baseurl = hs.config.public_baseurl | |||
def register(self, http_server: HttpServer) -> None: | |||
super().register(http_server) | |||
@@ -373,6 +375,32 @@ class SsoRedirectServlet(RestServlet): | |||
async def on_GET( | |||
self, request: SynapseRequest, idp_id: Optional[str] = None | |||
) -> None: | |||
if not self._public_baseurl: | |||
raise SynapseError(400, "SSO requires a valid public_baseurl") | |||
# if this isn't the expected hostname, redirect to the right one, so that we | |||
# get our cookies back. | |||
requested_uri = get_request_uri(request) | |||
baseurl_bytes = self._public_baseurl.encode("utf-8") | |||
if not requested_uri.startswith(baseurl_bytes): | |||
# swap out the incorrect base URL for the right one. | |||
# | |||
# The idea here is to redirect from | |||
# https://foo.bar/whatever/_matrix/... | |||
# to | |||
# https://public.baseurl/_matrix/... | |||
# | |||
i = requested_uri.index(b"/_matrix") | |||
new_uri = baseurl_bytes[:-1] + requested_uri[i:] | |||
logger.info( | |||
"Requested URI %s is not canonical: redirecting to %s", | |||
requested_uri.decode("utf-8", errors="replace"), | |||
new_uri.decode("utf-8", errors="replace"), | |||
) | |||
request.redirect(new_uri) | |||
finish_request(request) | |||
return | |||
client_redirect_url = parse_string( | |||
request, "redirectUrl", required=True, encoding=None | |||
) | |||
@@ -15,7 +15,7 @@ | |||
import time | |||
import urllib.parse | |||
from typing import Any, Dict, List, Union | |||
from typing import Any, Dict, List, Optional, Union | |||
from urllib.parse import urlencode | |||
from mock import Mock | |||
@@ -47,8 +47,14 @@ except ImportError: | |||
HAS_JWT = False | |||
# public_base_url used in some tests | |||
BASE_URL = "https://synapse/" | |||
# synapse server name: used to populate public_baseurl in some tests | |||
SYNAPSE_SERVER_PUBLIC_HOSTNAME = "synapse" | |||
# public_baseurl for some tests. It uses an http:// scheme because | |||
# FakeChannel.isSecure() returns False, so synapse will see the requested uri as | |||
# http://..., so using http in the public_baseurl stops Synapse trying to redirect to | |||
# https://.... | |||
BASE_URL = "http://%s/" % (SYNAPSE_SERVER_PUBLIC_HOSTNAME,) | |||
# CAS server used in some tests | |||
CAS_SERVER = "https://fake.test" | |||
@@ -480,11 +486,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): | |||
def test_multi_sso_redirect(self): | |||
"""/login/sso/redirect should redirect to an identity picker""" | |||
# first hit the redirect url, which should redirect to our idp picker | |||
channel = self.make_request( | |||
"GET", | |||
"/_matrix/client/r0/login/sso/redirect?redirectUrl=" | |||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), | |||
) | |||
channel = self._make_sso_redirect_request(False, None) | |||
self.assertEqual(channel.code, 302, channel.result) | |||
uri = channel.headers.getRawHeaders("Location")[0] | |||
@@ -628,34 +630,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): | |||
def test_client_idp_redirect_msc2858_disabled(self): | |||
"""If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" | |||
channel = self.make_request( | |||
"GET", | |||
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" | |||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), | |||
) | |||
channel = self._make_sso_redirect_request(True, "oidc") | |||
self.assertEqual(channel.code, 400, channel.result) | |||
self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") | |||
@override_config({"experimental_features": {"msc2858_enabled": True}}) | |||
def test_client_idp_redirect_to_unknown(self): | |||
"""If the client tries to pick an unknown IdP, return a 404""" | |||
channel = self.make_request( | |||
"GET", | |||
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl=" | |||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), | |||
) | |||
channel = self._make_sso_redirect_request(True, "xxx") | |||
self.assertEqual(channel.code, 404, channel.result) | |||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") | |||
@override_config({"experimental_features": {"msc2858_enabled": True}}) | |||
def test_client_idp_redirect_to_oidc(self): | |||
"""If the client pick a known IdP, redirect to it""" | |||
channel = self.make_request( | |||
"GET", | |||
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" | |||
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), | |||
) | |||
channel = self._make_sso_redirect_request(True, "oidc") | |||
self.assertEqual(channel.code, 302, channel.result) | |||
oidc_uri = channel.headers.getRawHeaders("Location")[0] | |||
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) | |||
@@ -663,6 +652,30 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): | |||
# it should redirect us to the auth page of the OIDC server | |||
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) | |||
def _make_sso_redirect_request( | |||
self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None | |||
): | |||
"""Send a request to /_matrix/client/r0/login/sso/redirect | |||
... or the unstable equivalent | |||
... possibly specifying an IDP provider | |||
""" | |||
endpoint = ( | |||
"/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect" | |||
if unstable_endpoint | |||
else "/_matrix/client/r0/login/sso/redirect" | |||
) | |||
if idp_prov is not None: | |||
endpoint += "/" + idp_prov | |||
endpoint += "?redirectUrl=" + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) | |||
return self.make_request( | |||
"GET", | |||
endpoint, | |||
custom_headers=[("Host", SYNAPSE_SERVER_PUBLIC_HOSTNAME)], | |||
) | |||
@staticmethod | |||
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: | |||
prefix = key + " = " | |||
@@ -542,13 +542,30 @@ class RestHelper: | |||
if client_redirect_url: | |||
params["redirectUrl"] = client_redirect_url | |||
# hit the redirect url (which will issue a cookie and state) | |||
# hit the redirect url (which should redirect back to the redirect url. This | |||
# is the easiest way of figuring out what the Host header ought to be set to | |||
# to keep Synapse happy. | |||
channel = make_request( | |||
self.hs.get_reactor(), | |||
self.site, | |||
"GET", | |||
"/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params), | |||
) | |||
assert channel.code == 302 | |||
# hit the redirect url again with the right Host header, which should now issue | |||
# a cookie and redirect to the SSO provider. | |||
location = channel.headers.getRawHeaders("Location")[0] | |||
parts = urllib.parse.urlsplit(location) | |||
channel = make_request( | |||
self.hs.get_reactor(), | |||
self.site, | |||
"GET", | |||
urllib.parse.urlunsplit(("", "") + parts[2:]), | |||
custom_headers=[ | |||
("Host", parts[1]), | |||
], | |||
) | |||
assert channel.code == 302 | |||
channel.extract_cookies(cookies) | |||
@@ -161,7 +161,11 @@ class UIAuthTests(unittest.HomeserverTestCase): | |||
def default_config(self): | |||
config = super().default_config() | |||
config["public_baseurl"] = "https://synapse.test" | |||
# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns | |||
# False, so synapse will see the requested uri as http://..., so using http in | |||
# the public_baseurl stops Synapse trying to redirect to https. | |||
config["public_baseurl"] = "http://synapse.test" | |||
if HAS_OIDC: | |||
# we enable OIDC as a way of testing SSO flows | |||
@@ -124,7 +124,11 @@ class FakeChannel: | |||
return address.IPv4Address("TCP", self._ip, 3423) | |||
def getHost(self): | |||
return None | |||
# this is called by Request.__init__ to configure Request.host. | |||
return address.IPv4Address("TCP", "127.0.0.1", 8888) | |||
def isSecure(self): | |||
return False | |||
@property | |||
def transport(self): | |||