You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

217 lines
7.3 KiB

  1. # Copyright 2020 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict
  15. from unittest.mock import Mock
  16. from twisted.test.proto_helpers import MemoryReactor
  17. from synapse.handlers.cas import CasResponse
  18. from synapse.server import HomeServer
  19. from synapse.util import Clock
  20. from tests.test_utils import simple_async_mock
  21. from tests.unittest import HomeserverTestCase, override_config
  22. # These are a few constants that are used as config parameters in the tests.
  23. BASE_URL = "https://synapse/"
  24. SERVER_URL = "https://issuer/"
  25. class CasHandlerTestCase(HomeserverTestCase):
  26. def default_config(self) -> Dict[str, Any]:
  27. config = super().default_config()
  28. config["public_baseurl"] = BASE_URL
  29. cas_config = {
  30. "enabled": True,
  31. "server_url": SERVER_URL,
  32. "service_url": BASE_URL,
  33. }
  34. # Update this config with what's in the default config so that
  35. # override_config works as expected.
  36. cas_config.update(config.get("cas_config", {}))
  37. config["cas_config"] = cas_config
  38. return config
  39. def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
  40. hs = self.setup_test_homeserver()
  41. self.handler = hs.get_cas_handler()
  42. # Reduce the number of attempts when generating MXIDs.
  43. sso_handler = hs.get_sso_handler()
  44. sso_handler._MAP_USERNAME_RETRIES = 3
  45. return hs
  46. def test_map_cas_user_to_user(self) -> None:
  47. """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
  48. # stub out the auth handler
  49. auth_handler = self.hs.get_auth_handler()
  50. auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
  51. cas_response = CasResponse("test_user", {})
  52. request = _mock_request()
  53. self.get_success(
  54. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  55. )
  56. # check that the auth handler got called as expected
  57. auth_handler.complete_sso_login.assert_called_once_with(
  58. "@test_user:test",
  59. "cas",
  60. request,
  61. "redirect_uri",
  62. None,
  63. new_user=True,
  64. auth_provider_session_id=None,
  65. )
  66. def test_map_cas_user_to_existing_user(self) -> None:
  67. """Existing users can log in with CAS account."""
  68. store = self.hs.get_datastores().main
  69. self.get_success(
  70. store.register_user(user_id="@test_user:test", password_hash=None)
  71. )
  72. # stub out the auth handler
  73. auth_handler = self.hs.get_auth_handler()
  74. auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
  75. # Map a user via SSO.
  76. cas_response = CasResponse("test_user", {})
  77. request = _mock_request()
  78. self.get_success(
  79. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  80. )
  81. # check that the auth handler got called as expected
  82. auth_handler.complete_sso_login.assert_called_once_with(
  83. "@test_user:test",
  84. "cas",
  85. request,
  86. "redirect_uri",
  87. None,
  88. new_user=False,
  89. auth_provider_session_id=None,
  90. )
  91. # Subsequent calls should map to the same mxid.
  92. auth_handler.complete_sso_login.reset_mock()
  93. self.get_success(
  94. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  95. )
  96. auth_handler.complete_sso_login.assert_called_once_with(
  97. "@test_user:test",
  98. "cas",
  99. request,
  100. "redirect_uri",
  101. None,
  102. new_user=False,
  103. auth_provider_session_id=None,
  104. )
  105. def test_map_cas_user_to_invalid_localpart(self) -> None:
  106. """CAS automaps invalid characters to base-64 encoding."""
  107. # stub out the auth handler
  108. auth_handler = self.hs.get_auth_handler()
  109. auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
  110. cas_response = CasResponse("föö", {})
  111. request = _mock_request()
  112. self.get_success(
  113. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  114. )
  115. # check that the auth handler got called as expected
  116. auth_handler.complete_sso_login.assert_called_once_with(
  117. "@f=c3=b6=c3=b6:test",
  118. "cas",
  119. request,
  120. "redirect_uri",
  121. None,
  122. new_user=True,
  123. auth_provider_session_id=None,
  124. )
  125. @override_config(
  126. {
  127. "cas_config": {
  128. "required_attributes": {"userGroup": "staff", "department": None}
  129. }
  130. }
  131. )
  132. def test_required_attributes(self) -> None:
  133. """The required attributes must be met from the CAS response."""
  134. # stub out the auth handler
  135. auth_handler = self.hs.get_auth_handler()
  136. auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
  137. # The response doesn't have the proper userGroup or department.
  138. cas_response = CasResponse("test_user", {})
  139. request = _mock_request()
  140. self.get_success(
  141. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  142. )
  143. auth_handler.complete_sso_login.assert_not_called()
  144. # The response doesn't have any department.
  145. cas_response = CasResponse("test_user", {"userGroup": ["staff"]})
  146. request.reset_mock()
  147. self.get_success(
  148. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  149. )
  150. auth_handler.complete_sso_login.assert_not_called()
  151. # Add the proper attributes and it should succeed.
  152. cas_response = CasResponse(
  153. "test_user", {"userGroup": ["staff", "admin"], "department": ["sales"]}
  154. )
  155. request.reset_mock()
  156. self.get_success(
  157. self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
  158. )
  159. # check that the auth handler got called as expected
  160. auth_handler.complete_sso_login.assert_called_once_with(
  161. "@test_user:test",
  162. "cas",
  163. request,
  164. "redirect_uri",
  165. None,
  166. new_user=True,
  167. auth_provider_session_id=None,
  168. )
  169. def _mock_request() -> Mock:
  170. """Returns a mock which will stand in as a SynapseRequest"""
  171. mock = Mock(
  172. spec=[
  173. "finish",
  174. "getClientAddress",
  175. "getHeader",
  176. "setHeader",
  177. "setResponseCode",
  178. "write",
  179. ]
  180. )
  181. # `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
  182. mock._disconnected = False
  183. return mock