mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-10 04:49:29 +02:00
DAN-25 Semantic Identifier for Documents (#24)
This commit is contained in:
parent
6e59b02c91
commit
25b59217ef
@ -29,6 +29,7 @@ class EmbeddedIndexChunk(IndexChunk):
|
||||
class InferenceChunk(BaseChunk):
|
||||
document_id: str
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, init_dict):
|
||||
|
@ -6,6 +6,7 @@ CONTENT = "content"
|
||||
SOURCE_TYPE = "source_type"
|
||||
SOURCE_LINKS = "source_links"
|
||||
SOURCE_LINK = "link"
|
||||
SEMANTIC_IDENTIFIER = "semantic_identifier"
|
||||
SECTION_CONTINUATION = "section_continuation"
|
||||
ALLOWED_USERS = "allowed_users"
|
||||
ALLOWED_GROUPS = "allowed_groups"
|
||||
|
@ -7,7 +7,6 @@ from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
|
||||
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.type_aliases import BatchLoader
|
||||
@ -129,6 +128,7 @@ class BatchGoogleDriveLoader(BatchLoader):
|
||||
id=file["webViewLink"],
|
||||
sections=[Section(link=file["webViewLink"], text=full_context)],
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file["name"],
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
|
@ -17,7 +17,8 @@ class Document:
|
||||
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
|
||||
sections: list[Section]
|
||||
source: DocumentSource
|
||||
metadata: dict[str, Any]
|
||||
semantic_identifier: str | None
|
||||
metadata: dict[str, Any] | None
|
||||
|
||||
|
||||
def get_raw_document_text(document: Document) -> str:
|
||||
|
@ -14,12 +14,15 @@ from danswer.connectors.type_aliases import BatchLoader
|
||||
|
||||
|
||||
def _process_batch_event(
|
||||
event: dict[str, Any],
|
||||
slack_event: dict[str, Any],
|
||||
channel: dict[str, Any],
|
||||
matching_doc: Document | None,
|
||||
workspace: str | None = None,
|
||||
channel_id: str | None = None,
|
||||
) -> Document | None:
|
||||
if event["type"] == "message" and event.get("subtype") != "channel_join":
|
||||
if (
|
||||
slack_event["type"] == "message"
|
||||
and slack_event.get("subtype") != "channel_join"
|
||||
):
|
||||
if matching_doc:
|
||||
return Document(
|
||||
id=matching_doc.id,
|
||||
@ -27,26 +30,28 @@ def _process_batch_event(
|
||||
+ [
|
||||
Section(
|
||||
link=get_message_link(
|
||||
event, workspace=workspace, channel_id=channel_id
|
||||
slack_event, workspace=workspace, channel_id=channel["id"]
|
||||
),
|
||||
text=event["text"],
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=matching_doc.source,
|
||||
semantic_identifier=matching_doc.semantic_identifier,
|
||||
metadata=matching_doc.metadata,
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=event["ts"],
|
||||
id=slack_event["ts"],
|
||||
sections=[
|
||||
Section(
|
||||
link=get_message_link(
|
||||
event, workspace=workspace, channel_id=channel_id
|
||||
slack_event, workspace=workspace, channel_id=channel["id"]
|
||||
),
|
||||
text=event["text"],
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=DocumentSource.SLACK,
|
||||
semantic_identifier=channel["name"],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
@ -76,11 +81,13 @@ class BatchSlackLoader(BatchLoader):
|
||||
for path in channel_file_paths:
|
||||
with open(path) as f:
|
||||
events = cast(list[dict[str, Any]], json.load(f))
|
||||
for event in events:
|
||||
for slack_event in events:
|
||||
doc = _process_batch_event(
|
||||
event,
|
||||
document_batch.get(event.get("thread_ts", "")),
|
||||
channel_id=channel_info["id"],
|
||||
slack_event=slack_event,
|
||||
channel=channel_info,
|
||||
matching_doc=document_batch.get(
|
||||
slack_event.get("thread_ts", "")
|
||||
),
|
||||
)
|
||||
if doc:
|
||||
document_batch[doc.id] = doc
|
||||
|
@ -23,6 +23,7 @@ logger = setup_logger()
|
||||
SLACK_LIMIT = 900
|
||||
|
||||
|
||||
ChannelType = dict[str, Any]
|
||||
MessageType = dict[str, Any]
|
||||
# list of messages in a thread
|
||||
ThreadType = list[MessageType]
|
||||
@ -85,7 +86,14 @@ def _make_slack_api_call(
|
||||
return _make_slack_api_call_paginated(_make_slack_api_rate_limited(call))(**kwargs)
|
||||
|
||||
|
||||
def get_channels(client: WebClient) -> list[dict[str, Any]]:
|
||||
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
|
||||
"""Get information about a channel. Needed to convert channel ID to channel name"""
|
||||
return _make_slack_api_call(client.conversations_info, channel=channel_id)[0][
|
||||
"channel"
|
||||
]
|
||||
|
||||
|
||||
def get_channels(client: WebClient) -> list[ChannelType]:
|
||||
"""Get all channels in the workspace"""
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in _make_slack_api_call(client.conversations_list):
|
||||
@ -127,44 +135,8 @@ def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType
|
||||
return threads
|
||||
|
||||
|
||||
def _default_msg_filter(message: MessageType) -> bool:
|
||||
return message.get("subtype", "") == "channel_join"
|
||||
|
||||
|
||||
def get_all_threads(
|
||||
client: WebClient,
|
||||
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
) -> dict[str, list[ThreadType]]:
|
||||
"""Get all threads in the workspace"""
|
||||
channels = get_channels(client)
|
||||
channel_id_to_messages: dict[str, list[dict[str, Any]]] = {}
|
||||
for channel in channels:
|
||||
channel_id_to_messages[channel["id"]] = get_channel_messages(
|
||||
client=client, channel=channel, oldest=oldest, latest=latest
|
||||
)
|
||||
|
||||
channel_to_threads: dict[str, list[ThreadType]] = {}
|
||||
for channel_id, messages in channel_id_to_messages.items():
|
||||
final_threads: list[ThreadType] = []
|
||||
for message in messages:
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
thread = get_thread(client, channel_id, thread_ts)
|
||||
filtered_thread = [
|
||||
message for message in thread if not msg_filter_func(message)
|
||||
]
|
||||
if filtered_thread:
|
||||
final_threads.append(filtered_thread)
|
||||
else:
|
||||
final_threads.append([message])
|
||||
channel_to_threads[channel_id] = final_threads
|
||||
|
||||
return channel_to_threads
|
||||
|
||||
|
||||
def thread_to_doc(channel_id: str, thread: ThreadType) -> Document:
|
||||
def thread_to_doc(channel: ChannelType, thread: ThreadType) -> Document:
|
||||
channel_id = channel["id"]
|
||||
return Document(
|
||||
id=f"{channel_id}__{thread[0]['ts']}",
|
||||
sections=[
|
||||
@ -175,20 +147,55 @@ def thread_to_doc(channel_id: str, thread: ThreadType) -> Document:
|
||||
for m in thread
|
||||
],
|
||||
source=DocumentSource.SLACK,
|
||||
semantic_identifier=channel["name"],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
def _default_msg_filter(message: MessageType) -> bool:
|
||||
return message.get("subtype", "") == "channel_join"
|
||||
|
||||
|
||||
def get_all_docs(
|
||||
client: WebClient,
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
|
||||
) -> list[Document]:
|
||||
"""Get all documents in the workspace"""
|
||||
channel_id_to_threads = get_all_threads(client=client, oldest=oldest, latest=latest)
|
||||
channels = get_channels(client)
|
||||
channel_id_to_channel_info = {channel["id"]: channel for channel in channels}
|
||||
|
||||
channel_id_to_messages: dict[str, list[dict[str, Any]]] = {}
|
||||
for channel in channels:
|
||||
channel_id_to_messages[channel["id"]] = get_channel_messages(
|
||||
client=client, channel=channel, oldest=oldest, latest=latest
|
||||
)
|
||||
|
||||
channel_id_to_threads: dict[str, list[ThreadType]] = {}
|
||||
for channel_id, messages in channel_id_to_messages.items():
|
||||
final_threads: list[ThreadType] = []
|
||||
for message in messages:
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
thread = get_thread(
|
||||
client=client, channel_id=channel_id, thread_id=thread_ts
|
||||
)
|
||||
filtered_thread = [
|
||||
message for message in thread if not msg_filter_func(message)
|
||||
]
|
||||
if filtered_thread:
|
||||
final_threads.append(filtered_thread)
|
||||
else:
|
||||
final_threads.append([message])
|
||||
channel_id_to_threads[channel_id] = final_threads
|
||||
|
||||
docs: list[Document] = []
|
||||
for channel_id, threads in channel_id_to_threads.items():
|
||||
docs.extend(thread_to_doc(channel_id, thread) for thread in threads)
|
||||
docs.extend(
|
||||
thread_to_doc(channel=channel_id_to_channel_info[channel_id], thread=thread)
|
||||
for thread in threads
|
||||
)
|
||||
logger.info(f"Pulled {len(docs)} documents from slack")
|
||||
return docs
|
||||
|
||||
|
@ -80,6 +80,11 @@ class BatchWebLoader(BatchLoader):
|
||||
content = await page.content()
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
|
||||
title_tag = soup.find("title")
|
||||
title = None
|
||||
if title_tag and title_tag.text:
|
||||
title = title_tag.text
|
||||
|
||||
# Heuristics based cleaning
|
||||
for undesired_tag in ["nav", "header", "footer", "meta"]:
|
||||
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
||||
@ -96,6 +101,7 @@ class BatchWebLoader(BatchLoader):
|
||||
id=current_url,
|
||||
sections=[Section(link=current_url, text=page_text)],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=title,
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
@ -142,6 +148,11 @@ class BatchWebLoader(BatchLoader):
|
||||
content = page.content()
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
|
||||
title_tag = soup.find("title")
|
||||
title = None
|
||||
if title_tag and title_tag.text:
|
||||
title = title_tag.text
|
||||
|
||||
# Heuristics based cleaning
|
||||
for undesired_tag in ["nav", "header", "footer", "meta"]:
|
||||
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
||||
@ -158,6 +169,7 @@ class BatchWebLoader(BatchLoader):
|
||||
id=current_url,
|
||||
sections=[Section(link=current_url, text=page_text)],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=title,
|
||||
metadata={},
|
||||
)
|
||||
)
|
||||
|
@ -7,6 +7,7 @@ from danswer.configs.constants import CHUNK_ID
|
||||
from danswer.configs.constants import CONTENT
|
||||
from danswer.configs.constants import DOCUMENT_ID
|
||||
from danswer.configs.constants import SECTION_CONTINUATION
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.semantic_search.semantic_search import DOC_EMBEDDING_DIM
|
||||
@ -60,6 +61,7 @@ def index_chunks(
|
||||
CONTENT: chunk.content,
|
||||
SOURCE_TYPE: str(document.source.value),
|
||||
SOURCE_LINKS: chunk.source_links,
|
||||
SEMANTIC_IDENTIFIER: document.semantic_identifier,
|
||||
SECTION_CONTINUATION: chunk.section_continuation,
|
||||
ALLOWED_USERS: [], # TODO
|
||||
ALLOWED_GROUPS: [], # TODO
|
||||
|
@ -13,6 +13,7 @@ from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import OPENAI_API_KEY
|
||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.configs.constants import DOCUMENT_ID
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINK
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
|
||||
@ -138,12 +139,14 @@ def match_quotes_to_docs(
|
||||
DOCUMENT_ID: chunk.document_id,
|
||||
SOURCE_LINK: curr_link,
|
||||
SOURCE_TYPE: chunk.source_type,
|
||||
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
|
||||
}
|
||||
break
|
||||
quotes_dict[quote] = {
|
||||
DOCUMENT_ID: chunk.document_id,
|
||||
SOURCE_LINK: curr_link,
|
||||
SOURCE_TYPE: chunk.source_type,
|
||||
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
|
||||
}
|
||||
break
|
||||
return quotes_dict
|
||||
|
@ -1,5 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.connectors.slack.pull import get_channel_info
|
||||
from danswer.connectors.slack.pull import get_thread
|
||||
from danswer.connectors.slack.pull import thread_to_doc
|
||||
from danswer.connectors.slack.utils import get_client
|
||||
@ -42,9 +43,12 @@ def process_slack_event(event: SlackEvent):
|
||||
|
||||
channel_id = event.event["channel"]
|
||||
thread_ts = message.get("thread_ts")
|
||||
slack_client = get_client()
|
||||
doc = thread_to_doc(
|
||||
channel_id,
|
||||
get_thread(get_client(), channel_id, thread_ts)
|
||||
channel=get_channel_info(client=slack_client, channel_id=channel_id),
|
||||
thread=get_thread(
|
||||
client=slack_client, channel_id=channel_id, thread_id=thread_ts
|
||||
)
|
||||
if thread_ts
|
||||
else [message],
|
||||
)
|
||||
|
@ -4,6 +4,8 @@ import json
|
||||
import requests
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINK
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
|
||||
|
||||
@ -81,7 +83,8 @@ if __name__ == "__main__":
|
||||
contents["quotes"].items()
|
||||
):
|
||||
print(f"Quote {str(ind + 1)}:\n{quote}")
|
||||
print(f"Link: {quote_info['link']}")
|
||||
print(f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}")
|
||||
print(f"Link: {quote_info[SOURCE_LINK]}")
|
||||
print(f"Source: {quote_info[SOURCE_TYPE]}")
|
||||
else:
|
||||
print("No quotes found")
|
||||
|
Loading…
x
Reference in New Issue
Block a user