mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-08 20:15:12 +02:00
Add cql support for confluence connector (#2679)
* Added CQL support for Confluence * changed string substitutions for CQL * final cleanup * updated string fixes * remove print statements * Update description
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Collection
|
||||
from datetime import datetime
|
||||
@@ -56,8 +57,101 @@ NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
|
||||
)
|
||||
|
||||
|
||||
class DanswerConfluence(Confluence):
|
||||
"""
|
||||
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
|
||||
This is necessary because the default Confluence class does not properly support cql expansions.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
|
||||
super(DanswerConfluence, self).__init__(url, *args, **kwargs)
|
||||
|
||||
def danswer_cql(
|
||||
self,
|
||||
cql: str,
|
||||
expand: str | None = None,
|
||||
start: int = 0,
|
||||
limit: int = 500,
|
||||
include_archived_spaces: bool = False,
|
||||
) -> list[dict[str, Any]]:
|
||||
# Performs the query expansion and start/limit url additions
|
||||
url_suffix = f"rest/api/content/search?cql={cql}"
|
||||
if expand:
|
||||
url_suffix += f"&expand={expand}"
|
||||
url_suffix += f"&start={start}&limit={limit}"
|
||||
if include_archived_spaces:
|
||||
url_suffix += "&includeArchivedSpaces=true"
|
||||
try:
|
||||
response = self.get(url_suffix)
|
||||
return response.get("results", [])
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _replace_cql_time_filter(
|
||||
cql_query: str, start_time: datetime, end_time: datetime
|
||||
) -> str:
|
||||
"""
|
||||
This function replaces the lastmodified filter in the CQL query with the start and end times.
|
||||
This selects the more restrictive time range.
|
||||
"""
|
||||
# Extract existing lastmodified >= and <= filters
|
||||
existing_start_match = re.search(
|
||||
r'lastmodified\s*>=\s*["\']?(\d{4}-\d{2}-\d{2}(?:\s+\d{2}:\d{2})?)["\']?',
|
||||
cql_query,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
existing_end_match = re.search(
|
||||
r'lastmodified\s*<=\s*["\']?(\d{4}-\d{2}-\d{2}(?:\s+\d{2}:\d{2})?)["\']?',
|
||||
cql_query,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Remove all existing lastmodified and updated filters
|
||||
cql_query = re.sub(
|
||||
r'\s*AND\s+(lastmodified|updated)\s*[<>=]+\s*["\']?[\d-]+(?:\s+[\d:]+)?["\']?',
|
||||
"",
|
||||
cql_query,
|
||||
flags=re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Determine the start time to use
|
||||
if existing_start_match:
|
||||
existing_start_str = existing_start_match.group(1)
|
||||
existing_start = datetime.strptime(
|
||||
existing_start_str,
|
||||
"%Y-%m-%d %H:%M" if " " in existing_start_str else "%Y-%m-%d",
|
||||
)
|
||||
existing_start = existing_start.replace(
|
||||
tzinfo=timezone.utc
|
||||
) # Make offset-aware
|
||||
start_time_to_use = max(start_time.astimezone(timezone.utc), existing_start)
|
||||
else:
|
||||
start_time_to_use = start_time.astimezone(timezone.utc)
|
||||
|
||||
# Determine the end time to use
|
||||
if existing_end_match:
|
||||
existing_end_str = existing_end_match.group(1)
|
||||
existing_end = datetime.strptime(
|
||||
existing_end_str,
|
||||
"%Y-%m-%d %H:%M" if " " in existing_end_str else "%Y-%m-%d",
|
||||
)
|
||||
existing_end = existing_end.replace(tzinfo=timezone.utc) # Make offset-aware
|
||||
end_time_to_use = min(end_time.astimezone(timezone.utc), existing_end)
|
||||
else:
|
||||
end_time_to_use = end_time.astimezone(timezone.utc)
|
||||
|
||||
# Add new time filters
|
||||
cql_query += (
|
||||
f" and lastmodified >= '{start_time_to_use.strftime('%Y-%m-%d %H:%M')}'"
|
||||
)
|
||||
cql_query += f" and lastmodified <= '{end_time_to_use.strftime('%Y-%m-%d %H:%M')}'"
|
||||
|
||||
return cql_query.strip()
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def _get_user(user_id: str, confluence_client: Confluence) -> str:
|
||||
def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
@@ -81,7 +175,7 @@ def _get_user(user_id: str, confluence_client: Confluence) -> str:
|
||||
return user_not_found
|
||||
|
||||
|
||||
def parse_html_page(text: str, confluence_client: Confluence) -> str:
|
||||
def parse_html_page(text: str, confluence_client: DanswerConfluence) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
@@ -112,7 +206,7 @@ def parse_html_page(text: str, confluence_client: Confluence) -> str:
|
||||
def _comment_dfs(
|
||||
comments_str: str,
|
||||
comment_pages: Collection[dict[str, Any]],
|
||||
confluence_client: Confluence,
|
||||
confluence_client: DanswerConfluence,
|
||||
) -> str:
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_child_by_type
|
||||
@@ -159,7 +253,7 @@ class RecursiveIndexer:
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int,
|
||||
confluence_client: Confluence,
|
||||
confluence_client: DanswerConfluence,
|
||||
index_recursively: bool,
|
||||
origin_page_id: str,
|
||||
) -> None:
|
||||
@@ -285,8 +379,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_base: str,
|
||||
space: str,
|
||||
is_cloud: bool,
|
||||
space: str = "",
|
||||
page_id: str = "",
|
||||
index_recursively: bool = True,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -295,35 +389,44 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
cql_query: str | None = None,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.recursive_indexer: RecursiveIndexer | None = None
|
||||
self.index_recursively = index_recursively
|
||||
self.index_recursively = False if cql_query else index_recursively
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
self.space = space
|
||||
self.page_id = page_id
|
||||
self.page_id = "" if cql_query else page_id
|
||||
self.space_level_scan = bool(not self.page_id)
|
||||
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
self.space_level_scan = False
|
||||
self.confluence_client: Confluence | None = None
|
||||
self.confluence_client: DanswerConfluence | None = None
|
||||
|
||||
if self.page_id is None or self.page_id == "":
|
||||
self.space_level_scan = True
|
||||
# if a cql_query is provided, we will use it to fetch the pages
|
||||
# if no cql_query is provided, we will use the space to fetch the pages
|
||||
# if no space is provided, we will default to fetching all pages, regardless of space
|
||||
if cql_query:
|
||||
self.cql_query = cql_query
|
||||
elif self.space:
|
||||
self.cql_query = f"type=page and space={self.space}"
|
||||
else:
|
||||
self.cql_query = "type=page"
|
||||
|
||||
logger.info(
|
||||
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
|
||||
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}"
|
||||
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively},"
|
||||
+ f" cql_query: {self.cql_query}"
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
username = credentials["confluence_username"]
|
||||
access_token = credentials["confluence_access_token"]
|
||||
self.confluence_client = Confluence(
|
||||
self.confluence_client = DanswerConfluence(
|
||||
url=self.wiki_base,
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=username if self.is_cloud else None,
|
||||
@@ -334,26 +437,33 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
def _fetch_pages(
|
||||
self,
|
||||
confluence_client: Confluence,
|
||||
start_ind: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_all_pages_from_space
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
get_all_pages = make_confluence_call_handle_rate_limit(
|
||||
self.confluence_client.danswer_cql
|
||||
)
|
||||
|
||||
include_archived_spaces = (
|
||||
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
|
||||
if not self.is_cloud
|
||||
else False
|
||||
)
|
||||
|
||||
try:
|
||||
return get_all_pages_from_space(
|
||||
self.space,
|
||||
return get_all_pages(
|
||||
cql=self.cql_query,
|
||||
start=start_ind,
|
||||
limit=batch_size,
|
||||
status=(
|
||||
None if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES else "current"
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
include_archived_spaces=include_archived_spaces,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"Batch failed with space {self.space} at offset {start_ind} "
|
||||
f"Batch failed with cql {self.cql_query} at offset {start_ind} "
|
||||
f"with size {batch_size}, processing pages individually..."
|
||||
)
|
||||
|
||||
@@ -363,27 +473,23 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
# Could be that one of the pages here failed due to this bug:
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-76433
|
||||
view_pages.extend(
|
||||
get_all_pages_from_space(
|
||||
self.space,
|
||||
get_all_pages(
|
||||
cql=self.cql_query,
|
||||
start=start_ind + i,
|
||||
limit=1,
|
||||
status=(
|
||||
None
|
||||
if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
|
||||
else "current"
|
||||
),
|
||||
expand="body.storage.value,version",
|
||||
include_archived_spaces=include_archived_spaces,
|
||||
)
|
||||
)
|
||||
except HTTPError as e:
|
||||
logger.warning(
|
||||
f"Page failed with space {self.space} at offset {start_ind + i}, "
|
||||
f"Page failed with cql {self.cql_query} at offset {start_ind + i}, "
|
||||
f"trying alternative expand option: {e}"
|
||||
)
|
||||
# Use view instead, which captures most info but is less complete
|
||||
view_pages.extend(
|
||||
get_all_pages_from_space(
|
||||
self.space,
|
||||
get_all_pages(
|
||||
cql=self.cql_query,
|
||||
start=start_ind + i,
|
||||
limit=1,
|
||||
expand="body.view.value,version",
|
||||
@@ -393,6 +499,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
return view_pages
|
||||
|
||||
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
if self.recursive_indexer is None:
|
||||
self.recursive_indexer = RecursiveIndexer(
|
||||
origin_page_id=self.page_id,
|
||||
@@ -421,7 +530,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
raise e
|
||||
|
||||
# error checking phase, only reachable if `self.continue_on_failure=True`
|
||||
for i in range(self.batch_size):
|
||||
for _ in range(self.batch_size):
|
||||
try:
|
||||
pages = (
|
||||
_fetch_space(start_ind, self.batch_size)
|
||||
@@ -437,7 +546,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
return pages
|
||||
|
||||
def _fetch_comments(self, confluence_client: Confluence, page_id: str) -> str:
|
||||
def _fetch_comments(
|
||||
self, confluence_client: DanswerConfluence, page_id: str
|
||||
) -> str:
|
||||
get_page_child_by_type = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_child_by_type
|
||||
)
|
||||
@@ -463,7 +574,9 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
return ""
|
||||
|
||||
def _fetch_labels(self, confluence_client: Confluence, page_id: str) -> list[str]:
|
||||
def _fetch_labels(
|
||||
self, confluence_client: DanswerConfluence, page_id: str
|
||||
) -> list[str]:
|
||||
get_page_labels = make_confluence_call_handle_rate_limit(
|
||||
confluence_client.get_page_labels
|
||||
)
|
||||
@@ -577,22 +690,20 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
return "\n".join(files_attachment_content), unused_attachments
|
||||
|
||||
def _get_doc_batch(
|
||||
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
|
||||
self, start_ind: int
|
||||
) -> tuple[list[Document], list[dict[str, Any]], int]:
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
batch = self._fetch_pages(self.confluence_client, start_ind)
|
||||
batch = self._fetch_pages(start_ind)
|
||||
|
||||
for page in batch:
|
||||
last_modified = _datetime_from_string(page["version"]["when"])
|
||||
author = cast(str | None, page["version"].get("by", {}).get("email"))
|
||||
|
||||
if time_filter and not time_filter(last_modified):
|
||||
continue
|
||||
|
||||
page_id = page["id"]
|
||||
|
||||
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
|
||||
@@ -715,17 +826,12 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
return doc_batch, end_ind - start_ind
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
unused_attachments = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
|
||||
start_ind
|
||||
)
|
||||
unused_attachments.extend(unused_attachments_batch)
|
||||
doc_batch, unused_attachments, num_pages = self._get_doc_batch(start_ind)
|
||||
unused_attachments.extend(unused_attachments)
|
||||
start_ind += num_pages
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
@@ -748,7 +854,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
unused_attachments = []
|
||||
unused_attachments: list[dict[str, Any]] = []
|
||||
|
||||
if self.confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
@@ -756,12 +862,12 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
self.cql_query = _replace_cql_time_filter(self.cql_query, start_time, end_time)
|
||||
|
||||
start_ind = 0
|
||||
while True:
|
||||
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
|
||||
start_ind, time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
unused_attachments.extend(unused_attachments_batch)
|
||||
doc_batch, unused_attachments, num_pages = self._get_doc_batch(start_ind)
|
||||
unused_attachments.extend(unused_attachments)
|
||||
|
||||
start_ind += num_pages
|
||||
if doc_batch:
|
||||
|
Reference in New Issue
Block a user