DAN-25 Semantic Identifier for Documents (#24)

This commit is contained in:
Yuhong Sun 2023-05-09 22:46:45 -07:00 committed by GitHub
parent 6e59b02c91
commit 25b59217ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 99 additions and 58 deletions

View File

@ -29,6 +29,7 @@ class EmbeddedIndexChunk(IndexChunk):
class InferenceChunk(BaseChunk):
document_id: str
source_type: str
semantic_identifier: str
@classmethod
def from_dict(cls, init_dict):

View File

@ -6,6 +6,7 @@ CONTENT = "content"
SOURCE_TYPE = "source_type"
SOURCE_LINKS = "source_links"
SOURCE_LINK = "link"
SEMANTIC_IDENTIFIER = "semantic_identifier"
SECTION_CONTINUATION = "section_continuation"
ALLOWED_USERS = "allowed_users"
ALLOWED_GROUPS = "allowed_groups"

View File

@ -7,7 +7,6 @@ from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SOURCE_TYPE
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.type_aliases import BatchLoader
@ -129,6 +128,7 @@ class BatchGoogleDriveLoader(BatchLoader):
id=file["webViewLink"],
sections=[Section(link=file["webViewLink"], text=full_context)],
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
metadata={},
)
)

View File

@ -17,7 +17,8 @@ class Document:
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
sections: list[Section]
source: DocumentSource
metadata: dict[str, Any]
semantic_identifier: str | None
metadata: dict[str, Any] | None
def get_raw_document_text(document: Document) -> str:

View File

@ -14,12 +14,15 @@ from danswer.connectors.type_aliases import BatchLoader
def _process_batch_event(
event: dict[str, Any],
slack_event: dict[str, Any],
channel: dict[str, Any],
matching_doc: Document | None,
workspace: str | None = None,
channel_id: str | None = None,
) -> Document | None:
if event["type"] == "message" and event.get("subtype") != "channel_join":
if (
slack_event["type"] == "message"
and slack_event.get("subtype") != "channel_join"
):
if matching_doc:
return Document(
id=matching_doc.id,
@ -27,26 +30,28 @@ def _process_batch_event(
+ [
Section(
link=get_message_link(
event, workspace=workspace, channel_id=channel_id
slack_event, workspace=workspace, channel_id=channel["id"]
),
text=event["text"],
text=slack_event["text"],
)
],
source=matching_doc.source,
semantic_identifier=matching_doc.semantic_identifier,
metadata=matching_doc.metadata,
)
return Document(
id=event["ts"],
id=slack_event["ts"],
sections=[
Section(
link=get_message_link(
event, workspace=workspace, channel_id=channel_id
slack_event, workspace=workspace, channel_id=channel["id"]
),
text=event["text"],
text=slack_event["text"],
)
],
source=DocumentSource.SLACK,
semantic_identifier=channel["name"],
metadata={},
)
@ -76,11 +81,13 @@ class BatchSlackLoader(BatchLoader):
for path in channel_file_paths:
with open(path) as f:
events = cast(list[dict[str, Any]], json.load(f))
for event in events:
for slack_event in events:
doc = _process_batch_event(
event,
document_batch.get(event.get("thread_ts", "")),
channel_id=channel_info["id"],
slack_event=slack_event,
channel=channel_info,
matching_doc=document_batch.get(
slack_event.get("thread_ts", "")
),
)
if doc:
document_batch[doc.id] = doc

View File

@ -23,6 +23,7 @@ logger = setup_logger()
SLACK_LIMIT = 900
ChannelType = dict[str, Any]
MessageType = dict[str, Any]
# list of messages in a thread
ThreadType = list[MessageType]
@ -85,7 +86,14 @@ def _make_slack_api_call(
return _make_slack_api_call_paginated(_make_slack_api_rate_limited(call))(**kwargs)
def get_channels(client: WebClient) -> list[dict[str, Any]]:
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
"""Get information about a channel. Needed to convert channel ID to channel name"""
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
"channel"
]
def get_channels(client: WebClient) -> list[ChannelType]:
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
for result in _make_slack_api_call(client.conversations_list):
@ -127,44 +135,8 @@ def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType
return threads
def _default_msg_filter(message: MessageType) -> bool:
return message.get("subtype", "") == "channel_join"
def get_all_threads(
client: WebClient,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
oldest: str | None = None,
latest: str | None = None,
) -> dict[str, list[ThreadType]]:
"""Get all threads in the workspace"""
channels = get_channels(client)
channel_id_to_messages: dict[str, list[dict[str, Any]]] = {}
for channel in channels:
channel_id_to_messages[channel["id"]] = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest
)
channel_to_threads: dict[str, list[ThreadType]] = {}
for channel_id, messages in channel_id_to_messages.items():
final_threads: list[ThreadType] = []
for message in messages:
thread_ts = message.get("thread_ts")
if thread_ts:
thread = get_thread(client, channel_id, thread_ts)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
if filtered_thread:
final_threads.append(filtered_thread)
else:
final_threads.append([message])
channel_to_threads[channel_id] = final_threads
return channel_to_threads
def thread_to_doc(channel_id: str, thread: ThreadType) -> Document:
def thread_to_doc(channel: ChannelType, thread: ThreadType) -> Document:
channel_id = channel["id"]
return Document(
id=f"{channel_id}__{thread[0]['ts']}",
sections=[
@ -175,20 +147,55 @@ def thread_to_doc(channel_id: str, thread: ThreadType) -> Document:
for m in thread
],
source=DocumentSource.SLACK,
semantic_identifier=channel["name"],
metadata={},
)
def _default_msg_filter(message: MessageType) -> bool:
return message.get("subtype", "") == "channel_join"
def get_all_docs(
client: WebClient,
oldest: str | None = None,
latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
) -> list[Document]:
"""Get all documents in the workspace"""
channel_id_to_threads = get_all_threads(client=client, oldest=oldest, latest=latest)
channels = get_channels(client)
channel_id_to_channel_info = {channel["id"]: channel for channel in channels}
channel_id_to_messages: dict[str, list[dict[str, Any]]] = {}
for channel in channels:
channel_id_to_messages[channel["id"]] = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest
)
channel_id_to_threads: dict[str, list[ThreadType]] = {}
for channel_id, messages in channel_id_to_messages.items():
final_threads: list[ThreadType] = []
for message in messages:
thread_ts = message.get("thread_ts")
if thread_ts:
thread = get_thread(
client=client, channel_id=channel_id, thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
if filtered_thread:
final_threads.append(filtered_thread)
else:
final_threads.append([message])
channel_id_to_threads[channel_id] = final_threads
docs: list[Document] = []
for channel_id, threads in channel_id_to_threads.items():
docs.extend(thread_to_doc(channel_id, thread) for thread in threads)
docs.extend(
thread_to_doc(channel=channel_id_to_channel_info[channel_id], thread=thread)
for thread in threads
)
logger.info(f"Pulled {len(docs)} documents from slack")
return docs

View File

@ -80,6 +80,11 @@ class BatchWebLoader(BatchLoader):
content = await page.content()
soup = BeautifulSoup(content, "html.parser")
title_tag = soup.find("title")
title = None
if title_tag and title_tag.text:
title = title_tag.text
# Heuristics based cleaning
for undesired_tag in ["nav", "header", "footer", "meta"]:
[tag.extract() for tag in soup.find_all(undesired_tag)]
@ -96,6 +101,7 @@ class BatchWebLoader(BatchLoader):
id=current_url,
sections=[Section(link=current_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=title,
metadata={},
)
)
@ -142,6 +148,11 @@ class BatchWebLoader(BatchLoader):
content = page.content()
soup = BeautifulSoup(content, "html.parser")
title_tag = soup.find("title")
title = None
if title_tag and title_tag.text:
title = title_tag.text
# Heuristics based cleaning
for undesired_tag in ["nav", "header", "footer", "meta"]:
[tag.extract() for tag in soup.find_all(undesired_tag)]
@ -158,6 +169,7 @@ class BatchWebLoader(BatchLoader):
id=current_url,
sections=[Section(link=current_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=title,
metadata={},
)
)

View File

@ -7,6 +7,7 @@ from danswer.configs.constants import CHUNK_ID
from danswer.configs.constants import CONTENT
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.semantic_search.semantic_search import DOC_EMBEDDING_DIM
@ -60,6 +61,7 @@ def index_chunks(
CONTENT: chunk.content,
SOURCE_TYPE: str(document.source.value),
SOURCE_LINKS: chunk.source_links,
SEMANTIC_IDENTIFIER: document.semantic_identifier,
SECTION_CONTINUATION: chunk.section_continuation,
ALLOWED_USERS: [], # TODO
ALLOWED_GROUPS: [], # TODO

View File

@ -13,6 +13,7 @@ from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
@ -138,12 +139,14 @@ def match_quotes_to_docs(
DOCUMENT_ID: chunk.document_id,
SOURCE_LINK: curr_link,
SOURCE_TYPE: chunk.source_type,
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
}
break
quotes_dict[quote] = {
DOCUMENT_ID: chunk.document_id,
SOURCE_LINK: curr_link,
SOURCE_TYPE: chunk.source_type,
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
}
break
return quotes_dict

View File

@ -1,5 +1,6 @@
from typing import Any
from danswer.connectors.slack.pull import get_channel_info
from danswer.connectors.slack.pull import get_thread
from danswer.connectors.slack.pull import thread_to_doc
from danswer.connectors.slack.utils import get_client
@ -42,9 +43,12 @@ def process_slack_event(event: SlackEvent):
channel_id = event.event["channel"]
thread_ts = message.get("thread_ts")
slack_client = get_client()
doc = thread_to_doc(
channel_id,
get_thread(get_client(), channel_id, thread_ts)
channel=get_channel_info(client=slack_client, channel_id=channel_id),
thread=get_thread(
client=slack_client, channel_id=channel_id, thread_id=thread_ts
)
if thread_ts
else [message],
)

View File

@ -4,6 +4,8 @@ import json
import requests
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
@ -81,7 +83,8 @@ if __name__ == "__main__":
contents["quotes"].items()
):
print(f"Quote {str(ind + 1)}:\n{quote}")
print(f"Link: {quote_info['link']}")
print(f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}")
print(f"Link: {quote_info[SOURCE_LINK]}")
print(f"Source: {quote_info[SOURCE_TYPE]}")
else:
print("No quotes found")