You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

910 lines
31 KiB

  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import enum
  15. import logging
  16. import re
  17. from collections import deque
  18. from dataclasses import dataclass
  19. from typing import (
  20. TYPE_CHECKING,
  21. Any,
  22. Collection,
  23. Iterable,
  24. List,
  25. Optional,
  26. Set,
  27. Tuple,
  28. Union,
  29. )
  30. import attr
  31. from synapse.api.errors import SynapseError
  32. from synapse.events import EventBase
  33. from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
  34. from synapse.storage.database import (
  35. DatabasePool,
  36. LoggingDatabaseConnection,
  37. LoggingTransaction,
  38. )
  39. from synapse.storage.databases.main.events_worker import EventRedactBehaviour
  40. from synapse.storage.engines import PostgresEngine, Sqlite3Engine
  41. from synapse.types import JsonDict
  42. if TYPE_CHECKING:
  43. from synapse.server import HomeServer
  44. logger = logging.getLogger(__name__)
  45. @attr.s(slots=True, frozen=True, auto_attribs=True)
  46. class SearchEntry:
  47. key: str
  48. value: str
  49. event_id: str
  50. room_id: str
  51. stream_ordering: Optional[int]
  52. origin_server_ts: int
  53. def _clean_value_for_search(value: str) -> str:
  54. """
  55. Replaces any null code points in the string with spaces as
  56. Postgres and SQLite do not like the insertion of strings with
  57. null code points into the full-text search tables.
  58. """
  59. return value.replace("\u0000", " ")
  60. class SearchWorkerStore(SQLBaseStore):
  61. def store_search_entries_txn(
  62. self, txn: LoggingTransaction, entries: Iterable[SearchEntry]
  63. ) -> None:
  64. """Add entries to the search table
  65. Args:
  66. txn:
  67. entries: entries to be added to the table
  68. """
  69. if not self.hs.config.server.enable_search:
  70. return
  71. if isinstance(self.database_engine, PostgresEngine):
  72. sql = """
  73. INSERT INTO event_search
  74. (event_id, room_id, key, vector, stream_ordering, origin_server_ts)
  75. VALUES (?,?,?,to_tsvector('english', ?),?,?)
  76. """
  77. args1 = (
  78. (
  79. entry.event_id,
  80. entry.room_id,
  81. entry.key,
  82. _clean_value_for_search(entry.value),
  83. entry.stream_ordering,
  84. entry.origin_server_ts,
  85. )
  86. for entry in entries
  87. )
  88. txn.execute_batch(sql, args1)
  89. elif isinstance(self.database_engine, Sqlite3Engine):
  90. self.db_pool.simple_insert_many_txn(
  91. txn,
  92. table="event_search",
  93. keys=("event_id", "room_id", "key", "value"),
  94. values=(
  95. (
  96. entry.event_id,
  97. entry.room_id,
  98. entry.key,
  99. _clean_value_for_search(entry.value),
  100. )
  101. for entry in entries
  102. ),
  103. )
  104. else:
  105. # This should be unreachable.
  106. raise Exception("Unrecognized database engine")
  107. class SearchBackgroundUpdateStore(SearchWorkerStore):
  108. EVENT_SEARCH_UPDATE_NAME = "event_search"
  109. EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
  110. EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
  111. EVENT_SEARCH_DELETE_NON_STRINGS = "event_search_sqlite_delete_non_strings"
  112. def __init__(
  113. self,
  114. database: DatabasePool,
  115. db_conn: LoggingDatabaseConnection,
  116. hs: "HomeServer",
  117. ):
  118. super().__init__(database, db_conn, hs)
  119. self.db_pool.updates.register_background_update_handler(
  120. self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
  121. )
  122. self.db_pool.updates.register_background_update_handler(
  123. self.EVENT_SEARCH_ORDER_UPDATE_NAME, self._background_reindex_search_order
  124. )
  125. self.db_pool.updates.register_background_update_handler(
  126. self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
  127. )
  128. self.db_pool.updates.register_background_update_handler(
  129. self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
  130. )
  131. async def _background_reindex_search(
  132. self, progress: JsonDict, batch_size: int
  133. ) -> int:
  134. # we work through the events table from highest stream id to lowest
  135. target_min_stream_id = progress["target_min_stream_id_inclusive"]
  136. max_stream_id = progress["max_stream_id_exclusive"]
  137. rows_inserted = progress.get("rows_inserted", 0)
  138. TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
  139. def reindex_search_txn(txn: LoggingTransaction) -> int:
  140. sql = """
  141. SELECT stream_ordering, event_id, room_id, type, json, origin_server_ts
  142. FROM events
  143. JOIN event_json USING (room_id, event_id)
  144. WHERE ? <= stream_ordering AND stream_ordering < ?
  145. AND (%s)
  146. ORDER BY stream_ordering DESC
  147. LIMIT ?
  148. """ % (
  149. " OR ".join("type = '%s'" % (t,) for t in TYPES),
  150. )
  151. txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
  152. # we could stream straight from the results into
  153. # store_search_entries_txn with a generator function, but that
  154. # would mean having two cursors open on the database at once.
  155. # Instead we just build a list of results.
  156. rows = self.db_pool.cursor_to_dict(txn)
  157. if not rows:
  158. return 0
  159. min_stream_id = rows[-1]["stream_ordering"]
  160. event_search_rows = []
  161. for row in rows:
  162. try:
  163. event_id = row["event_id"]
  164. room_id = row["room_id"]
  165. etype = row["type"]
  166. stream_ordering = row["stream_ordering"]
  167. origin_server_ts = row["origin_server_ts"]
  168. try:
  169. event_json = db_to_json(row["json"])
  170. content = event_json["content"]
  171. except Exception:
  172. continue
  173. if etype == "m.room.message":
  174. key = "content.body"
  175. value = content["body"]
  176. elif etype == "m.room.topic":
  177. key = "content.topic"
  178. value = content["topic"]
  179. elif etype == "m.room.name":
  180. key = "content.name"
  181. value = content["name"]
  182. else:
  183. raise Exception("unexpected event type %s" % etype)
  184. except (KeyError, AttributeError):
  185. # If the event is missing a necessary field then
  186. # skip over it.
  187. continue
  188. if not isinstance(value, str):
  189. # If the event body, name or topic isn't a string
  190. # then skip over it
  191. continue
  192. event_search_rows.append(
  193. SearchEntry(
  194. key=key,
  195. value=value,
  196. event_id=event_id,
  197. room_id=room_id,
  198. stream_ordering=stream_ordering,
  199. origin_server_ts=origin_server_ts,
  200. )
  201. )
  202. self.store_search_entries_txn(txn, event_search_rows)
  203. progress = {
  204. "target_min_stream_id_inclusive": target_min_stream_id,
  205. "max_stream_id_exclusive": min_stream_id,
  206. "rows_inserted": rows_inserted + len(event_search_rows),
  207. }
  208. self.db_pool.updates._background_update_progress_txn(
  209. txn, self.EVENT_SEARCH_UPDATE_NAME, progress
  210. )
  211. return len(event_search_rows)
  212. if self.hs.config.server.enable_search:
  213. result = await self.db_pool.runInteraction(
  214. self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
  215. )
  216. else:
  217. # Don't index anything if search is not enabled.
  218. result = 0
  219. if not result:
  220. await self.db_pool.updates._end_background_update(
  221. self.EVENT_SEARCH_UPDATE_NAME
  222. )
  223. return result
  224. async def _background_reindex_gin_search(
  225. self, progress: JsonDict, batch_size: int
  226. ) -> int:
  227. """This handles old synapses which used GIST indexes, if any;
  228. converting them back to be GIN as per the actual schema.
  229. """
  230. def create_index(conn: LoggingDatabaseConnection) -> None:
  231. conn.rollback()
  232. # we have to set autocommit, because postgres refuses to
  233. # CREATE INDEX CONCURRENTLY without it.
  234. conn.set_session(autocommit=True)
  235. try:
  236. c = conn.cursor()
  237. # if we skipped the conversion to GIST, we may already/still
  238. # have an event_search_fts_idx; unfortunately postgres 9.4
  239. # doesn't support CREATE INDEX IF EXISTS so we just catch the
  240. # exception and ignore it.
  241. import psycopg2
  242. try:
  243. c.execute(
  244. """
  245. CREATE INDEX CONCURRENTLY event_search_fts_idx
  246. ON event_search USING GIN (vector)
  247. """
  248. )
  249. except psycopg2.ProgrammingError as e:
  250. logger.warning(
  251. "Ignoring error %r when trying to switch from GIST to GIN", e
  252. )
  253. # we should now be able to delete the GIST index.
  254. c.execute("DROP INDEX IF EXISTS event_search_fts_idx_gist")
  255. finally:
  256. conn.set_session(autocommit=False)
  257. if isinstance(self.database_engine, PostgresEngine):
  258. await self.db_pool.runWithConnection(create_index)
  259. await self.db_pool.updates._end_background_update(
  260. self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
  261. )
  262. return 1
  263. async def _background_reindex_search_order(
  264. self, progress: JsonDict, batch_size: int
  265. ) -> int:
  266. target_min_stream_id = progress["target_min_stream_id_inclusive"]
  267. max_stream_id = progress["max_stream_id_exclusive"]
  268. rows_inserted = progress.get("rows_inserted", 0)
  269. have_added_index = progress["have_added_indexes"]
  270. if not have_added_index:
  271. def create_index(conn: LoggingDatabaseConnection) -> None:
  272. conn.rollback()
  273. conn.set_session(autocommit=True)
  274. c = conn.cursor()
  275. # We create with NULLS FIRST so that when we search *backwards*
  276. # we get the ones with non null origin_server_ts *first*
  277. c.execute(
  278. """
  279. CREATE INDEX CONCURRENTLY event_search_room_order
  280. ON event_search(room_id, origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
  281. """
  282. )
  283. c.execute(
  284. """
  285. CREATE INDEX CONCURRENTLY event_search_order
  286. ON event_search(origin_server_ts NULLS FIRST, stream_ordering NULLS FIRST)
  287. """
  288. )
  289. conn.set_session(autocommit=False)
  290. await self.db_pool.runWithConnection(create_index)
  291. pg = dict(progress)
  292. pg["have_added_indexes"] = True
  293. await self.db_pool.runInteraction(
  294. self.EVENT_SEARCH_ORDER_UPDATE_NAME,
  295. self.db_pool.updates._background_update_progress_txn,
  296. self.EVENT_SEARCH_ORDER_UPDATE_NAME,
  297. pg,
  298. )
  299. def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
  300. sql = """
  301. UPDATE event_search AS es
  302. SET stream_ordering = e.stream_ordering, origin_server_ts = e.origin_server_ts
  303. FROM events AS e
  304. WHERE e.event_id = es.event_id
  305. AND ? <= e.stream_ordering AND e.stream_ordering < ?
  306. RETURNING es.stream_ordering
  307. """
  308. min_stream_id = max_stream_id - batch_size
  309. txn.execute(sql, (min_stream_id, max_stream_id))
  310. rows = txn.fetchall()
  311. if min_stream_id < target_min_stream_id:
  312. # We've recached the end.
  313. return len(rows), False
  314. progress = {
  315. "target_min_stream_id_inclusive": target_min_stream_id,
  316. "max_stream_id_exclusive": min_stream_id,
  317. "rows_inserted": rows_inserted + len(rows),
  318. "have_added_indexes": True,
  319. }
  320. self.db_pool.updates._background_update_progress_txn(
  321. txn, self.EVENT_SEARCH_ORDER_UPDATE_NAME, progress
  322. )
  323. return len(rows), True
  324. num_rows, finished = await self.db_pool.runInteraction(
  325. self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
  326. )
  327. if not finished:
  328. await self.db_pool.updates._end_background_update(
  329. self.EVENT_SEARCH_ORDER_UPDATE_NAME
  330. )
  331. return num_rows
  332. async def _background_delete_non_strings(
  333. self, progress: JsonDict, batch_size: int
  334. ) -> int:
  335. """Deletes rows with non-string `value`s from `event_search` if using sqlite.
  336. Prior to Synapse 1.44.0, malformed events received over federation could cause integers
  337. to be inserted into the `event_search` table when using sqlite.
  338. """
  339. def delete_non_strings_txn(txn: LoggingTransaction) -> None:
  340. txn.execute("DELETE FROM event_search WHERE typeof(value) != 'text'")
  341. await self.db_pool.runInteraction(
  342. self.EVENT_SEARCH_DELETE_NON_STRINGS, delete_non_strings_txn
  343. )
  344. await self.db_pool.updates._end_background_update(
  345. self.EVENT_SEARCH_DELETE_NON_STRINGS
  346. )
  347. return 1
  348. class SearchStore(SearchBackgroundUpdateStore):
  349. def __init__(
  350. self,
  351. database: DatabasePool,
  352. db_conn: LoggingDatabaseConnection,
  353. hs: "HomeServer",
  354. ):
  355. super().__init__(database, db_conn, hs)
  356. async def search_msgs(
  357. self, room_ids: Collection[str], search_term: str, keys: Iterable[str]
  358. ) -> JsonDict:
  359. """Performs a full text search over events with given keys.
  360. Args:
  361. room_ids: List of room ids to search in
  362. search_term: Search term to search for
  363. keys: List of keys to search in, currently supports
  364. "content.body", "content.name", "content.topic"
  365. Returns:
  366. Dictionary of results
  367. """
  368. clauses = []
  369. args: List[Any] = []
  370. # Make sure we don't explode because the person is in too many rooms.
  371. # We filter the results below regardless.
  372. if len(room_ids) < 500:
  373. clause, args = make_in_list_sql_clause(
  374. self.database_engine, "room_id", room_ids
  375. )
  376. clauses = [clause]
  377. local_clauses = []
  378. for key in keys:
  379. local_clauses.append("key = ?")
  380. args.append(key)
  381. clauses.append("(%s)" % (" OR ".join(local_clauses),))
  382. count_args = args
  383. count_clauses = clauses
  384. if isinstance(self.database_engine, PostgresEngine):
  385. search_query = search_term
  386. sql = """
  387. SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) AS rank,
  388. room_id, event_id
  389. FROM event_search
  390. WHERE vector @@ websearch_to_tsquery('english', ?)
  391. """
  392. args = [search_query, search_query] + args
  393. count_sql = """
  394. SELECT room_id, count(*) as count FROM event_search
  395. WHERE vector @@ websearch_to_tsquery('english', ?)
  396. """
  397. count_args = [search_query] + count_args
  398. elif isinstance(self.database_engine, Sqlite3Engine):
  399. search_query = _parse_query_for_sqlite(search_term)
  400. sql = """
  401. SELECT rank(matchinfo(event_search)) as rank, room_id, event_id
  402. FROM event_search
  403. WHERE value MATCH ?
  404. """
  405. args = [search_query] + args
  406. count_sql = """
  407. SELECT room_id, count(*) as count FROM event_search
  408. WHERE value MATCH ?
  409. """
  410. count_args = [search_query] + count_args
  411. else:
  412. # This should be unreachable.
  413. raise Exception("Unrecognized database engine")
  414. for clause in clauses:
  415. sql += " AND " + clause
  416. for clause in count_clauses:
  417. count_sql += " AND " + clause
  418. # We add an arbitrary limit here to ensure we don't try to pull the
  419. # entire table from the database.
  420. sql += " ORDER BY rank DESC LIMIT 500"
  421. results = await self.db_pool.execute(
  422. "search_msgs", self.db_pool.cursor_to_dict, sql, *args
  423. )
  424. results = list(filter(lambda row: row["room_id"] in room_ids, results))
  425. # We set redact_behaviour to block here to prevent redacted events being returned in
  426. # search results (which is a data leak)
  427. events = await self.get_events_as_list( # type: ignore[attr-defined]
  428. [r["event_id"] for r in results],
  429. redact_behaviour=EventRedactBehaviour.block,
  430. )
  431. event_map = {ev.event_id: ev for ev in events}
  432. highlights = None
  433. if isinstance(self.database_engine, PostgresEngine):
  434. highlights = await self._find_highlights_in_postgres(search_query, events)
  435. count_sql += " GROUP BY room_id"
  436. count_results = await self.db_pool.execute(
  437. "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
  438. )
  439. count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
  440. return {
  441. "results": [
  442. {"event": event_map[r["event_id"]], "rank": r["rank"]}
  443. for r in results
  444. if r["event_id"] in event_map
  445. ],
  446. "highlights": highlights,
  447. "count": count,
  448. }
  449. async def search_rooms(
  450. self,
  451. room_ids: Collection[str],
  452. search_term: str,
  453. keys: Iterable[str],
  454. limit: int,
  455. pagination_token: Optional[str] = None,
  456. ) -> JsonDict:
  457. """Performs a full text search over events with given keys.
  458. Args:
  459. room_ids: The room_ids to search in
  460. search_term: Search term to search for
  461. keys: List of keys to search in, currently supports "content.body",
  462. "content.name", "content.topic"
  463. pagination_token: A pagination token previously returned
  464. Returns:
  465. Each match as a dictionary.
  466. """
  467. clauses = []
  468. args: List[Any] = []
  469. # Make sure we don't explode because the person is in too many rooms.
  470. # We filter the results below regardless.
  471. if len(room_ids) < 500:
  472. clause, args = make_in_list_sql_clause(
  473. self.database_engine, "room_id", room_ids
  474. )
  475. clauses = [clause]
  476. local_clauses = []
  477. for key in keys:
  478. local_clauses.append("key = ?")
  479. args.append(key)
  480. clauses.append("(%s)" % (" OR ".join(local_clauses),))
  481. # take copies of the current args and clauses lists, before adding
  482. # pagination clauses to main query.
  483. count_args = list(args)
  484. count_clauses = list(clauses)
  485. if pagination_token:
  486. try:
  487. origin_server_ts_str, stream_str = pagination_token.split(",")
  488. origin_server_ts = int(origin_server_ts_str)
  489. stream = int(stream_str)
  490. except Exception:
  491. raise SynapseError(400, "Invalid pagination token")
  492. clauses.append(
  493. """
  494. (origin_server_ts < ? OR (origin_server_ts = ? AND stream_ordering < ?))
  495. """
  496. )
  497. args.extend([origin_server_ts, origin_server_ts, stream])
  498. if isinstance(self.database_engine, PostgresEngine):
  499. search_query = search_term
  500. sql = """
  501. SELECT ts_rank_cd(vector, websearch_to_tsquery('english', ?)) as rank,
  502. origin_server_ts, stream_ordering, room_id, event_id
  503. FROM event_search
  504. WHERE vector @@ websearch_to_tsquery('english', ?) AND
  505. """
  506. args = [search_query, search_query] + args
  507. count_sql = """
  508. SELECT room_id, count(*) as count FROM event_search
  509. WHERE vector @@ websearch_to_tsquery('english', ?) AND
  510. """
  511. count_args = [search_query] + count_args
  512. elif isinstance(self.database_engine, Sqlite3Engine):
  513. # We use CROSS JOIN here to ensure we use the right indexes.
  514. # https://sqlite.org/optoverview.html#crossjoin
  515. #
  516. # We want to use the full text search index on event_search to
  517. # extract all possible matches first, then lookup those matches
  518. # in the events table to get the topological ordering. We need
  519. # to use the indexes in this order because sqlite refuses to
  520. # MATCH unless it uses the full text search index
  521. sql = """
  522. SELECT
  523. rank(matchinfo) as rank, room_id, event_id, origin_server_ts, stream_ordering
  524. FROM (
  525. SELECT key, event_id, matchinfo(event_search) as matchinfo
  526. FROM event_search
  527. WHERE value MATCH ?
  528. )
  529. CROSS JOIN events USING (event_id)
  530. WHERE
  531. """
  532. search_query = _parse_query_for_sqlite(search_term)
  533. args = [search_query] + args
  534. count_sql = """
  535. SELECT room_id, count(*) as count FROM event_search
  536. WHERE value MATCH ? AND
  537. """
  538. count_args = [search_query] + count_args
  539. else:
  540. # This should be unreachable.
  541. raise Exception("Unrecognized database engine")
  542. sql += " AND ".join(clauses)
  543. count_sql += " AND ".join(count_clauses)
  544. # We add an arbitrary limit here to ensure we don't try to pull the
  545. # entire table from the database.
  546. if isinstance(self.database_engine, PostgresEngine):
  547. sql += """
  548. ORDER BY origin_server_ts DESC NULLS LAST, stream_ordering DESC NULLS LAST
  549. LIMIT ?
  550. """
  551. elif isinstance(self.database_engine, Sqlite3Engine):
  552. sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
  553. else:
  554. raise Exception("Unrecognized database engine")
  555. # mypy expects to append only a `str`, not an `int`
  556. args.append(limit)
  557. results = await self.db_pool.execute(
  558. "search_rooms", self.db_pool.cursor_to_dict, sql, *args
  559. )
  560. results = list(filter(lambda row: row["room_id"] in room_ids, results))
  561. # We set redact_behaviour to block here to prevent redacted events being returned in
  562. # search results (which is a data leak)
  563. events = await self.get_events_as_list( # type: ignore[attr-defined]
  564. [r["event_id"] for r in results],
  565. redact_behaviour=EventRedactBehaviour.block,
  566. )
  567. event_map = {ev.event_id: ev for ev in events}
  568. highlights = None
  569. if isinstance(self.database_engine, PostgresEngine):
  570. highlights = await self._find_highlights_in_postgres(search_query, events)
  571. count_sql += " GROUP BY room_id"
  572. count_results = await self.db_pool.execute(
  573. "search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
  574. )
  575. count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
  576. return {
  577. "results": [
  578. {
  579. "event": event_map[r["event_id"]],
  580. "rank": r["rank"],
  581. "pagination_token": "%s,%s"
  582. % (r["origin_server_ts"], r["stream_ordering"]),
  583. }
  584. for r in results
  585. if r["event_id"] in event_map
  586. ],
  587. "highlights": highlights,
  588. "count": count,
  589. }
  590. async def _find_highlights_in_postgres(
  591. self, search_query: str, events: List[EventBase]
  592. ) -> Set[str]:
  593. """Given a list of events and a search term, return a list of words
  594. that match from the content of the event.
  595. This is used to give a list of words that clients can match against to
  596. highlight the matching parts.
  597. Args:
  598. search_query
  599. events: A list of events
  600. Returns:
  601. A set of strings.
  602. """
  603. def f(txn: LoggingTransaction) -> Set[str]:
  604. highlight_words = set()
  605. for event in events:
  606. # As a hack we simply join values of all possible keys. This is
  607. # fine since we're only using them to find possible highlights.
  608. values = []
  609. for key in ("body", "name", "topic"):
  610. v = event.content.get(key, None)
  611. if v:
  612. v = _clean_value_for_search(v)
  613. values.append(v)
  614. if not values:
  615. continue
  616. value = " ".join(values)
  617. # We need to find some values for StartSel and StopSel that
  618. # aren't in the value so that we can pick results out.
  619. start_sel = "<"
  620. stop_sel = ">"
  621. while start_sel in value:
  622. start_sel += "<"
  623. while stop_sel in value:
  624. stop_sel += ">"
  625. query = (
  626. "SELECT ts_headline(?, websearch_to_tsquery('english', ?), %s)"
  627. % (
  628. _to_postgres_options(
  629. {
  630. "StartSel": start_sel,
  631. "StopSel": stop_sel,
  632. "MaxFragments": "50",
  633. }
  634. )
  635. )
  636. )
  637. txn.execute(query, (value, search_query))
  638. (headline,) = txn.fetchall()[0]
  639. # Now we need to pick the possible highlights out of the haedline
  640. # result.
  641. matcher_regex = "%s(.*?)%s" % (
  642. re.escape(start_sel),
  643. re.escape(stop_sel),
  644. )
  645. res = re.findall(matcher_regex, headline)
  646. highlight_words.update([r.lower() for r in res])
  647. return highlight_words
  648. return await self.db_pool.runInteraction("_find_highlights", f)
  649. def _to_postgres_options(options_dict: JsonDict) -> str:
  650. return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
  651. @dataclass
  652. class Phrase:
  653. phrase: List[str]
  654. class SearchToken(enum.Enum):
  655. Not = enum.auto()
  656. Or = enum.auto()
  657. And = enum.auto()
  658. Token = Union[str, Phrase, SearchToken]
  659. TokenList = List[Token]
  660. def _is_stop_word(word: str) -> bool:
  661. # TODO Pull these out of the dictionary:
  662. # https://github.com/postgres/postgres/blob/master/src/backend/snowball/stopwords/english.stop
  663. return word in {"the", "a", "you", "me", "and", "but"}
  664. def _tokenize_query(query: str) -> TokenList:
  665. """
  666. Convert the user-supplied `query` into a TokenList, which can be translated into
  667. some DB-specific syntax.
  668. The following constructs are supported:
  669. - phrase queries using "double quotes"
  670. - case-insensitive `or` and `and` operators
  671. - negation of a keyword via unary `-`
  672. - unary hyphen to denote NOT e.g. 'include -exclude'
  673. The following differs from websearch_to_tsquery:
  674. - Stop words are not removed.
  675. - Unclosed phrases are treated differently.
  676. """
  677. tokens: TokenList = []
  678. # Find phrases.
  679. in_phrase = False
  680. parts = deque(query.split('"'))
  681. for i, part in enumerate(parts):
  682. # The contents inside double quotes is treated as a phrase.
  683. in_phrase = bool(i % 2)
  684. # Pull out the individual words, discarding any non-word characters.
  685. words = deque(re.findall(r"([\w\-]+)", part, re.UNICODE))
  686. # Phrases have simplified handling of words.
  687. if in_phrase:
  688. # Skip stop words.
  689. phrase = [word for word in words if not _is_stop_word(word)]
  690. # Consecutive words are implicitly ANDed together.
  691. if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
  692. tokens.append(SearchToken.And)
  693. # Add the phrase.
  694. tokens.append(Phrase(phrase))
  695. continue
  696. # Otherwise, not in a phrase.
  697. while words:
  698. word = words.popleft()
  699. if word.startswith("-"):
  700. tokens.append(SearchToken.Not)
  701. # If there's more word, put it back to be processed again.
  702. word = word[1:]
  703. if word:
  704. words.appendleft(word)
  705. elif word.lower() == "or":
  706. tokens.append(SearchToken.Or)
  707. else:
  708. # Skip stop words.
  709. if _is_stop_word(word):
  710. continue
  711. # Consecutive words are implicitly ANDed together.
  712. if tokens and tokens[-1] not in (SearchToken.Not, SearchToken.Or):
  713. tokens.append(SearchToken.And)
  714. # Add the search term.
  715. tokens.append(word)
  716. return tokens
  717. def _tokens_to_sqlite_match_query(tokens: TokenList) -> str:
  718. """
  719. Convert the list of tokens to a string suitable for passing to sqlite's MATCH.
  720. Assume sqlite was compiled with enhanced query syntax.
  721. Ref: https://www.sqlite.org/fts3.html#full_text_index_queries
  722. """
  723. match_query = []
  724. for token in tokens:
  725. if isinstance(token, str):
  726. match_query.append(token)
  727. elif isinstance(token, Phrase):
  728. match_query.append('"' + " ".join(token.phrase) + '"')
  729. elif token == SearchToken.Not:
  730. # TODO: SQLite treats NOT as a *binary* operator. Hopefully a search
  731. # term has already been added before this.
  732. match_query.append(" NOT ")
  733. elif token == SearchToken.Or:
  734. match_query.append(" OR ")
  735. elif token == SearchToken.And:
  736. match_query.append(" AND ")
  737. else:
  738. raise ValueError(f"unknown token {token}")
  739. return "".join(match_query)
  740. def _parse_query_for_sqlite(search_term: str) -> str:
  741. """Takes a plain unicode string from the user and converts it into a form
  742. that can be passed to sqllite's matchinfo().
  743. """
  744. return _tokens_to_sqlite_match_query(_tokenize_query(search_term))