Add cql support for confluence connector (#2679)

* Added CQL support for Confluence

* changed string substitutions for CQL

* final cleanup

* updated string fixes

* remove print statements

* Update description
This commit is contained in:
hagen-danswer
2024-10-10 12:16:56 -07:00
committed by GitHub
parent 101b010c5c
commit 1f4fe42f4b
5 changed files with 339 additions and 198 deletions

View File

@@ -1,5 +1,6 @@
import io
import os
import re
from collections.abc import Callable
from collections.abc import Collection
from datetime import datetime
@@ -56,8 +57,101 @@ NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
)
class DanswerConfluence(Confluence):
"""
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
This is necessary because the default Confluence class does not properly support cql expansions.
"""
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
super(DanswerConfluence, self).__init__(url, *args, **kwargs)
def danswer_cql(
self,
cql: str,
expand: str | None = None,
start: int = 0,
limit: int = 500,
include_archived_spaces: bool = False,
) -> list[dict[str, Any]]:
# Performs the query expansion and start/limit url additions
url_suffix = f"rest/api/content/search?cql={cql}"
if expand:
url_suffix += f"&expand={expand}"
url_suffix += f"&start={start}&limit={limit}"
if include_archived_spaces:
url_suffix += "&includeArchivedSpaces=true"
try:
response = self.get(url_suffix)
return response.get("results", [])
except Exception as e:
raise e
def _replace_cql_time_filter(
cql_query: str, start_time: datetime, end_time: datetime
) -> str:
"""
This function replaces the lastmodified filter in the CQL query with the start and end times.
This selects the more restrictive time range.
"""
# Extract existing lastmodified >= and <= filters
existing_start_match = re.search(
r'lastmodified\s*>=\s*["\']?(\d{4}-\d{2}-\d{2}(?:\s+\d{2}:\d{2})?)["\']?',
cql_query,
flags=re.IGNORECASE,
)
existing_end_match = re.search(
r'lastmodified\s*<=\s*["\']?(\d{4}-\d{2}-\d{2}(?:\s+\d{2}:\d{2})?)["\']?',
cql_query,
flags=re.IGNORECASE,
)
# Remove all existing lastmodified and updated filters
cql_query = re.sub(
r'\s*AND\s+(lastmodified|updated)\s*[<>=]+\s*["\']?[\d-]+(?:\s+[\d:]+)?["\']?',
"",
cql_query,
flags=re.IGNORECASE,
)
# Determine the start time to use
if existing_start_match:
existing_start_str = existing_start_match.group(1)
existing_start = datetime.strptime(
existing_start_str,
"%Y-%m-%d %H:%M" if " " in existing_start_str else "%Y-%m-%d",
)
existing_start = existing_start.replace(
tzinfo=timezone.utc
) # Make offset-aware
start_time_to_use = max(start_time.astimezone(timezone.utc), existing_start)
else:
start_time_to_use = start_time.astimezone(timezone.utc)
# Determine the end time to use
if existing_end_match:
existing_end_str = existing_end_match.group(1)
existing_end = datetime.strptime(
existing_end_str,
"%Y-%m-%d %H:%M" if " " in existing_end_str else "%Y-%m-%d",
)
existing_end = existing_end.replace(tzinfo=timezone.utc) # Make offset-aware
end_time_to_use = min(end_time.astimezone(timezone.utc), existing_end)
else:
end_time_to_use = end_time.astimezone(timezone.utc)
# Add new time filters
cql_query += (
f" and lastmodified >= '{start_time_to_use.strftime('%Y-%m-%d %H:%M')}'"
)
cql_query += f" and lastmodified <= '{end_time_to_use.strftime('%Y-%m-%d %H:%M')}'"
return cql_query.strip()
@lru_cache()
def _get_user(user_id: str, confluence_client: Confluence) -> str:
def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
@@ -81,7 +175,7 @@ def _get_user(user_id: str, confluence_client: Confluence) -> str:
return user_not_found
def parse_html_page(text: str, confluence_client: Confluence) -> str:
def parse_html_page(text: str, confluence_client: DanswerConfluence) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
@@ -112,7 +206,7 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
def _comment_dfs(
comments_str: str,
comment_pages: Collection[dict[str, Any]],
confluence_client: Confluence,
confluence_client: DanswerConfluence,
) -> str:
get_page_child_by_type = make_confluence_call_handle_rate_limit(
confluence_client.get_page_child_by_type
@@ -159,7 +253,7 @@ class RecursiveIndexer:
def __init__(
self,
batch_size: int,
confluence_client: Confluence,
confluence_client: DanswerConfluence,
index_recursively: bool,
origin_page_id: str,
) -> None:
@@ -285,8 +379,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
def __init__(
self,
wiki_base: str,
space: str,
is_cloud: bool,
space: str = "",
page_id: str = "",
index_recursively: bool = True,
batch_size: int = INDEX_BATCH_SIZE,
@@ -295,35 +389,44 @@ class ConfluenceConnector(LoadConnector, PollConnector):
# skip it. This is generally used to avoid indexing extra sensitive
# pages.
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
cql_query: str | None = None,
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.labels_to_skip = set(labels_to_skip)
self.recursive_indexer: RecursiveIndexer | None = None
self.index_recursively = index_recursively
self.index_recursively = False if cql_query else index_recursively
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
self.space = space
self.page_id = page_id
self.page_id = "" if cql_query else page_id
self.space_level_scan = bool(not self.page_id)
self.is_cloud = is_cloud
self.space_level_scan = False
self.confluence_client: Confluence | None = None
self.confluence_client: DanswerConfluence | None = None
if self.page_id is None or self.page_id == "":
self.space_level_scan = True
# if a cql_query is provided, we will use it to fetch the pages
# if no cql_query is provided, we will use the space to fetch the pages
# if no space is provided, we will default to fetching all pages, regardless of space
if cql_query:
self.cql_query = cql_query
elif self.space:
self.cql_query = f"type=page and space={self.space}"
else:
self.cql_query = "type=page"
logger.info(
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}"
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively},"
+ f" cql_query: {self.cql_query}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
username = credentials["confluence_username"]
access_token = credentials["confluence_access_token"]
self.confluence_client = Confluence(
self.confluence_client = DanswerConfluence(
url=self.wiki_base,
# passing in username causes issues for Confluence data center
username=username if self.is_cloud else None,
@@ -334,26 +437,33 @@ class ConfluenceConnector(LoadConnector, PollConnector):
def _fetch_pages(
self,
confluence_client: Confluence,
start_ind: int,
) -> list[dict[str, Any]]:
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
confluence_client.get_all_pages_from_space
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
get_all_pages = make_confluence_call_handle_rate_limit(
self.confluence_client.danswer_cql
)
include_archived_spaces = (
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
if not self.is_cloud
else False
)
try:
return get_all_pages_from_space(
self.space,
return get_all_pages(
cql=self.cql_query,
start=start_ind,
limit=batch_size,
status=(
None if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES else "current"
),
expand="body.storage.value,version",
include_archived_spaces=include_archived_spaces,
)
except Exception:
logger.warning(
f"Batch failed with space {self.space} at offset {start_ind} "
f"Batch failed with cql {self.cql_query} at offset {start_ind} "
f"with size {batch_size}, processing pages individually..."
)
@@ -363,27 +473,23 @@ class ConfluenceConnector(LoadConnector, PollConnector):
# Could be that one of the pages here failed due to this bug:
# https://jira.atlassian.com/browse/CONFCLOUD-76433
view_pages.extend(
get_all_pages_from_space(
self.space,
get_all_pages(
cql=self.cql_query,
start=start_ind + i,
limit=1,
status=(
None
if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
else "current"
),
expand="body.storage.value,version",
include_archived_spaces=include_archived_spaces,
)
)
except HTTPError as e:
logger.warning(
f"Page failed with space {self.space} at offset {start_ind + i}, "
f"Page failed with cql {self.cql_query} at offset {start_ind + i}, "
f"trying alternative expand option: {e}"
)
# Use view instead, which captures most info but is less complete
view_pages.extend(
get_all_pages_from_space(
self.space,
get_all_pages(
cql=self.cql_query,
start=start_ind + i,
limit=1,
expand="body.view.value,version",
@@ -393,6 +499,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return view_pages
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
if self.recursive_indexer is None:
self.recursive_indexer = RecursiveIndexer(
origin_page_id=self.page_id,
@@ -421,7 +530,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
raise e
# error checking phase, only reachable if `self.continue_on_failure=True`
for i in range(self.batch_size):
for _ in range(self.batch_size):
try:
pages = (
_fetch_space(start_ind, self.batch_size)
@@ -437,7 +546,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return pages
def _fetch_comments(self, confluence_client: Confluence, page_id: str) -> str:
def _fetch_comments(
self, confluence_client: DanswerConfluence, page_id: str
) -> str:
get_page_child_by_type = make_confluence_call_handle_rate_limit(
confluence_client.get_page_child_by_type
)
@@ -463,7 +574,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
)
return ""
def _fetch_labels(self, confluence_client: Confluence, page_id: str) -> list[str]:
def _fetch_labels(
self, confluence_client: DanswerConfluence, page_id: str
) -> list[str]:
get_page_labels = make_confluence_call_handle_rate_limit(
confluence_client.get_page_labels
)
@@ -577,22 +690,20 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return "\n".join(files_attachment_content), unused_attachments
def _get_doc_batch(
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
self, start_ind: int
) -> tuple[list[Document], list[dict[str, Any]], int]:
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
doc_batch: list[Document] = []
unused_attachments: list[dict[str, Any]] = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
batch = self._fetch_pages(self.confluence_client, start_ind)
batch = self._fetch_pages(start_ind)
for page in batch:
last_modified = _datetime_from_string(page["version"]["when"])
author = cast(str | None, page["version"].get("by", {}).get("email"))
if time_filter and not time_filter(last_modified):
continue
page_id = page["id"]
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
@@ -715,17 +826,12 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return doc_batch, end_ind - start_ind
def load_from_state(self) -> GenerateDocumentsOutput:
unused_attachments = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
unused_attachments: list[dict[str, Any]] = []
start_ind = 0
while True:
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
start_ind
)
unused_attachments.extend(unused_attachments_batch)
doc_batch, unused_attachments, num_pages = self._get_doc_batch(start_ind)
unused_attachments.extend(unused_attachments)
start_ind += num_pages
if doc_batch:
yield doc_batch
@@ -748,7 +854,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
unused_attachments = []
unused_attachments: list[dict[str, Any]] = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
@@ -756,12 +862,12 @@ class ConfluenceConnector(LoadConnector, PollConnector):
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
self.cql_query = _replace_cql_time_filter(self.cql_query, start_time, end_time)
start_ind = 0
while True:
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
start_ind, time_filter=lambda t: start_time <= t <= end_time
)
unused_attachments.extend(unused_attachments_batch)
doc_batch, unused_attachments, num_pages = self._get_doc_batch(start_ind)
unused_attachments.extend(unused_attachments)
start_ind += num_pages
if doc_batch: