Added filter to exclude attachments with unsupported file extensions (#3530)

* Added filter to exclude attachments with unsupported file extensions

* extension
This commit is contained in:
hagen-danswer
2024-12-20 11:48:15 -08:00
committed by GitHub
parent 64b6f15e95
commit 71c5043832

View File

@@ -56,6 +56,23 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000 _SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"png",
"jpg",
"jpeg",
"gif",
"mp4",
"mov",
"mp3",
"wav",
]
_FULL_EXTENSION_FILTER_STRING = "".join(
[
f" and title!~'*.{extension}'"
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
]
)
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector): class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__( def __init__(
@@ -64,7 +81,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
is_cloud: bool, is_cloud: bool,
space: str = "", space: str = "",
page_id: str = "", page_id: str = "",
index_recursively: bool = True, index_recursively: bool = False,
cql_query: str | None = None, cql_query: str | None = None,
batch_size: int = INDEX_BATCH_SIZE, batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
@@ -83,22 +100,21 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
self.wiki_base = wiki_base.rstrip("/") self.wiki_base = wiki_base.rstrip("/")
# if nothing is provided, we will fetch all pages # if nothing is provided, we will fetch all pages
cql_page_query = "type=page" base_cql_page_query = "type=page"
if cql_query: if cql_query:
# if a cql_query is provided, we will use it to fetch the pages # if a cql_query is provided, we will use it to fetch the pages
cql_page_query = cql_query base_cql_page_query = cql_query
elif page_id: elif page_id:
# if a cql_query is not provided, we will use the page_id to fetch the page # if a cql_query is not provided, we will use the page_id to fetch the page
if index_recursively: if index_recursively:
cql_page_query += f" and ancestor='{page_id}'" base_cql_page_query += f" and ancestor='{page_id}'"
else: else:
cql_page_query += f" and id='{page_id}'" base_cql_page_query += f" and id='{page_id}'"
elif space: elif space:
# if no cql_query or page_id is provided, we will use the space to fetch the pages # if no cql_query or page_id is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'" base_cql_page_query += f" and space='{quote(space)}'"
self.cql_page_query = cql_page_query self.base_cql_page_query = base_cql_page_query
self.cql_time_filter = ""
self.cql_label_filter = "" self.cql_label_filter = ""
if labels_to_skip: if labels_to_skip:
@@ -126,6 +142,33 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
) )
return None return None
def _construct_page_query(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
page_query = self.base_cql_page_query + self.cql_label_filter
# Add time filters
if start:
formatted_start_time = datetime.fromtimestamp(
start, tz=self.timezone
).strftime("%Y-%m-%d %H:%M")
page_query += f" and lastmodified >= '{formatted_start_time}'"
if end:
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
page_query += f" and lastmodified <= '{formatted_end_time}'"
return page_query
def _construct_attachment_query(self, confluence_page_id: str) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
attachment_query += _FULL_EXTENSION_FILTER_STRING
return attachment_query
def _get_comment_string_for_page_id(self, page_id: str) -> str: def _get_comment_string_for_page_id(self, page_id: str) -> str:
comment_string = "" comment_string = ""
@@ -205,11 +248,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
metadata=doc_metadata, metadata=doc_metadata,
) )
def _fetch_document_batches(self) -> GenerateDocumentsOutput: def _fetch_document_batches(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = [] doc_batch: list[Document] = []
confluence_page_ids: list[str] = [] confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter page_query = self._construct_page_query(start, end)
logger.debug(f"page_query: {page_query}") logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents # Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval( for page in self.confluence_client.paginated_cql_retrieval(
@@ -228,11 +275,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# Fetch attachments as Documents # Fetch attachments as Documents
for confluence_page_id in confluence_page_ids: for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'" attachment_query = self._construct_attachment_query(confluence_page_id)
attachment_cql += self.cql_label_filter
# TODO: maybe should add time filter as well? # TODO: maybe should add time filter as well?
for attachment in self.confluence_client.paginated_cql_retrieval( for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_cql, cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS), expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
): ):
doc = self._convert_object_to_document(attachment) doc = self._convert_object_to_document(attachment)
@@ -248,17 +294,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def load_from_state(self) -> GenerateDocumentsOutput: def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_document_batches() return self._fetch_document_batches()
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput: def poll_source(
# Add time filters self,
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime( start: SecondsSinceUnixEpoch | None = None,
"%Y-%m-%d %H:%M" end: SecondsSinceUnixEpoch | None = None,
) ) -> GenerateDocumentsOutput:
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime( return self._fetch_document_batches(start, end)
"%Y-%m-%d %H:%M"
)
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
return self._fetch_document_batches()
def retrieve_all_slim_documents( def retrieve_all_slim_documents(
self, self,
@@ -269,7 +310,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS) restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
page_query = self.cql_page_query + self.cql_label_filter page_query = self.base_cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions( for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query, cql=page_query,
expand=restrictions_expand, expand=restrictions_expand,
@@ -294,10 +335,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
perm_sync_data=page_perm_sync_data, perm_sync_data=page_perm_sync_data,
) )
) )
attachment_cql = f"type=attachment and container='{page['id']}'" attachment_query = self._construct_attachment_query(page["id"])
attachment_cql += self.cql_label_filter
for attachment in self.confluence_client.cql_paginate_all_expansions( for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql, cql=attachment_query,
expand=restrictions_expand, expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE, limit=_SLIM_DOC_BATCH_SIZE,
): ):