Initial Commit

This commit is contained in:
Yuhong Sun
2023-04-28 22:40:46 -07:00
parent 751a8ea69e
commit 402b89b6ec
88 changed files with 7447 additions and 0 deletions

8
backend/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
__pycache__/
.idea/
site_crawls/
.ipynb_checkpoints/
api_keys.py
*ipynb
qdrant-data/
typesense-data/

View File

@ -0,0 +1,12 @@
repos:
- repo: https://github.com/psf/black
rev: 23.3.0
hooks:
- id: black
language_version: python3.11
- repo: https://github.com/asottile/reorder_python_imports
rev: v3.9.0
hooks:
- id: reorder-python-imports
args: ['--py311-plus']

14
backend/Dockerfile Normal file
View File

@ -0,0 +1,14 @@
FROM python:3.11-slim-bullseye
RUN apt-get update \
&& apt-get install -y git cmake pkg-config libprotobuf-c-dev protobuf-compiler libprotobuf-dev libgoogle-perftools-dev build-essential \
&& rm -rf /var/lib/apt/lists/*
COPY ./requirements/default.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
WORKDIR /app
COPY ./qa_service /app/qa_service
ENV PYTHONPATH .
CMD ["uvicorn", "danswer.main:app", "--host", "0.0.0.0", "--port", "8080"]

View File

@ -0,0 +1,14 @@
FROM python:3.11-slim-bullseye
RUN apt-get update \
&& apt-get install -y git cmake pkg-config libprotobuf-c-dev protobuf-compiler libprotobuf-dev libgoogle-perftools-dev build-essential cron \
&& rm -rf /var/lib/apt/lists/*
COPY ./requirements/default.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
WORKDIR /app
COPY ./qa_service /app/qa_service
ENV PYTHONPATH .
CMD ["python3", "qa_service/background/update.py"]

View File

@ -0,0 +1,58 @@
import time
from typing import cast
from danswer.connectors.slack.config import get_pull_frequency
from danswer.connectors.slack.pull import SlackPullLoader
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logging import setup_logger
logger = setup_logger()
LAST_PULL_KEY_TEMPLATE = "last_pull_{}"
def _check_should_run(current_time: int, last_pull: int, pull_frequency: int) -> bool:
return current_time - last_pull > pull_frequency * 60
def run_update():
logger.info("Running update")
# TODO (chris): implement a more generic way to run updates
# so we don't need to edit this file for future connectors
dynamic_config_store = get_dynamic_config_store()
current_time = int(time.time())
# Slack
try:
pull_frequency = get_pull_frequency()
except ConfigNotFoundError:
pull_frequency = 0
if pull_frequency:
last_slack_pull_key = LAST_PULL_KEY_TEMPLATE.format(SlackPullLoader.__name__)
try:
last_pull = cast(int, dynamic_config_store.load(last_slack_pull_key))
except ConfigNotFoundError:
last_pull = None
if last_pull is None or _check_should_run(
current_time, last_pull, pull_frequency
):
logger.info(f"Running slack pull from {last_pull or 0} to {current_time}")
for doc_batch in SlackPullLoader().load(last_pull or 0, current_time):
print(len(doc_batch))
dynamic_config_store.store(last_slack_pull_key, current_time)
if __name__ == "__main__":
DELAY = 60 # 60 seconds
while True:
start = time.time()
try:
run_update()
except Exception:
logger.exception("Failed to run update")
sleep_time = DELAY - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)

View File

View File

@ -0,0 +1,152 @@
import abc
from collections.abc import Callable
from danswer.chunking.models import IndexChunk
from danswer.configs.app_configs import CHUNK_OVERLAP
from danswer.configs.app_configs import CHUNK_SIZE
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.text_processing import shared_precompare_cleanup
SECTION_SEPARATOR = "\n\n"
ChunkFunc = Callable[[Document], list[IndexChunk]]
def chunk_large_section(
section: Section,
document: Document,
start_chunk_id: int,
chunk_size: int = CHUNK_SIZE,
word_overlap: int = CHUNK_OVERLAP,
) -> list[IndexChunk]:
section_text = section.text
char_count = len(section_text)
chunk_strs: list[str] = []
start_pos = segment_start_pos = 0
while start_pos < char_count:
back_count_words = 0
end_pos = segment_end_pos = min(start_pos + chunk_size, char_count)
while not section_text[segment_end_pos - 1].isspace():
if segment_end_pos >= char_count:
break
segment_end_pos += 1
while back_count_words <= word_overlap:
if segment_start_pos == 0:
break
if section_text[segment_start_pos].isspace():
back_count_words += 1
segment_start_pos -= 1
if segment_start_pos != 0:
segment_start_pos += 2
chunk_str = section_text[segment_start_pos:segment_end_pos]
if chunk_str[-1].isspace():
chunk_str = chunk_str[:-1]
chunk_strs.append(chunk_str)
start_pos = segment_start_pos = end_pos
# Last chunk should be as long as possible, overlap favored over tiny chunk with no context
if len(chunk_strs) > 1:
chunk_strs.pop()
back_count_words = 0
start_pos = char_count - chunk_size
while back_count_words <= word_overlap:
if section_text[start_pos].isspace():
back_count_words += 1
start_pos -= 1
chunk_strs.append(section_text[start_pos + 2 :])
chunks = []
for chunk_ind, chunk_str in enumerate(chunk_strs):
chunks.append(
IndexChunk(
source_document=document,
chunk_id=start_chunk_id + chunk_ind,
content=chunk_str,
source_links={0: section.link},
section_continuation=(chunk_ind != 0),
)
)
return chunks
def chunk_document(
document: Document,
chunk_size: int = CHUNK_SIZE,
subsection_overlap: int = CHUNK_OVERLAP,
) -> list[IndexChunk]:
chunks: list[IndexChunk] = []
link_offsets: dict[int, str] = {}
chunk_text = ""
for section in document.sections:
current_length = len(chunk_text)
curr_offset_len = len(shared_precompare_cleanup(chunk_text))
section_length = len(section.text)
# Large sections are considered self-contained/unique therefore they start a new chunk and are not concatenated
# at the end by other sections
if section_length > chunk_size:
if chunk_text:
chunks.append(
IndexChunk(
source_document=document,
chunk_id=len(chunks),
content=chunk_text,
source_links=link_offsets,
section_continuation=False,
)
)
link_offsets = {}
chunk_text = ""
large_section_chunks = chunk_large_section(
section=section,
document=document,
start_chunk_id=len(chunks),
chunk_size=chunk_size,
word_overlap=subsection_overlap,
)
chunks.extend(large_section_chunks)
continue
# In the case where the whole section is shorter than a chunk, either adding to chunk or start a new one
if current_length + len(SECTION_SEPARATOR) + section_length <= chunk_size:
chunk_text += (
SECTION_SEPARATOR + section.text if chunk_text else section.text
)
link_offsets[curr_offset_len] = section.link
else:
chunks.append(
IndexChunk(
source_document=document,
chunk_id=len(chunks),
content=chunk_text,
source_links=link_offsets,
section_continuation=False,
)
)
link_offsets = {0: section.link}
chunk_text = section.text
# Once we hit the end, if we're still in the process of building a chunk, add what we have
if chunk_text:
chunks.append(
IndexChunk(
source_document=document,
chunk_id=len(chunks),
content=chunk_text,
source_links=link_offsets,
section_continuation=False,
)
)
return chunks
class Chunker:
@abc.abstractmethod
def chunk(self, document: Document) -> list[IndexChunk]:
raise NotImplementedError
class DefaultChunker(Chunker):
def chunk(self, document: Document) -> list[IndexChunk]:
return chunk_document(document)

View File

@ -0,0 +1,41 @@
import inspect
from dataclasses import dataclass
from typing import Optional
from danswer.connectors.models import Document
@dataclass
class BaseChunk:
chunk_id: int
content: str
source_links: Optional[
dict[int, str]
] # Holds the link and the offsets into the raw Chunk text
section_continuation: bool # True if this Chunk's start is not at the start of a Section
@dataclass
class IndexChunk(BaseChunk):
source_document: Document
@dataclass
class EmbeddedIndexChunk(IndexChunk):
embedding: list[float]
@dataclass
class InferenceChunk(BaseChunk):
document_id: str
source_type: str
@classmethod
def from_dict(cls, init_dict):
return cls(
**{
k: v
for k, v in init_dict.items()
if k in inspect.signature(cls).parameters
}
)

View File

View File

@ -0,0 +1,73 @@
import os
#####
# App Configs
#####
APP_HOST = "0.0.0.0"
APP_PORT = 8080
#####
# Vector DB Configs
#####
# Url / Key are used to connect to a remote Qdrant instance
QDRANT_URL = os.environ.get("QDRANT_URL", "")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
# Host / Port are used for connecting to local Qdrant instance
QDRANT_HOST = os.environ.get("QDRANT_HOST", "localhost")
QDRANT_PORT = 6333
QDRANT_DEFAULT_COLLECTION = os.environ.get("QDRANT_COLLECTION", "semantic_search")
DB_CONN_TIMEOUT = 2 # Timeout seconds connecting to DBs
INDEX_BATCH_SIZE = 16 # File batches (not accounting file chunking)
#####
# Connector Configs
#####
GOOGLE_DRIVE_CREDENTIAL_JSON = os.environ.get("GOOGLE_DRIVE_CREDENTIAL_JSON", "")
GOOGLE_DRIVE_TOKENS_JSON = os.environ.get("GOOGLE_DRIVE_TOKENS_JSON", "")
GOOGLE_DRIVE_INCLUDE_SHARED = False
#####
# Query Configs
#####
DEFAULT_PROMPT = "generic-qa"
NUM_RETURNED_HITS = 15
NUM_RERANKED_RESULTS = 4
KEYWORD_MAX_HITS = 5
QUOTE_ALLOWED_ERROR_PERCENT = 0.05 # 1 edit per 2 characters
#####
# Text Processing Configs
#####
# Chunking docs to this number of characters not including finishing the last word and the overlap words below
# Calculated by ~500 to 512 tokens max * average 4 chars per token
CHUNK_SIZE = 2000
# Each chunk includes an additional 5 words from previous chunk
# in extreme cases, may cause some words at the end to be truncated by embedding model
CHUNK_OVERLAP = 5
#####
# Other API Keys
#####
OPENAI_API_KEY = os.environ["OPENAI_API_KEY"]
#####
# Encoder Model Endpoint Configs (Currently unused, running the models in memory)
#####
BI_ENCODER_HOST = "localhost"
BI_ENCODER_PORT = 9000
CROSS_ENCODER_HOST = "localhost"
CROSS_ENCODER_PORT = 9000
#####
# Miscellaneous
#####
TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "")
TYPESENSE_HOST = "localhost"
TYPESENSE_PORT = 8108

View File

@ -0,0 +1,27 @@
from enum import Enum
DOCUMENT_ID = "document_id"
CHUNK_ID = "chunk_id"
CONTENT = "content"
SOURCE_TYPE = "source_type"
SOURCE_LINKS = "source_links"
SOURCE_LINK = "link"
SECTION_CONTINUATION = "section_continuation"
ALLOWED_USERS = "allowed_users"
ALLOWED_GROUPS = "allowed_groups"
class DocumentSource(Enum):
Slack = 1
Web = 2
GoogleDrive = 3
Unknown = 4
def __str__(self):
return self.name
def __int__(self):
return self.value
WEB_SOURCE = "Web"

View File

@ -0,0 +1,23 @@
import os
# Bi/Cross-Encoder Model Configs
# TODO: try 'all-distilroberta-v1' maybe larger training set has more technical knowledge (768 dim)
# Downside: slower by factor of 3 (model size)
# Important considerations, max tokens must be 512
DOCUMENT_ENCODER_MODEL = "multi-qa-MiniLM-L6-cos-v1"
DOC_EMBEDDING_DIM = 384 # Depends on the document encoder model
# L-12-v2 might be worth a try, though stats seem very very similar, L-12 slower by factor of 2
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
QUERY_EMBEDDING_CONTEXT_SIZE = 256
DOC_EMBEDDING_CONTEXT_SIZE = 512
CROSS_EMBED_CONTEXT_SIZE = 512
MODEL_CACHE_FOLDER = os.environ.get("TRANSFORMERS_CACHE")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# OpenAI Model API Configs
OPENAPI_MODEL_VERSION = "text-davinci-003"
OPENAI_MAX_OUTPUT_TOKENS = 200

View File

View File

@ -0,0 +1,135 @@
import io
import os
from collections.abc import Generator
from danswer.configs.app_configs import GOOGLE_DRIVE_CREDENTIAL_JSON
from danswer.configs.app_configs import GOOGLE_DRIVE_INCLUDE_SHARED
from danswer.configs.app_configs import GOOGLE_DRIVE_TOKENS_JSON
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SOURCE_TYPE
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.type_aliases import BatchLoader
from danswer.utils.logging import setup_logger
from google.auth.transport.requests import Request # type: ignore
from google.oauth2.credentials import Credentials # type: ignore
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from googleapiclient import discovery # type: ignore
from PyPDF2 import PdfReader
logger = setup_logger()
SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
SUPPORTED_DRIVE_DOC_TYPES = [
"application/vnd.google-apps.document",
"application/pdf",
"application/vnd.google-apps.spreadsheet",
]
ID_KEY = "id"
LINK_KEY = "link"
TYPE_KEY = "type"
def get_credentials() -> Credentials:
creds = None
if os.path.exists(GOOGLE_DRIVE_TOKENS_JSON):
creds = Credentials.from_authorized_user_file(GOOGLE_DRIVE_TOKENS_JSON, SCOPES)
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
creds.refresh(Request())
else:
flow = InstalledAppFlow.from_client_secrets_file(
GOOGLE_DRIVE_CREDENTIAL_JSON, SCOPES
)
creds = flow.run_local_server()
with open(GOOGLE_DRIVE_TOKENS_JSON, "w") as token_file:
token_file.write(creds.to_json())
return creds
def get_file_batches(
service: discovery.Resource,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
batch_size: int = INDEX_BATCH_SIZE,
):
next_page_token = ""
while next_page_token is not None:
results = (
service.files()
.list(
pageSize=batch_size,
supportsAllDrives=include_shared,
fields="nextPageToken, files(mimeType, id, name, webViewLink)",
pageToken=next_page_token,
)
.execute()
)
next_page_token = results.get("nextPageToken")
files = results["files"]
valid_files = []
for file in files:
if file["mimeType"] in SUPPORTED_DRIVE_DOC_TYPES:
valid_files.append(file)
logger.info(
f"Parseable Documents in batch: {[file['name'] for file in valid_files]}"
)
yield valid_files
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
mime_type = file["mimeType"]
if mime_type == "application/vnd.google-apps.document":
return (
service.files()
.export(fileId=file["id"], mimeType="text/plain")
.execute()
.decode("utf-8")
)
elif mime_type == "application/vnd.google-apps.spreadsheet":
return (
service.files()
.export(fileId=file["id"], mimeType="text/csv")
.execute()
.decode("utf-8")
)
# Default download to PDF since most types can be exported as a PDF
else:
response = service.files().get_media(fileId=file["id"]).execute()
pdf_stream = io.BytesIO(response)
pdf_reader = PdfReader(pdf_stream)
return "\n".join(page.extract_text() for page in pdf_reader.pages)
class BatchGoogleDriveLoader(BatchLoader):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
include_shared: bool = GOOGLE_DRIVE_INCLUDE_SHARED,
) -> None:
self.batch_size = batch_size
self.include_shared = include_shared
self.creds = get_credentials()
def load(self) -> Generator[list[Document], None, None]:
service = discovery.build("drive", "v3", credentials=self.creds)
for files_batch in get_file_batches(
service, self.include_shared, self.batch_size
):
doc_batch = []
for file in files_batch:
text_contents = extract_text(file, service)
full_context = file["name"] + " " + text_contents
doc_batch.append(
Document(
id=file["webViewLink"],
sections=[Section(link=file["webViewLink"], text=full_context)],
metadata={SOURCE_TYPE: DocumentSource.GoogleDrive},
)
)
yield doc_batch

View File

@ -0,0 +1,19 @@
from dataclasses import dataclass
from typing import Any
@dataclass
class Section:
link: str
text: str
@dataclass
class Document:
id: str
sections: list[Section]
metadata: dict[str, Any]
def get_raw_document_text(document: Document) -> str:
return "\n\n".join([section.text for section in document.sections])

View File

@ -0,0 +1,87 @@
import json
import os
from collections.abc import Generator
from pathlib import Path
from typing import Any
from typing import cast
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.type_aliases import BatchLoader
def _process_batch_event(
event: dict[str, Any],
matching_doc: Document | None,
workspace: str | None = None,
channel_id: str | None = None,
) -> Document | None:
if event["type"] == "message" and event.get("subtype") != "channel_join":
if matching_doc:
return Document(
id=matching_doc.id,
sections=matching_doc.sections
+ [
Section(
link=get_message_link(
event, workspace=workspace, channel_id=channel_id
),
text=event["text"],
)
],
metadata=matching_doc.metadata,
)
return Document(
id=event["ts"],
sections=[
Section(
link=get_message_link(
event, workspace=workspace, channel_id=channel_id
),
text=event["text"],
)
],
metadata={},
)
return None
class BatchSlackLoader(BatchLoader):
def __init__(
self, export_path_str: str, batch_size: int = INDEX_BATCH_SIZE
) -> None:
self.export_path_str = export_path_str
self.batch_size = batch_size
def load(self) -> Generator[list[Document], None, None]:
export_path = Path(self.export_path_str)
with open(export_path / "channels.json") as f:
channels = json.load(f)
document_batch: dict[str, Document] = {}
for channel_info in channels:
channel_dir_path = export_path / cast(str, channel_info["name"])
channel_file_paths = [
channel_dir_path / file_name
for file_name in os.listdir(channel_dir_path)
]
for path in channel_file_paths:
with open(path) as f:
events = cast(list[dict[str, Any]], json.load(f))
for event in events:
doc = _process_batch_event(
event,
document_batch.get(event.get("thread_ts", "")),
channel_id=channel_info["id"],
)
if doc:
document_batch[doc.id] = doc
if len(document_batch) >= self.batch_size:
yield list(document_batch.values())
yield list(document_batch.values())

View File

@ -0,0 +1,33 @@
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from pydantic import BaseModel
SLACK_CONFIG_KEY = "slack_connector_config"
class SlackConfig(BaseModel):
slack_bot_token: str
workspace_id: str
pull_frequency: int = 0 # in minutes, 0 => no pulling
def get_slack_config() -> SlackConfig:
slack_config = get_dynamic_config_store().load(SLACK_CONFIG_KEY)
return SlackConfig.parse_obj(slack_config)
def get_slack_bot_token() -> str:
return get_slack_config().slack_bot_token
def get_workspace_id() -> str:
return get_slack_config().workspace_id
def get_pull_frequency() -> int:
return get_slack_config().pull_frequency
def update_slack_config(slack_config: SlackConfig):
get_dynamic_config_store().store(SLACK_CONFIG_KEY, slack_config.dict())

View File

@ -0,0 +1,191 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import List
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.utils import get_client
from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.type_aliases import PullLoader
from danswer.connectors.type_aliases import SecondsSinceUnixEpoch
from danswer.utils.logging import setup_logger
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
logger = setup_logger()
SLACK_LIMIT = 900
MessageType = dict[str, Any]
# list of messages in a thread
ThreadType = list[MessageType]
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., list[dict[str, Any]]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
cursor: str | None = None
has_more = True
while has_more:
for result in call(cursor=cursor, limit=SLACK_LIMIT, **kwargs):
has_more = result.get("has_more", False)
cursor = result.get("response_metadata", {}).get("next_cursor", "")
results.append(cast(dict[str, Any], result))
return results
return paginated_call
def _make_slack_api_rate_limited(
call: Callable[..., SlackResponse], max_retries: int = 3
) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting"""
def rate_limited_call(**kwargs: Any) -> SlackResponse:
for _ in range(max_retries):
try:
# Make the API call
response = call(**kwargs)
# Check for errors in the response
if response.get("ok"):
return response
else:
raise SlackApiError("", response)
except SlackApiError as e:
if e.response["error"] == "ratelimited":
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
retry_after = int(e.response.headers.get("Retry-After", 1))
time.sleep(retry_after)
else:
# Raise the error for non-transient errors
raise
# If the code reaches this point, all retries have been exhausted
raise Exception(f"Max retries ({max_retries}) exceeded")
return rate_limited_call
def _make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> list[dict[str, Any]]:
return _make_slack_api_call_paginated(_make_slack_api_rate_limited(call))(**kwargs)
def get_channels(client: WebClient) -> list[dict[str, Any]]:
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
for result in _make_slack_api_call(client.conversations_list):
channels.extend(result["channels"])
return channels
def get_channel_messages(
client: WebClient, channel: dict[str, Any]
) -> list[MessageType]:
"""Get all messages in a channel"""
# join so that the bot can access messages
if not channel["is_member"]:
client.conversations_join(
channel=channel["id"], is_private=channel["is_private"]
)
messages: list[MessageType] = []
for result in _make_slack_api_call(
client.conversations_history, channel=channel["id"]
):
messages.extend(result["messages"])
return messages
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
"""Get all messages in a thread"""
threads: list[MessageType] = []
for result in _make_slack_api_call(
client.conversations_replies, channel=channel_id, ts=thread_id
):
threads.extend(result["messages"])
return threads
def _default_msg_filter(message: MessageType) -> bool:
return message.get("subtype", "") == "channel_join"
def get_all_threads(
client: WebClient,
msg_filter_func: Callable[[MessageType], bool] = _default_msg_filter,
) -> dict[str, list[ThreadType]]:
"""Get all threads in the workspace"""
channels = get_channels(client)
channel_id_to_messages: dict[str, list[dict[str, Any]]] = {}
for channel in channels:
channel_id_to_messages[channel["id"]] = get_channel_messages(client, channel)
channel_to_threads: dict[str, list[ThreadType]] = {}
for channel_id, messages in channel_id_to_messages.items():
final_threads: list[ThreadType] = []
for message in messages:
thread_ts = message.get("thread_ts")
if thread_ts:
thread = get_thread(client, channel_id, thread_ts)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
if filtered_thread:
final_threads.append(filtered_thread)
else:
final_threads.append([message])
channel_to_threads[channel_id] = final_threads
return channel_to_threads
def thread_to_doc(channel_id: str, thread: ThreadType) -> Document:
return Document(
id=f"{channel_id}__{thread[0]['ts']}",
sections=[
Section(
link=get_message_link(m, channel_id=channel_id),
text=cast(str, m["text"]),
)
for m in thread
],
metadata={},
)
def get_all_docs(client: WebClient) -> list[Document]:
"""Get all documents in the workspace"""
channel_id_to_threads = get_all_threads(client)
docs: list[Document] = []
for channel_id, threads in channel_id_to_threads.items():
docs.extend(thread_to_doc(channel_id, thread) for thread in threads)
logger.info(f"Pulled {len(docs)} documents from slack")
return docs
class SlackPullLoader(PullLoader):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.client = get_client()
self.batch_size = batch_size
def load(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> Generator[List[Document], None, None]:
# TODO: make this respect start and end
all_docs = get_all_docs(self.client)
for i in range(0, len(all_docs), self.batch_size):
yield all_docs[i : i + self.batch_size]

View File

@ -0,0 +1,22 @@
from typing import Any
from typing import cast
from danswer.connectors.slack.config import get_slack_bot_token
from danswer.connectors.slack.config import get_workspace_id
from slack_sdk import WebClient
def get_client() -> WebClient:
"""NOTE: assumes token is present in environment variable SLACK_BOT_TOKEN"""
return WebClient(token=get_slack_bot_token())
def get_message_link(
event: dict[str, Any], workspace: str | None = None, channel_id: str | None = None
) -> str:
channel_id = channel_id or cast(
str, event["channel"]
) # channel must either be present in the event or passed in
message_ts = cast(str, event["ts"])
message_ts_without_dot = message_ts.replace(".", "")
return f"https://{workspace or get_workspace_id()}.slack.com/archives/{channel_id}/p{message_ts_without_dot}"

View File

@ -0,0 +1,41 @@
import abc
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from typing import Any
from typing import List
from typing import Optional
from danswer.connectors.models import Document
ConnectorConfig = dict[str, Any]
# takes in the raw representation of a document from a source and returns a
# Document object
ProcessDocumentFunc = Callable[..., Document]
BuildListenerFunc = Callable[[ConnectorConfig], ProcessDocumentFunc]
class BatchLoader:
@abc.abstractmethod
def load(self) -> Generator[List[Document], None, None]:
raise NotImplementedError
SecondsSinceUnixEpoch = int
class PullLoader:
@abc.abstractmethod
def load(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> Generator[List[Document], None, None]:
raise NotImplementedError
# Fetches raw representations from a specific source for the specified time
# range. Is used when the source does not support subscriptions to some sort
# of event stream
# TODO: use Protocol instead of Callable
TimeRangeBasedLoad = Callable[[datetime, datetime], list[Any]]

View File

@ -0,0 +1,117 @@
from collections.abc import Generator
from typing import Any
from typing import cast
from urllib.parse import urljoin
from urllib.parse import urlparse
from bs4 import BeautifulSoup
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SOURCE_TYPE
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.type_aliases import BatchLoader
from danswer.utils.logging import setup_logger
from playwright.sync_api import sync_playwright
logger = setup_logger()
TAG_SEPARATOR = "\n"
def is_valid_url(url: str) -> bool:
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False
def get_internal_links(
base_url: str, url: str, soup: BeautifulSoup, should_ignore_pound: bool = True
) -> list[str]:
internal_links = []
for link in cast(list[dict[str, Any]], soup.find_all("a")):
href = cast(str | None, link.get("href"))
if not href:
continue
if should_ignore_pound and "#" in href:
href = href.split("#")[0]
if not is_valid_url(href):
href = urljoin(url, href)
if urlparse(href).netloc == urlparse(url).netloc and base_url in href:
internal_links.append(href)
return internal_links
class BatchWebLoader(BatchLoader):
def __init__(
self,
base_url: str,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.base_url = base_url
self.batch_size = batch_size
def load(self) -> Generator[list[Document], None, None]:
"""Traverses through all pages found on the website
and converts them into documents"""
visited_links: set[str] = set()
to_visit: list[str] = [self.base_url]
doc_batch: list[Document] = []
with sync_playwright() as playwright:
browser = playwright.chromium.launch(headless=True)
context = browser.new_context()
while to_visit:
current_url = to_visit.pop()
if current_url in visited_links:
continue
visited_links.add(current_url)
try:
page = context.new_page()
page.goto(current_url)
content = page.content()
soup = BeautifulSoup(content, "html.parser")
# Heuristics based cleaning
for undesired_tag in ["nav", "header", "footer", "meta"]:
[tag.extract() for tag in soup.find_all(undesired_tag)]
for undesired_div in ["sidebar", "header", "footer"]:
[
tag.extract()
for tag in soup.find_all("div", {"class": undesired_div})
]
page_text = soup.get_text(TAG_SEPARATOR)
doc_batch.append(
Document(
id=current_url,
sections=[Section(link=current_url, text=page_text)],
metadata={SOURCE_TYPE: DocumentSource.Web},
)
)
internal_links = get_internal_links(
self.base_url, current_url, soup
)
for link in internal_links:
if link not in visited_links:
to_visit.append(link)
page.close()
except Exception as e:
logger.error(f"Failed to fetch '{current_url}': {e}")
continue
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch

View File

@ -0,0 +1,12 @@
from typing import Type
from danswer.datastores.interfaces import Datastore
from danswer.datastores.qdrant.store import QdrantDatastore
def get_selected_datastore_cls() -> Type[Datastore]:
"""Returns the selected Datastore cls. Only one datastore
should be selected for a specific deployment."""
# TOOD: when more datastores are added, look at env variable to
# figure out which one should be returned
return QdrantDatastore

View File

@ -0,0 +1,14 @@
import abc
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import InferenceChunk
class Datastore:
@abc.abstractmethod
def index(self, chunks: list[EmbeddedIndexChunk]) -> bool:
raise NotImplementedError
@abc.abstractmethod
def search(self, query: str, num_to_retrieve: int) -> list[InferenceChunk]:
raise NotImplementedError

View File

@ -0,0 +1,85 @@
import uuid
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.configs.constants import ALLOWED_GROUPS
from danswer.configs.constants import ALLOWED_USERS
from danswer.configs.constants import CHUNK_ID
from danswer.configs.constants import CONTENT
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logging import setup_logger
from qdrant_client import QdrantClient
from qdrant_client.http.models.models import UpdateStatus
from qdrant_client.models import Distance
from qdrant_client.models import PointStruct
from qdrant_client.models import VectorParams
logger = setup_logger()
DEFAULT_BATCH_SIZE = 30
def recreate_collection(collection_name: str, embedding_dim: int = DOC_EMBEDDING_DIM):
logger.info(f"Attempting to recreate collection {collection_name}")
result = get_qdrant_client().recreate_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)
if not result:
raise RuntimeError("Could not create Qdrant collection")
def index_chunks(
chunks: list[EmbeddedIndexChunk],
collection: str,
client: QdrantClient | None = None,
batch_upsert: bool = False,
) -> bool:
if client is None:
client = get_qdrant_client()
point_structs = []
for chunk in chunks:
document = chunk.source_document
point_structs.append(
PointStruct(
id=str(uuid.uuid4()),
payload={
DOCUMENT_ID: document.id,
CHUNK_ID: chunk.chunk_id,
CONTENT: chunk.content,
SOURCE_TYPE: str(
document.metadata.get("source_type", DocumentSource.Unknown)
),
SOURCE_LINKS: chunk.source_links,
SECTION_CONTINUATION: chunk.section_continuation,
ALLOWED_USERS: [], # TODO
ALLOWED_GROUPS: [], # TODO
},
vector=chunk.embedding,
)
)
index_results = None
if batch_upsert:
point_struct_batches = [
point_structs[x : x + DEFAULT_BATCH_SIZE]
for x in range(0, len(point_structs), DEFAULT_BATCH_SIZE)
]
for point_struct_batch in point_struct_batches:
index_results = client.upsert(
collection_name=collection, points=point_struct_batch
)
logger.info(
f"Indexing {len(point_struct_batch)} chunks into collection '{collection}', "
f"status: {index_results.status}"
)
else:
index_results = client.upsert(collection_name=collection, points=point_structs)
logger.info(f"Batch indexing status: {index_results.status}")
return index_results is not None and index_results.status == UpdateStatus.COMPLETED

View File

@ -0,0 +1,55 @@
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.datastores.interfaces import Datastore
from danswer.datastores.qdrant.indexing import index_chunks
from danswer.embedding.biencoder import get_default_model
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logging import setup_logger
from qdrant_client.http.models import FieldCondition
from qdrant_client.http.models import Filter
from qdrant_client.http.models import MatchValue
logger = setup_logger()
class QdrantDatastore(Datastore):
def __init__(self, collection: str = QDRANT_DEFAULT_COLLECTION) -> None:
self.collection = collection
self.client = get_qdrant_client()
def index(self, chunks: list[EmbeddedIndexChunk]) -> bool:
return index_chunks(
chunks=chunks, collection=self.collection, client=self.client
)
def search(self, query: str, num_to_retrieve: int) -> list[InferenceChunk]:
query_embedding = get_default_model().encode(
query
) # TODO: make this part of the embedder interface
hits = self.client.search(
collection_name=self.collection,
query_vector=query_embedding
if isinstance(query_embedding, list)
else query_embedding.tolist(),
query_filter=None,
limit=num_to_retrieve,
)
return [InferenceChunk.from_dict(hit.payload) for hit in hits]
def get_from_id(self, object_id: str) -> InferenceChunk | None:
matches, _ = self.client.scroll(
collection_name=self.collection,
scroll_filter=Filter(
should=[FieldCondition(key="id", match=MatchValue(value=object_id))]
),
)
if not matches:
return None
if len(matches) > 1:
logger.error(f"Found multiple matches for {logger}: {matches}")
match = matches[0]
return InferenceChunk.from_dict(match.payload)

View File

View File

@ -0,0 +1,28 @@
DOC_SEP_PAT = "---NEW DOCUMENT---"
QUESTION_PAT = "Query:"
ANSWER_PAT = "Answer:"
UNCERTAINTY_PAT = "?"
QUOTE_PAT = "Quote:"
def generic_prompt_processor(question: str, documents: list[str]) -> str:
prompt = (
f"Answer the query based on the documents below and quote the documents sections containing "
f'the answer. Respond with one "{ANSWER_PAT}" section and one or more "{QUOTE_PAT}" sections. '
f"For each quote, only include text exactly from the documents, don't include the source. "
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
f'Each document is prefixed with "{DOC_SEP_PAT}".\n\n'
)
for document in documents:
prompt += f"\n{DOC_SEP_PAT}\n{document}"
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
prompt += f"{ANSWER_PAT}\n"
return prompt
BASIC_QA_PROMPTS = {
"generic-qa": generic_prompt_processor,
}

View File

@ -0,0 +1,166 @@
import math
import re
from collections.abc import Callable
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Union
import openai
import regex
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import OPENAPI_MODEL_VERSION
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.utils.logging import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger()
openai.api_key = OPENAI_API_KEY
def ask_openai(
complete_qa_prompt: str,
model: str = OPENAPI_MODEL_VERSION,
max_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
) -> str:
try:
response = openai.Completion.create(
prompt=complete_qa_prompt,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=model,
max_tokens=max_tokens,
)
model_answer = response["choices"][0]["text"].strip()
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
return model_answer
except Exception as e:
logger.exception(e)
return "Model Failure"
def answer_question(
query: str,
context_docs: list[str],
prompt_processor: Callable[[str, list[str]], str],
) -> str:
formatted_prompt = prompt_processor(query, context_docs)
logger.debug(formatted_prompt)
return ask_openai(formatted_prompt)
def separate_answer_quotes(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
"""Gives back the answer and quote sections"""
null_answer_check = (
answer_raw.replace(ANSWER_PAT, "").replace(QUOTE_PAT, "").strip()
)
# If model just gives back the uncertainty pattern to signify answer isn't found or nothing at all
if null_answer_check == UNCERTAINTY_PAT or not null_answer_check:
return None, None
# If no answer section, don't care about the quote
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
return None, None
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
answer_raw = answer_raw[len(ANSWER_PAT) :]
# Accept quote sections starting with the lower case version
answer_raw = answer_raw.replace(
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
) # Just in case model unreliable
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
sections_clean = [
str(section).strip() for section in sections if str(section).strip()
]
if not sections_clean:
return None, None
answer = str(sections_clean[0])
if len(sections) == 1:
return answer, None
return answer, sections_clean[1:]
def match_quotes_to_docs(
quotes: list[str],
chunks: list[InferenceChunk],
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
fuzzy_search: bool = False,
prefix_only_length: int = 100,
) -> Dict[str, Dict[str, Union[str, int, None]]]:
quotes_dict: dict[str, dict[str, Union[str, int, None]]] = {}
for quote in quotes:
max_edits = math.ceil(float(len(quote)) * max_error_percent)
for chunk in chunks:
if not chunk.source_links:
continue
quote_clean = shared_precompare_cleanup(
clean_model_quote(quote, trim_length=prefix_only_length)
)
chunk_clean = shared_precompare_cleanup(chunk.content)
# Finding the offset of the quote in the plain text
if fuzzy_search:
re_search_str = (
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
)
found = regex.search(re_search_str, chunk_clean)
if not found:
continue
offset = found.span()[0]
else:
if quote_clean not in chunk_clean:
continue
offset = chunk_clean.index(quote_clean)
# Extracting the link from the offset
curr_link = None
for link_offset, link in chunk.source_links.items():
# Should always find one because offset is at least 0 and there must be a 0 link_offset
if int(link_offset) <= offset:
curr_link = link
else:
quotes_dict[quote] = {
DOCUMENT_ID: chunk.document_id,
SOURCE_LINK: curr_link,
SOURCE_TYPE: chunk.source_type,
}
break
quotes_dict[quote] = {
DOCUMENT_ID: chunk.document_id,
SOURCE_LINK: curr_link,
SOURCE_TYPE: chunk.source_type,
}
break
return quotes_dict
def process_answer(
answer_raw: str, chunks: list[InferenceChunk]
) -> Tuple[Optional[str], Optional[Dict[str, Dict[str, Union[str, int, None]]]]]:
answer, quote_strings = separate_answer_quotes(answer_raw)
if not answer or not quote_strings:
return None, None
quotes_dict = match_quotes_to_docs(quote_strings, chunks)
return answer, quotes_dict

View File

@ -0,0 +1,89 @@
from typing import List
import openai
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import MODEL_CACHE_FOLDER
from danswer.configs.model_configs import QUERY_EMBEDDING_CONTEXT_SIZE
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logging import setup_logger
from danswer.utils.timing import build_timing_wrapper
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
logger = setup_logger()
openai.api_key = OPENAI_API_KEY
embedding_model = SentenceTransformer(
DOCUMENT_ENCODER_MODEL, cache_folder=MODEL_CACHE_FOLDER
)
embedding_model.max_seq_length = QUERY_EMBEDDING_CONTEXT_SIZE
cross_encoder = CrossEncoder(CROSS_ENCODER_MODEL)
cross_encoder.max_length = CROSS_EMBED_CONTEXT_SIZE
@build_timing_wrapper()
def semantic_retrival(
qdrant_collection: str,
query: str,
num_hits: int = NUM_RETURNED_HITS,
use_openai: bool = False,
) -> List[InferenceChunk]:
if use_openai:
query_embedding = openai.Embedding.create(
input=query, model="text-embedding-ada-002"
)["data"][0]["embedding"]
else:
query_embedding = embedding_model.encode(query)
hits = get_qdrant_client().search(
collection_name=qdrant_collection,
query_vector=query_embedding
if isinstance(query_embedding, list)
else query_embedding.tolist(),
query_filter=None,
limit=num_hits,
)
retrieved_chunks = []
for hit in hits:
payload = hit.payload
retrieved_chunks.append(InferenceChunk.from_dict(payload))
return retrieved_chunks
@build_timing_wrapper()
def semantic_reranking(
query: str,
chunks: List[InferenceChunk],
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
) -> List[InferenceChunk]:
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks])
scored_results = list(zip(sim_scores, chunks))
scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_chunks = zip(*scored_results)
logger.debug(
f"Reranked similarity scores: {str(ranked_sim_scores[:filtered_result_set_size])}"
)
return ranked_chunks[:filtered_result_set_size]
def semantic_search(
qdrant_collection: str,
query: str,
num_hits: int = NUM_RETURNED_HITS,
filtered_result_set_size: int = NUM_RERANKED_RESULTS,
) -> List[InferenceChunk]:
top_chunks = semantic_retrival(qdrant_collection, query, num_hits)
ranked_chunks = semantic_reranking(query, top_chunks, filtered_result_set_size)
return ranked_chunks

View File

@ -0,0 +1,15 @@
import os
from danswer.dynamic_configs.file_system.store import (
FileSystemBackedDynamicConfigStore,
)
from danswer.dynamic_configs.interface import DynamicConfigStore
def get_dynamic_config_store() -> DynamicConfigStore:
dynamic_config_store_type = os.environ.get("DYNAMIC_CONFIG_STORE")
if dynamic_config_store_type == FileSystemBackedDynamicConfigStore.__name__:
return FileSystemBackedDynamicConfigStore(os.environ["DYNAMIC_CONFIG_DIR_PATH"])
# TODO: change exception type
raise Exception("Unknown dynamic config store type")

View File

@ -0,0 +1,38 @@
import json
from pathlib import Path
from typing import cast
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.dynamic_configs.interface import DynamicConfigStore
from danswer.dynamic_configs.interface import JSON_ro
from filelock import FileLock
FILE_LOCK_TIMEOUT = 10
def _get_file_lock(file_name: Path) -> FileLock:
return FileLock(file_name.with_suffix(".lock"))
class FileSystemBackedDynamicConfigStore(DynamicConfigStore):
def __init__(self, dir_path: str) -> None:
# TODO (chris): maybe require all possible keys to be passed in
# at app start somehow to prevent key overlaps
self.dir_path = Path(dir_path)
def store(self, key: str, val: JSON_ro) -> None:
file_path = self.dir_path / key
lock = _get_file_lock(file_path)
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
with open(file_path, "w+") as f:
json.dump(val, f)
def load(self, key: str) -> JSON_ro:
file_path = self.dir_path / key
if not file_path.exists():
raise ConfigNotFoundError
lock = _get_file_lock(file_path)
with lock.acquire(timeout=FILE_LOCK_TIMEOUT):
with open(self.dir_path / key) as f:
return cast(JSON_ro, json.load(f))

View File

@ -0,0 +1,23 @@
import abc
from collections.abc import Mapping
from collections.abc import Sequence
from typing import TypeAlias
JSON_ro: TypeAlias = (
Mapping[str, "JSON_ro"] | Sequence["JSON_ro"] | str | int | float | bool | None
)
class ConfigNotFoundError(Exception):
pass
class DynamicConfigStore:
@abc.abstractmethod
def store(self, key: str, val: JSON_ro) -> None:
raise NotImplementedError
@abc.abstractmethod
def load(self, key: str) -> JSON_ro:
raise NotImplementedError

View File

View File

@ -0,0 +1,55 @@
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.embedding.type_aliases import Embedder
from danswer.utils.logging import setup_logger
from sentence_transformers import SentenceTransformer # type: ignore
logger = setup_logger()
_MODEL: None | SentenceTransformer = None
def get_default_model() -> SentenceTransformer:
global _MODEL
if _MODEL is None:
_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
return _MODEL
def encode_chunks(
chunks: list[IndexChunk],
embedding_model: SentenceTransformer | None = None,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
) -> list[EmbeddedIndexChunk]:
embedded_chunks = []
if embedding_model is None:
embedding_model = get_default_model()
chunk_batches = [
chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)
]
for batch_ind, chunk_batch in enumerate(chunk_batches):
embeddings_batch = embedding_model.encode(
[chunk.content for chunk in chunk_batch]
)
embedded_chunks.extend(
[
EmbeddedIndexChunk(
**{k: getattr(chunk, k) for k in chunk.__dataclass_fields__},
embedding=embeddings_batch[i].tolist()
)
for i, chunk in enumerate(chunk_batch)
]
)
return embedded_chunks
class DefaultEmbedder(Embedder):
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
return encode_chunks(chunks)

View File

@ -0,0 +1,7 @@
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
class Embedder:
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
raise NotImplementedError

35
backend/danswer/main.py Normal file
View File

@ -0,0 +1,35 @@
import uvicorn
from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
from danswer.server.admin import router as admin_router
from danswer.server.event_loading import router as event_processing_router
from danswer.server.search_backend import router as backend_router
from danswer.utils.logging import setup_logger
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
logger = setup_logger()
def get_application() -> FastAPI:
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1")
application.include_router(backend_router)
application.include_router(event_processing_router)
application.include_router(admin_router)
return application
app = get_application()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change this to the list of allowed origins if needed
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if __name__ == "__main__":
logger.info(f"Running QA Service on http://{APP_HOST}:{str(APP_PORT)}/")
uvicorn.run(app, host=APP_HOST, port=APP_PORT)

View File

View File

@ -0,0 +1,23 @@
from danswer.connectors.slack.config import get_slack_config
from danswer.connectors.slack.config import SlackConfig
from danswer.connectors.slack.config import update_slack_config
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logging import setup_logger
from fastapi import APIRouter
router = APIRouter(prefix="/admin")
logger = setup_logger()
@router.get("/slack_connector_config", response_model=SlackConfig)
def fetch_slack_config():
try:
return get_slack_config()
except ConfigNotFoundError:
return SlackConfig(slack_bot_token="", workspace_id="")
@router.post("/slack_connector_config")
def modify_slack_config(slack_config: SlackConfig):
update_slack_config(slack_config)

View File

@ -0,0 +1,61 @@
from typing import Any
from danswer.connectors.slack.pull import get_thread
from danswer.connectors.slack.pull import thread_to_doc
from danswer.connectors.slack.utils import get_client
from danswer.utils.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logging import setup_logger
from fastapi import APIRouter
from pydantic import BaseModel
from pydantic import Extra
router = APIRouter()
logger = setup_logger()
class SlackEvent(BaseModel, extra=Extra.allow):
type: str
challenge: str | None
event: dict[str, Any] | None
class EventHandlingResponse(BaseModel):
challenge: str | None
@router.post("/process_slack_event", response_model=EventHandlingResponse)
def process_slack_event(event: SlackEvent):
logger.info("Recieved slack event: %s", event.dict())
if event.type == "url_verification":
return {"challenge": event.challenge}
if event.type == "event_callback" and event.event:
try:
# TODO: process in the background as per slack guidelines
message_type = event.event.get("subtype")
if message_type == "message_changed":
message = event.event["message"]
else:
message = event.event
channel_id = event.event["channel"]
thread_ts = message.get("thread_ts")
doc = thread_to_doc(
channel_id,
get_thread(get_client(), channel_id, thread_ts)
if thread_ts
else [message],
)
if doc is None:
logger.info("Message was determined to not be indexable")
return {}
build_indexing_pipeline()([doc])
except Exception:
logger.exception("Failed to process slack message")
return {}
logger.error("Unsupported event type: %s", event.type)
return {}

View File

@ -0,0 +1,109 @@
import time
from http import HTTPStatus
from typing import Dict
from typing import List
from typing import Union
from danswer.configs.app_configs import DEFAULT_PROMPT
from danswer.configs.app_configs import KEYWORD_MAX_HITS
from danswer.configs.constants import CONTENT
from danswer.configs.constants import SOURCE_LINKS
from danswer.direct_qa.qa_prompts import BASIC_QA_PROMPTS
from danswer.direct_qa.question_answer import answer_question
from danswer.direct_qa.question_answer import process_answer
from danswer.direct_qa.semantic_search import semantic_search
from danswer.utils.clients import TSClient
from danswer.utils.logging import setup_logger
from fastapi import APIRouter
from pydantic import BaseModel
logger = setup_logger()
router = APIRouter()
class ServerStatus(BaseModel):
status: str
class QAQuestion(BaseModel):
query: str
collection: str
class QAResponse(BaseModel):
answer: Union[str, None]
quotes: Union[Dict[str, Dict[str, str]], None]
class KeywordResponse(BaseModel):
results: Union[List[str], None]
@router.get("/", response_model=ServerStatus)
@router.get("/status", response_model=ServerStatus)
def read_server_status():
return {"status": HTTPStatus.OK.value}
@router.post("/direct-qa", response_model=QAResponse)
def direct_qa(question: QAQuestion):
prompt_processor = BASIC_QA_PROMPTS[DEFAULT_PROMPT]
query = question.query
collection = question.collection
logger.info(f"Received semantic query: {query}")
start_time = time.time()
ranked_chunks = semantic_search(collection, query)
sem_search_time = time.time()
top_docs = [ranked_chunk.document_id for ranked_chunk in ranked_chunks]
top_contents = [ranked_chunk.content for ranked_chunk in ranked_chunks]
logger.info(f"Semantic search took {sem_search_time - start_time} seconds")
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
logger.info(files_log_msg)
qa_answer = answer_question(query, top_contents, prompt_processor)
qa_time = time.time()
logger.debug(qa_answer)
logger.info(f"GPT QA took {qa_time - sem_search_time} seconds")
# Postprocessing, no more models involved, purely rule based
answer, quotes_dict = process_answer(qa_answer, ranked_chunks)
postprocess_time = time.time()
logger.info(f"Postprocessing took {postprocess_time - qa_time} seconds")
total_time = time.time() - start_time
logger.info(f"Total QA took {total_time} seconds")
qa_response = {"answer": answer, "quotes": quotes_dict}
return qa_response
@router.post("/keyword-search", response_model=KeywordResponse)
def keyword_search(question: QAQuestion):
ts_client = TSClient.get_instance()
query = question.query
collection = question.collection
logger.info(f"Received keyword query: {query}")
start_time = time.time()
search_results = ts_client.collections[collection].documents.search(
{
"q": query,
"query_by": CONTENT,
"per_page": KEYWORD_MAX_HITS,
"limit_hits": KEYWORD_MAX_HITS,
}
)
hits = search_results["hits"]
sources = [hit["document"][SOURCE_LINKS][0] for hit in hits]
total_time = time.time() - start_time
logger.info(f"Total Keyword Search took {total_time} seconds")
return {"results": sources}

View File

View File

@ -0,0 +1,62 @@
from typing import Optional
import typesense # type: ignore
from danswer.configs.app_configs import DB_CONN_TIMEOUT
from danswer.configs.app_configs import QDRANT_API_KEY
from danswer.configs.app_configs import QDRANT_HOST
from danswer.configs.app_configs import QDRANT_PORT
from danswer.configs.app_configs import QDRANT_URL
from danswer.configs.app_configs import TYPESENSE_API_KEY
from danswer.configs.app_configs import TYPESENSE_HOST
from danswer.configs.app_configs import TYPESENSE_PORT
from qdrant_client import QdrantClient
_qdrant_client: Optional[QdrantClient] = None
def get_qdrant_client() -> QdrantClient:
global _qdrant_client
if _qdrant_client is None:
if QDRANT_URL and QDRANT_API_KEY:
_qdrant_client = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API_KEY)
elif QDRANT_HOST and QDRANT_PORT:
_qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT)
else:
raise Exception("Unable to instantiate QdrantClient")
return _qdrant_client
class TSClient:
__instance: Optional["TSClient"] = None
@staticmethod
def get_instance(
host=TYPESENSE_HOST,
port=TYPESENSE_PORT,
api_key=TYPESENSE_API_KEY,
timeout=DB_CONN_TIMEOUT,
) -> "TSClient":
if TSClient.__instance is None:
TSClient(host, port, api_key, timeout)
return TSClient.__instance # type: ignore
def __init__(self, host, port, api_key, timeout):
if TSClient.__instance is not None:
raise Exception(
"Singleton instance already exists. Use TSClient.get_instance() to get the instance."
)
else:
TSClient.__instance = self
self.client = typesense.Client(
{
"api_key": api_key,
"nodes": [{"host": host, "port": str(port), "protocol": "http"}],
"connection_timeout_seconds": timeout,
}
)
# delegate all client operations to the third party client
def __getattr__(self, name):
return getattr(self.client, name)

View File

@ -0,0 +1,43 @@
from collections.abc import Callable
from functools import partial
from itertools import chain
from danswer.chunking.chunk import Chunker
from danswer.chunking.chunk import DefaultChunker
from danswer.connectors.models import Document
from danswer.datastores.interfaces import Datastore
from danswer.datastores.qdrant.store import QdrantDatastore
from danswer.embedding.biencoder import DefaultEmbedder
from danswer.embedding.type_aliases import Embedder
def _indexing_pipeline(
chunker: Chunker,
embedder: Embedder,
datastore: Datastore,
documents: list[Document],
) -> None:
chunks = list(chain(*[chunker.chunk(document) for document in documents]))
chunks_with_embeddings = embedder.embed(chunks)
datastore.index(chunks_with_embeddings)
def build_indexing_pipeline(
*,
chunker: Chunker | None = None,
embedder: Embedder | None = None,
datastore: Datastore | None = None,
) -> Callable[[list[Document]], None]:
"""Builds a pipline which takes in a list of docs and indexes them.
Default uses _ chunker, _ embedder, and qdrant for the datastore"""
if chunker is None:
chunker = DefaultChunker()
if embedder is None:
embedder = DefaultEmbedder()
if datastore is None:
datastore = QdrantDatastore()
return partial(_indexing_pipeline, chunker, embedder, datastore)

View File

@ -0,0 +1,24 @@
import logging
def setup_logger(name=__name__, log_level=logging.INFO):
logger = logging.getLogger(name)
# If the logger already has handlers, assume it was already configured and return it.
if logger.handlers:
return logger
logger.setLevel(log_level)
formatter = logging.Formatter(
"%(asctime)s %(filename)20s%(lineno)4s : %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
handler = logging.StreamHandler()
handler.setLevel(log_level)
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger

View File

@ -0,0 +1,19 @@
def clean_model_quote(quote: str, trim_length: int) -> str:
quote_clean = quote.strip()
if quote_clean[0] == '"':
quote_clean = quote_clean[1:]
if quote_clean[-1] == '"':
quote_clean = quote_clean[:-1]
if trim_length > 0:
quote_clean = quote_clean[:trim_length]
return quote_clean
def shared_precompare_cleanup(text: str) -> str:
text = text.lower()
text = "".join(
text.split()
) # GPT models like to return cleaner spacing, not good for quote matching
return text.replace(
"*", ""
) # GPT models sometimes like to cleanup bulletpoints represented by *

View File

@ -0,0 +1,31 @@
import time
from collections.abc import Callable
from danswer.utils.logging import setup_logger
logger = setup_logger()
def build_timing_wrapper(
func_name: str | None = None,
) -> Callable[[Callable], Callable]:
"""Build a timing wrapper for a function. Logs how long the function took to run.
Use like:
@build_timing_wrapper()
def my_func():
...
"""
def timing_wrapper(func: Callable) -> Callable:
def wrapped_func(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
logger.info(
f"{func_name or func.__name__} took {time.time() - start_time} seconds"
)
return result
return wrapped_func
return timing_wrapper

View File

@ -0,0 +1,5 @@
This serves as an example for how to deploy everything in a single box. This is far
from optimal, but can get you started easily and cheaply. To run:
1. Set up a `.env` file in this directory with relevant environment variables (TODO: document)
2. `docker compose up -d --build`

View File

@ -0,0 +1,43 @@
upstream app_server {
# fail_timeout=0 means we always retry an upstream even if it failed
# to return a good HTTP response
# for UNIX domain socket setups
#server unix:/tmp/gunicorn.sock fail_timeout=0;
# for a TCP configuration
server api:8080 fail_timeout=0;
}
server {
listen 80;
server_name api.danswer.dev;
location / {
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header Host $http_host;
# we don't want nginx trying to do something clever with
# redirects, we set the Host: header above already.
proxy_redirect off;
proxy_pass http://app_server;
}
location /.well-known/acme-challenge/ {
root /var/www/certbot;
}
}
server {
listen 443 ssl;
server_name api.danswer.dev;
location / {
proxy_pass http://api.danswer.dev;
}
ssl_certificate /etc/letsencrypt/live/api.danswer.dev/fullchain.pem;
ssl_certificate_key /etc/letsencrypt/live/api.danswer.dev/privkey.pem;
include /etc/letsencrypt/options-ssl-nginx.conf;
ssl_dhparam /etc/letsencrypt/ssl-dhparams.pem;
}

View File

@ -0,0 +1,42 @@
# follows https://pentacent.medium.com/nginx-and-lets-encrypt-with-docker-in-less-than-5-minutes-b4b8a60d3a71
version: '3'
services:
api:
build:
context: ..
dockerfile: Dockerfile
# just for local testing
ports:
- "8080:8080"
env_file:
- .env
volumes:
- local_dynamic_storage:/home/storage
background:
build:
context: ..
dockerfile: Dockerfile.background
env_file:
- .env
volumes:
- local_dynamic_storage:/home/storage
nginx:
image: nginx:1.23.4-alpine
ports:
- "80:80"
- "443:443"
volumes:
- ./data/nginx:/etc/nginx/conf.d
- ./data/certbot/conf:/etc/letsencrypt
- ./data/certbot/www:/var/www/certbot
command: "/bin/sh -c 'while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\"'"
depends_on:
- api
certbot:
image: certbot/certbot
volumes:
- ./data/certbot/conf:/etc/letsencrypt
- ./data/certbot/www:/var/www/certbot
entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'"
volumes:
local_dynamic_storage:

View File

@ -0,0 +1,80 @@
#!/bin/bash
if ! [ -x "$(command -v docker compose)" ]; then
echo 'Error: docker compose is not installed.' >&2
exit 1
fi
domains=(api.danswer.dev www.api.danswer.dev)
rsa_key_size=4096
data_path="./data/certbot"
email="" # Adding a valid address is strongly recommended
staging=0 # Set to 1 if you're testing your setup to avoid hitting request limits
if [ -d "$data_path" ]; then
read -p "Existing data found for $domains. Continue and replace existing certificate? (y/N) " decision
if [ "$decision" != "Y" ] && [ "$decision" != "y" ]; then
exit
fi
fi
if [ ! -e "$data_path/conf/options-ssl-nginx.conf" ] || [ ! -e "$data_path/conf/ssl-dhparams.pem" ]; then
echo "### Downloading recommended TLS parameters ..."
mkdir -p "$data_path/conf"
curl -s https://raw.githubusercontent.com/certbot/certbot/master/certbot-nginx/certbot_nginx/_internal/tls_configs/options-ssl-nginx.conf > "$data_path/conf/options-ssl-nginx.conf"
curl -s https://raw.githubusercontent.com/certbot/certbot/master/certbot/certbot/ssl-dhparams.pem > "$data_path/conf/ssl-dhparams.pem"
echo
fi
echo "### Creating dummy certificate for $domains ..."
path="/etc/letsencrypt/live/$domains"
mkdir -p "$data_path/conf/live/$domains"
docker compose run --rm --entrypoint "\
openssl req -x509 -nodes -newkey rsa:$rsa_key_size -days 1\
-keyout '$path/privkey.pem' \
-out '$path/fullchain.pem' \
-subj '/CN=localhost'" certbot
echo
echo "### Starting nginx ..."
docker compose up --force-recreate -d nginx
echo
echo "### Deleting dummy certificate for $domains ..."
docker compose run --rm --entrypoint "\
rm -Rf /etc/letsencrypt/live/$domains && \
rm -Rf /etc/letsencrypt/archive/$domains && \
rm -Rf /etc/letsencrypt/renewal/$domains.conf" certbot
echo
echo "### Requesting Let's Encrypt certificate for $domains ..."
#Join $domains to -d args
domain_args=""
for domain in "${domains[@]}"; do
domain_args="$domain_args -d $domain"
done
# Select appropriate email arg
case "$email" in
"") email_arg="--register-unsafely-without-email" ;;
*) email_arg="--email $email" ;;
esac
# Enable staging mode if needed
if [ $staging != "0" ]; then staging_arg="--staging"; fi
docker compose run --rm --entrypoint "\
certbot certonly --webroot -w /var/www/certbot \
$staging_arg \
$email_arg \
$domain_args \
--rsa-key-size $rsa_key_size \
--agree-tos \
--force-renewal" certbot
echo
echo "### Reloading nginx ..."
docker compose exec nginx nginx -s reload

View File

@ -0,0 +1,2 @@
aws-cdk-lib>=2.0.0
constructs>=10.0.0

View File

@ -0,0 +1,23 @@
beautifulsoup4==4.12.0
fastapi==0.95.0
filelock==3.12.0
google-api-python-client==2.86.0
google-auth-httplib2==0.1.0
google-auth-oauthlib==1.0.0
openai==0.27.2
playwright==1.32.1
pydantic==1.10.7
PyPDF2==3.0.1
pytest-playwright==0.3.2
qdrant-client==1.1.0
requests==2.28.2
sentence-transformers==2.2.2
slack-sdk==3.20.2
transformers==4.27.3
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13
types-regex-2023.3.23.1
types-requests==2.28.11.17
types-urllib3==1.26.25.11
typesense==0.15.1
uvicorn==0.21.1

View File

@ -0,0 +1,10 @@
mypy==1.1.1
mypy-extensions==1.0.0
black==23.3.0
reorder-python-imports==3.9.0
pre-commit==3.2.2
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13
types-requests==2.28.11.17
types-urllib3==1.26.25.11
types-regex==2023.3.23.1

View File

@ -0,0 +1,114 @@
import argparse
from itertools import chain
from danswer.chunking.chunk import Chunker
from danswer.chunking.chunk import DefaultChunker
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.connectors.google_drive.batch import BatchGoogleDriveLoader
from danswer.connectors.slack.batch import BatchSlackLoader
from danswer.connectors.type_aliases import BatchLoader
from danswer.connectors.web.batch import BatchWebLoader
from danswer.datastores.interfaces import Datastore
from danswer.datastores.qdrant.indexing import recreate_collection
from danswer.datastores.qdrant.store import QdrantDatastore
from danswer.embedding.biencoder import DefaultEmbedder
from danswer.embedding.type_aliases import Embedder
from danswer.utils.logging import setup_logger
logger = setup_logger()
def load_batch(
doc_loader: BatchLoader,
chunker: Chunker,
embedder: Embedder,
datastore: Datastore,
):
num_processed = 0
total_chunks = 0
for document_batch in doc_loader.load():
if not document_batch:
logger.warning("No parseable documents found in batch")
continue
logger.info(f"Indexed {num_processed} documents")
document_chunks = list(
chain(*[chunker.chunk(document) for document in document_batch])
)
num_chunks = len(document_chunks)
total_chunks += num_chunks
logger.info(
f"Document batch yielded {num_chunks} chunks for a total of {total_chunks}"
)
chunks_with_embeddings = embedder.embed(document_chunks)
datastore.index(chunks_with_embeddings)
num_processed += len(document_batch)
logger.info(f"Finished, indexed a total of {num_processed} documents")
def load_slack_batch(file_path: str, qdrant_collection: str):
logger.info("Loading documents from Slack.")
load_batch(
BatchSlackLoader(export_path_str=file_path, batch_size=INDEX_BATCH_SIZE),
DefaultChunker(),
DefaultEmbedder(),
QdrantDatastore(collection=qdrant_collection),
)
def load_web_batch(url: str, qdrant_collection: str):
logger.info("Loading documents from web.")
load_batch(
BatchWebLoader(base_url=url, batch_size=INDEX_BATCH_SIZE),
DefaultChunker(),
DefaultEmbedder(),
QdrantDatastore(collection=qdrant_collection),
)
def load_google_drive_batch(qdrant_collection: str):
logger.info("Loading documents from Google Drive.")
load_batch(
BatchGoogleDriveLoader(batch_size=INDEX_BATCH_SIZE),
DefaultChunker(),
DefaultEmbedder(),
QdrantDatastore(collection=qdrant_collection),
)
class BatchLoadingArgs(argparse.Namespace):
slack_export_dir: str
website_url: str
qdrant_collection: str
rebuild_index: bool
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--website-url",
default="https://docs.github.com/en/actions",
)
parser.add_argument(
"--slack-export-dir",
default="/Users/chrisweaver/Downloads/test-slack-export",
)
parser.add_argument(
"--qdrant-collection",
default=QDRANT_DEFAULT_COLLECTION,
)
parser.add_argument(
"--rebuild-index",
action="store_true",
help="Deletes and repopulates the semantic search index",
)
args = parser.parse_args(namespace=BatchLoadingArgs)
if args.rebuild_index:
recreate_collection(args.qdrant_collection)
#load_slack_batch(args.slack_export_dir, args.qdrant_collection)
load_web_batch(args.website_url, args.qdrant_collection)
#load_google_drive_batch(args.qdrant_collection)

View File

@ -0,0 +1,58 @@
import json
import requests
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.constants import SOURCE_TYPE
if __name__ == "__main__":
previous_query = None
while True:
keyword_search = False
query = input(
"\n\nAsk any question:\n - prefix with -k for keyword search\n - input an empty string to "
"rerun last query\n\t"
)
if query.lower() in ["q", "quit", "exit", "exit()"]:
break
if query:
previous_query = query
else:
if not previous_query:
print("No previous query")
continue
print(f"Re-executing previous question:\n\t{previous_query}")
query = previous_query
endpoint = f"http://127.0.0.1:{APP_PORT}/direct-qa"
if query.startswith("-k "):
keyword_search = True
query = query[2:]
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
response = requests.post(
endpoint, json={"query": query, "collection": QDRANT_DEFAULT_COLLECTION}
)
contents = json.loads(response.content)
if keyword_search:
if contents["results"]:
for link in contents["results"]:
print(link)
else:
print("No matches found")
else:
answer = contents.get("answer")
if answer:
print("Answer: " + answer)
else:
print("Answer: ?")
if contents.get("quotes"):
for ind, (quote, quote_info) in enumerate(contents["quotes"].items()):
print(f"Quote {str(ind)}:\n{quote}")
print(f"Link: {quote_info['link']}")
print(f"Source: {quote_info[SOURCE_TYPE]}")
else:
print("No quotes found")

View File

@ -0,0 +1,84 @@
import unittest
from danswer.chunking.chunk import chunk_document
from danswer.chunking.chunk import chunk_large_section
from danswer.connectors.models import Document
from danswer.connectors.models import Section
WAR_AND_PEACE = (
"Well, Prince, so Genoa and Lucca are now just family estates of the Buonapartes. But I warn you, "
"if you dont tell me that this means war, if you still try to defend the infamies and horrors perpetrated by "
"that Antichrist—I really believe he is Antichrist—I will have nothing more to do with you and you are no longer "
"my friend, no longer my faithful slave, as you call yourself! But how do you do? I see I have frightened "
"you—sit down and tell me all the news."
)
class TestDocumentChunking(unittest.TestCase):
def setUp(self):
self.large_section = Section(text=WAR_AND_PEACE, link="https://www.test.com/")
self.document = Document(
id="test_document",
metadata={"source_type": "testing"},
sections=[
Section(
text="Here is some testing text", link="https://www.test.com/0"
),
Section(
text="Some more text, still under 100 chars",
link="https://www.test.com/1",
),
Section(
text="Now with this section it's longer than the chunk size",
link="https://www.test.com/2",
),
self.large_section,
Section(text="These last 2 sections", link="https://www.test.com/4"),
Section(
text="should be combined into one", link="https://www.test.com/5"
),
],
)
def test_chunk_large_section(self):
chunks = chunk_large_section(
section=self.large_section,
document=self.document,
start_chunk_id=5,
chunk_size=100,
word_overlap=3,
)
self.assertEqual(len(chunks), 5)
self.assertEqual(chunks[0].content, WAR_AND_PEACE[:99])
self.assertEqual(
chunks[-2].content, WAR_AND_PEACE[-176:-63]
) # slightly longer than 100 due to overlap
self.assertEqual(
chunks[-1].content, WAR_AND_PEACE[-121:]
) # large overlap with second to last segment
self.assertFalse(chunks[0].section_continuation)
self.assertTrue(chunks[1].section_continuation)
self.assertTrue(chunks[-1].section_continuation)
def test_chunk_document(self):
chunks = chunk_document(self.document, chunk_size=100, subsection_overlap=3)
self.assertEqual(len(chunks), 8)
self.assertEqual(
chunks[0].content,
self.document.sections[0].text + "\n\n" + self.document.sections[1].text,
)
self.assertEqual(
chunks[0].source_links,
{0: "https://www.test.com/0", 21: "https://www.test.com/1"},
)
self.assertEqual(
chunks[-1].source_links,
{0: "https://www.test.com/4", 18: "https://www.test.com/5"},
)
self.assertEqual(chunks[5].chunk_id, 5)
self.assertEqual(chunks[6].source_document, self.document)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,178 @@
import textwrap
import unittest
from danswer.chunking.models import InferenceChunk
from danswer.direct_qa.question_answer import match_quotes_to_docs
from danswer.direct_qa.question_answer import separate_answer_quotes
class TestQAPostprocessing(unittest.TestCase):
def test_separate_answer_quotes(self):
test_answer = textwrap.dedent(
"""
It seems many people love dogs
Quote: A dog is a man's best friend
Quote: Air Bud was a movie about dogs and people loved it
"""
).strip()
answer, quotes = separate_answer_quotes(test_answer)
self.assertEqual(answer, "It seems many people love dogs")
self.assertEqual(quotes[0], "A dog is a man's best friend")
self.assertEqual(
quotes[1], "Air Bud was a movie about dogs and people loved it"
)
# Lowercase should be allowed
test_answer = textwrap.dedent(
"""
It seems many people love dogs
quote: A dog is a man's best friend
Quote: Air Bud was a movie about dogs and people loved it
"""
).strip()
answer, quotes = separate_answer_quotes(test_answer)
self.assertEqual(answer, "It seems many people love dogs")
self.assertEqual(quotes[0], "A dog is a man's best friend")
self.assertEqual(
quotes[1], "Air Bud was a movie about dogs and people loved it"
)
# No Answer
test_answer = textwrap.dedent(
"""
Quote: This one has no answer
"""
).strip()
answer, quotes = separate_answer_quotes(test_answer)
self.assertIsNone(answer)
self.assertIsNone(quotes)
# Multiline Quote
test_answer = textwrap.dedent(
"""
It seems many people love dogs
quote: A well known saying is:
A dog is a man's best friend
Quote: Air Bud was a movie about dogs and people loved it
"""
).strip()
answer, quotes = separate_answer_quotes(test_answer)
self.assertEqual(answer, "It seems many people love dogs")
self.assertEqual(
quotes[0], "A well known saying is:\nA dog is a man's best friend"
)
self.assertEqual(
quotes[1], "Air Bud was a movie about dogs and people loved it"
)
# Random patterns not picked up
test_answer = textwrap.dedent(
"""
It seems many people love quote: dogs
quote: Quote: A well known saying is:
A dog is a man's best friend
Quote: Answer: Air Bud was a movie about dogs and quote: people loved it
"""
).strip()
answer, quotes = separate_answer_quotes(test_answer)
self.assertEqual(answer, "It seems many people love quote: dogs")
self.assertEqual(
quotes[0], "Quote: A well known saying is:\nA dog is a man's best friend"
)
self.assertEqual(
quotes[1],
"Answer: Air Bud was a movie about dogs and quote: people loved it",
)
def test_fuzzy_match_quotes_to_docs(self):
chunk_0_text = textwrap.dedent(
"""
Here's a doc with some LINK embedded in the text
THIS SECTION IS A LINK
Some more text
"""
).strip()
chunk_1_text = textwrap.dedent(
"""
Some completely different text here
ANOTHER LINK embedded in this text
ending in a DIFFERENT-LINK
"""
).strip()
test_chunk_0 = InferenceChunk(
document_id="test doc 0",
source_type="testing",
chunk_id=0,
content=chunk_0_text,
source_links={
0: "doc 0 base",
23: "first line link",
49: "second line link",
},
section_continuation=False,
)
test_chunk_1 = InferenceChunk(
document_id="test doc 1",
source_type="testing",
chunk_id=0,
content=chunk_1_text,
source_links={0: "doc 1 base", 36: "2nd line link", 82: "last link"},
section_continuation=False,
)
test_quotes = [
"a doc with some", # Basic case
"a doc with some LINK", # Should take the start of quote, even if a link is in it
"a doc with some \nLINK", # Requires a newline deletion fuzzy match
"a doc with some link", # Capitalization insensitive
"embedded in this text", # Fuzzy match to first doc
"SECTION IS A LINK", # Match exact link
"some more text", # Match the end, after every link offset
"different taxt", # Substitution
"embedded in this texts", # Cannot fuzzy match to first doc, fuzzy match to second doc
"DIFFERENT-LINK", # Exact link match at the end
"Some complitali", # Too many edits, shouldn't match anything
]
results = match_quotes_to_docs(
test_quotes, [test_chunk_0, test_chunk_1], fuzzy_search=True
)
self.assertEqual(
results,
{
"a doc with some": {"document": "test doc 0", "link": "doc 0 base"},
"a doc with some LINK": {
"document": "test doc 0",
"link": "doc 0 base",
},
"a doc with some \nLINK": {
"document": "test doc 0",
"link": "doc 0 base",
},
"a doc with some link": {
"document": "test doc 0",
"link": "doc 0 base",
},
"embedded in this text": {
"document": "test doc 0",
"link": "first line link",
},
"SECTION IS A LINK": {
"document": "test doc 0",
"link": "second line link",
},
"some more text": {
"document": "test doc 0",
"link": "second line link",
},
"different taxt": {"document": "test doc 1", "link": "doc 1 base"},
"embedded in this texts": {
"document": "test doc 1",
"link": "2nd line link",
},
"DIFFERENT-LINK": {"document": "test doc 1", "link": "last link"},
},
)
if __name__ == "__main__":
unittest.main()

3
web/.eslintrc.json Normal file
View File

@ -0,0 +1,3 @@
{
"extends": "next/core-web-vitals"
}

36
web/.gitignore vendored Normal file
View File

@ -0,0 +1,36 @@
# See https://help.github.com/articles/ignoring-files/ for more about ignoring files.
# dependencies
/node_modules
/.pnp
.pnp.js
# testing
/coverage
# next.js
/.next/
/out/
# production
/build
# misc
.DS_Store
*.pem
# debug
npm-debug.log*
yarn-debug.log*
yarn-error.log*
.pnpm-debug.log*
# local env files
.env*.local
# vercel
.vercel
# typescript
*.tsbuildinfo
next-env.d.ts

6
web/.prettierignore Normal file
View File

@ -0,0 +1,6 @@
**/.git
**/.svn
**/.hg
**/node_modules
**/.next
**/.vscode

32
web/README.md Normal file
View File

@ -0,0 +1,32 @@
This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app).
## Getting Started
First, run the development server:
```bash
npm run dev
# or
yarn dev
# or
pnpm dev
```
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
This project uses [`next/font`](https://nextjs.org/docs/basic-features/font-optimization) to automatically optimize and load Inter, a custom Google Font.
## Learn More
To learn more about Next.js, take a look at the following resources:
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js/) - your feedback and contributions are welcome!
## Deploy on Vercel
The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
Check out our [Next.js deployment documentation](https://nextjs.org/docs/deployment) for more details.

8
web/next.config.js Normal file
View File

@ -0,0 +1,8 @@
/** @type {import('next').NextConfig} */
const nextConfig = {
experimental: {
appDir: true,
},
};
module.exports = nextConfig;

4107
web/package-lock.json generated Normal file

File diff suppressed because it is too large Load Diff

28
web/package.json Normal file
View File

@ -0,0 +1,28 @@
{
"name": "qa",
"version": "0.1.0",
"private": true,
"scripts": {
"dev": "next dev",
"build": "next build",
"start": "next start",
"lint": "next lint"
},
"dependencies": {
"@phosphor-icons/react": "^2.0.8",
"@types/node": "18.15.11",
"@types/react": "18.0.32",
"@types/react-dom": "18.0.11",
"autoprefixer": "^10.4.14",
"eslint": "8.37.0",
"eslint-config-next": "13.2.4",
"formik": "^2.2.9",
"next": "13.2.4",
"postcss": "^8.4.23",
"react": "18.2.0",
"react-dom": "18.2.0",
"tailwindcss": "^3.3.1",
"typescript": "5.0.3",
"yup": "^1.1.1"
}
}

6
web/postcss.config.js Normal file
View File

@ -0,0 +1,6 @@
module.exports = {
plugins: {
tailwindcss: {},
autoprefixer: {},
},
};

1
web/public/next.svg Normal file
View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 394 80"><path fill="#000" d="M262 0h68.5v12.7h-27.2v66.6h-13.6V12.7H262V0ZM149 0v12.7H94v20.4h44.3v12.6H94v21h55v12.6H80.5V0h68.7zm34.3 0h-17.8l63.8 79.4h17.9l-32-39.7 32-39.6h-17.9l-23 28.6-23-28.6zm18.3 56.7-9-11-27.1 33.7h17.8l18.3-22.7z"/><path fill="#000" d="M81 79.3 17 0H0v79.3h13.6V17l50.2 62.3H81Zm252.6-.4c-1 0-1.8-.4-2.5-1s-1.1-1.6-1.1-2.6.3-1.8 1-2.5 1.6-1 2.6-1 1.8.3 2.5 1a3.4 3.4 0 0 1 .6 4.3 3.7 3.7 0 0 1-3 1.8zm23.2-33.5h6v23.3c0 2.1-.4 4-1.3 5.5a9.1 9.1 0 0 1-3.8 3.5c-1.6.8-3.5 1.3-5.7 1.3-2 0-3.7-.4-5.3-1s-2.8-1.8-3.7-3.2c-.9-1.3-1.4-3-1.4-5h6c.1.8.3 1.6.7 2.2s1 1.2 1.6 1.5c.7.4 1.5.5 2.4.5 1 0 1.8-.2 2.4-.6a4 4 0 0 0 1.6-1.8c.3-.8.5-1.8.5-3V45.5zm30.9 9.1a4.4 4.4 0 0 0-2-3.3 7.5 7.5 0 0 0-4.3-1.1c-1.3 0-2.4.2-3.3.5-.9.4-1.6 1-2 1.6a3.5 3.5 0 0 0-.3 4c.3.5.7.9 1.3 1.2l1.8 1 2 .5 3.2.8c1.3.3 2.5.7 3.7 1.2a13 13 0 0 1 3.2 1.8 8.1 8.1 0 0 1 3 6.5c0 2-.5 3.7-1.5 5.1a10 10 0 0 1-4.4 3.5c-1.8.8-4.1 1.2-6.8 1.2-2.6 0-4.9-.4-6.8-1.2-2-.8-3.4-2-4.5-3.5a10 10 0 0 1-1.7-5.6h6a5 5 0 0 0 3.5 4.6c1 .4 2.2.6 3.4.6 1.3 0 2.5-.2 3.5-.6 1-.4 1.8-1 2.4-1.7a4 4 0 0 0 .8-2.4c0-.9-.2-1.6-.7-2.2a11 11 0 0 0-2.1-1.4l-3.2-1-3.8-1c-2.8-.7-5-1.7-6.6-3.2a7.2 7.2 0 0 1-2.4-5.7 8 8 0 0 1 1.7-5 10 10 0 0 1 4.3-3.5c2-.8 4-1.2 6.4-1.2 2.3 0 4.4.4 6.2 1.2 1.8.8 3.2 2 4.3 3.4 1 1.4 1.5 3 1.5 5h-5.8z"/></svg>

After

Width:  |  Height:  |  Size: 1.3 KiB

1
web/public/thirteen.svg Normal file
View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" width="40" height="31" fill="none"><g opacity=".9"><path fill="url(#a)" d="M13 .4v29.3H7V6.3h-.2L0 10.5V5L7.2.4H13Z"/><path fill="url(#b)" d="M28.8 30.1c-2.2 0-4-.3-5.7-1-1.7-.8-3-1.8-4-3.1a7.7 7.7 0 0 1-1.4-4.6h6.2c0 .8.3 1.4.7 2 .4.5 1 .9 1.7 1.2.7.3 1.6.4 2.5.4 1 0 1.7-.2 2.5-.5.7-.3 1.3-.8 1.7-1.4.4-.6.6-1.2.6-2s-.2-1.5-.7-2.1c-.4-.6-1-1-1.8-1.4-.8-.4-1.8-.5-2.9-.5h-2.7v-4.6h2.7a6 6 0 0 0 2.5-.5 4 4 0 0 0 1.7-1.3c.4-.6.6-1.3.6-2a3.5 3.5 0 0 0-2-3.3 5.6 5.6 0 0 0-4.5 0 4 4 0 0 0-1.7 1.2c-.4.6-.6 1.2-.6 2h-6c0-1.7.6-3.2 1.5-4.5 1-1.3 2.2-2.3 3.8-3C25 .4 26.8 0 28.8 0s3.8.4 5.3 1.1c1.5.7 2.7 1.7 3.6 3a7.2 7.2 0 0 1 1.2 4.2c0 1.6-.5 3-1.5 4a7 7 0 0 1-4 2.2v.2c2.2.3 3.8 1 5 2.2a6.4 6.4 0 0 1 1.6 4.6c0 1.7-.5 3.1-1.4 4.4a9.7 9.7 0 0 1-4 3.1c-1.7.8-3.7 1.1-5.8 1.1Z"/></g><defs><linearGradient id="a" x1="20" x2="20" y1="0" y2="30.1" gradientUnits="userSpaceOnUse"><stop/><stop offset="1" stop-color="#3D3D3D"/></linearGradient><linearGradient id="b" x1="20" x2="20" y1="0" y2="30.1" gradientUnits="userSpaceOnUse"><stop/><stop offset="1" stop-color="#3D3D3D"/></linearGradient></defs></svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

1
web/public/vercel.svg Normal file
View File

@ -0,0 +1 @@
<svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 283 64"><path fill="black" d="M141 16c-11 0-19 7-19 18s9 18 20 18c7 0 13-3 16-7l-7-5c-2 3-6 4-9 4-5 0-9-3-10-7h28v-3c0-11-8-18-19-18zm-9 15c1-4 4-7 9-7s8 3 9 7h-18zm117-15c-11 0-19 7-19 18s9 18 20 18c6 0 12-3 16-7l-8-5c-2 3-5 4-8 4-5 0-9-3-11-7h28l1-3c0-11-8-18-19-18zm-10 15c2-4 5-7 10-7s8 3 9 7h-19zm-39 3c0 6 4 10 10 10 4 0 7-2 9-5l8 5c-3 5-9 8-17 8-11 0-19-7-19-18s8-18 19-18c8 0 14 3 17 8l-8 5c-2-3-5-5-9-5-6 0-10 4-10 10zm83-29v46h-9V5h9zM37 0l37 64H0L37 0zm92 5-27 48L74 5h10l18 30 17-30h10zm59 12v10l-3-1c-6 0-10 4-10 10v15h-9V17h9v9c0-5 6-9 13-9z"/></svg>

After

Width:  |  Height:  |  Size: 629 B

View File

@ -0,0 +1,22 @@
"use client";
import { Inter } from "next/font/google";
import { Header } from "@/components/Header";
import { SlackForm } from "@/components/admin/connectors/SlackForm";
const inter = Inter({ subsets: ["latin"] });
export default function Home() {
return (
<>
<Header />
<div className="p-24 min-h-screen bg-gray-900 text-gray-100">
<div>
<h1 className="text-4xl font-bold mb-4">Slack</h1>
</div>
<h2 className="text-3xl font-bold mb-4 ml-auto mr-auto">Config</h2>
<SlackForm onSubmit={(success) => console.log(success)}/>
</div>
</>
);
}

BIN
web/src/app/favicon.ico Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 25 KiB

3
web/src/app/globals.css Normal file
View File

@ -0,0 +1,3 @@
@tailwind base;
@tailwind components;
@tailwind utilities;

18
web/src/app/layout.tsx Normal file
View File

@ -0,0 +1,18 @@
import "./globals.css";
export const metadata = {
title: "Create Next App",
description: "Generated by create next app",
};
export default function RootLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<html lang="en">
<body>{children}</body>
</html>
);
}

16
web/src/app/page.tsx Normal file
View File

@ -0,0 +1,16 @@
import { Inter } from "next/font/google";
import { SearchSection } from "@/components/SearchBar";
import { Header } from "@/components/Header";
const inter = Inter({ subsets: ["latin"] });
export default function Home() {
return (
<>
<Header />
<div className="p-24 flex flex-col items-center min-h-screen bg-gray-900 text-gray-100">
<SearchSection />
</div>
</>
);
}

View File

@ -0,0 +1,12 @@
import React from "react";
import "tailwindcss/tailwind.css";
export const Header: React.FC = () => {
return (
<header className="bg-gray-800 text-gray-200 py-4">
<div className="container mx-auto ml-8">
<h1 className="text-2xl font-bold">danswer 💃</h1>
</div>
</header>
);
};

View File

@ -0,0 +1,93 @@
"use client";
import React, { useState, KeyboardEvent, ChangeEvent } from "react";
import { MagnifyingGlass } from "@phosphor-icons/react";
import "tailwindcss/tailwind.css";
import { SearchResultsDisplay } from "./SearchResultsDisplay";
import { SearchResponse } from "./types";
const BACKEND_URL =
process.env.NEXT_PUBLIC_BACKEND_URL || "http://localhost:8000"; // "http://servi-lb8a1-jhqpsz92kbm2-1605938866.us-east-2.elb.amazonaws.com/direct-qa";
const searchRequest = async (query: string): Promise<SearchResponse> => {
const response = await fetch(BACKEND_URL + "/direct-qa", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
query: query,
collection: "semantic_search",
}),
});
return response.json();
};
export const SearchSection: React.FC<{}> = () => {
const [answer, setAnswer] = useState<SearchResponse>();
const [isFetching, setIsFetching] = useState(false);
return (
<>
<SearchBar
onSearch={(query) => {
setIsFetching(true);
searchRequest(query).then((response) => {
setIsFetching(false);
setAnswer(response);
});
}}
/>
<SearchResultsDisplay data={answer} isFetching={isFetching} />
</>
);
};
interface SearchBarProps {
onSearch: (searchTerm: string) => void;
}
const SearchBar: React.FC<SearchBarProps> = ({ onSearch }) => {
const [searchTerm, setSearchTerm] = useState<string>("");
const handleChange = (event: ChangeEvent<HTMLTextAreaElement>) => {
const target = event.target;
setSearchTerm(target.value);
// Reset the textarea height
target.style.height = "24px";
// Calculate the new height based on scrollHeight
const newHeight = target.scrollHeight;
// Apply the new height
target.style.height = `${newHeight}px`;
};
// const handleSubmit = (event: KeyboardEvent<HTMLInputElement>) => {
// onSearch(searchTerm);
// };
const handleKeyDown = (event: KeyboardEvent<HTMLTextAreaElement>) => {
if (event.key === "Enter" && !event.shiftKey) {
onSearch(searchTerm);
event.preventDefault();
}
};
return (
<div className="flex justify-center p-4">
<div className="flex items-center w-[800px] border-2 border-gray-300 rounded px-4 py-2 focus-within:border-blue-500">
<MagnifyingGlass className="text-gray-400" />
<textarea
className="flex-grow ml-2 h-6 bg-transparent outline-none placeholder-gray-400 overflow-hidden whitespace-normal resize-none"
role="textarea"
aria-multiline
placeholder="Search..."
value={searchTerm}
onChange={handleChange}
onKeyDown={handleKeyDown}
suppressContentEditableWarning={true}
/>
</div>
</div>
);
};

View File

@ -0,0 +1,73 @@
import React from "react";
import { Globe, SlackLogo, GoogleDriveLogo } from "@phosphor-icons/react";
import "tailwindcss/tailwind.css";
import { SearchResponse } from "./types";
import { ThinkingAnimation } from "./Thinking";
interface SearchResultsDisplayProps {
data: SearchResponse | undefined;
isFetching: boolean;
}
const getSourceIcon = (sourceType: string) => {
switch (sourceType) {
case "Web":
return <Globe className="text-blue-600" />;
case "Slack":
return <SlackLogo className="text-blue-600" />;
case "Google Drive":
return <GoogleDriveLogo className="text-blue-600" />;
default:
return null;
}
};
export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
data,
isFetching,
}) => {
if (isFetching) {
return <ThinkingAnimation />;
}
if (!data) {
return null;
}
const { answer, quotes } = data;
if (!answer || !quotes) {
return <div>Unable to find an answer</div>;
}
return (
<div className="p-4">
<h2 className="text-2xl font-bold mb-4">Answer</h2>
<p className="mb-6">{answer}</p>
<h2 className="text-2xl font-bold mb-4">Quotes</h2>
<ul>
{Object.entries(quotes).map(([quoteText, quoteInfo]) => (
<li key={quoteInfo.document_id} className="mb-4">
<blockquote className="italic text-lg mb-2">{quoteText}</blockquote>
<p>
<strong>Source:</strong> {getSourceIcon(quoteInfo.source_type)}{" "}
{quoteInfo.source_type}
</p>
<p>
<strong>Link:</strong>{" "}
<a
href={quoteInfo.link}
target="_blank"
rel="noopener noreferrer"
className="text-blue-600"
>
{quoteInfo.link}
</a>
</p>
</li>
))}
</ul>
</div>
);
};

View File

@ -0,0 +1,31 @@
import React, { useState, useEffect } from "react";
import "./thinking.css";
export const ThinkingAnimation: React.FC = () => {
const [dots, setDots] = useState("...");
useEffect(() => {
const interval = setInterval(() => {
setDots((prevDots) => {
switch (prevDots) {
case ".":
return "..";
case "..":
return "...";
case "...":
return ".";
default:
return "...";
}
});
}, 500);
return () => clearInterval(interval);
}, []);
return (
<div className="thinking-animation">
Thinking<span className="dots">{dots}</span>
</div>
);
};

View File

@ -0,0 +1,14 @@
interface PopupProps {
message: string;
type: "success" | "error";
}
export const Popup: React.FC<PopupProps> = ({ message, type }) => (
<div
className={`fixed bottom-4 left-4 p-4 rounded-md shadow-lg text-white ${
type === "success" ? "bg-green-500" : "bg-red-500"
}`}
>
{message}
</div>
);

View File

@ -0,0 +1,163 @@
import React, { useEffect, useState } from "react";
import { Formik, Form, Field, ErrorMessage, FormikHelpers } from "formik";
import * as Yup from "yup";
import { BACKEND_URL } from "@/lib/constants";
import { Popup } from "./Popup";
interface FormData {
slack_bot_token: string;
workspace_id: string;
}
const validationSchema = Yup.object().shape({
slack_bot_token: Yup.string().required("Please enter your Slack Bot Token"),
workspace_id: Yup.string().required("Please enter your Workspace ID"),
pull_frequency: Yup.number().required("Please enter a pull frequency (in minutes). 0 => no pulling from slack"),
});
const getConfig = async (): Promise<FormData> => {
const response = await fetch(BACKEND_URL + "/admin/slack_connector_config");
return response.json();
};
const handleSubmit = async (
values: FormData,
{ setSubmitting }: FormikHelpers<FormData>,
setPopup: (
popup: { message: string; type: "success" | "error" } | null
) => void
) => {
setSubmitting(true);
try {
// Replace this with your actual API call
const response = await fetch(
BACKEND_URL + "/admin/slack_connector_config",
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(values),
}
);
if (response.ok) {
setPopup({ message: "Success!", type: "success" });
} else {
const errorData = await response.json();
setPopup({ message: `Error: ${errorData.detail}`, type: "error" });
}
} catch (error) {
setPopup({ message: `Error: ${error}`, type: "error" });
} finally {
setSubmitting(false);
setTimeout(() => {
setPopup(null);
}, 3000);
}
};
interface SlackFormProps {
onSubmit: (isSuccess: boolean) => void;
}
export const SlackForm: React.FC<SlackFormProps> = ({ onSubmit }) => {
const [initialValues, setInitialValues] = React.useState<FormData>();
const [popup, setPopup] = useState<{
message: string;
type: "success" | "error";
} | null>(null);
useEffect(() => {
getConfig().then((response) => {
setInitialValues(response);
});
}, []);
if (!initialValues) {
// TODO (chris): improve
return <div>Loading...</div>;
}
return (
<>
{popup && <Popup message={popup.message} type={popup.type} />}
<Formik
initialValues={initialValues}
validationSchema={validationSchema}
onSubmit={(values, formikHelpers) =>
handleSubmit(values, formikHelpers, setPopup)
}
>
{({ isSubmitting }) => (
<Form className="bg-white p-6 rounded shadow-md w-full max-w-md mx-auto">
<div className="mb-4">
<label
htmlFor="slack_bot_token"
className="block text-gray-700 mb-1"
>
Slack Bot Token:
</label>
<Field
type="text"
name="slack_bot_token"
id="slack_bot_token"
className="border border-gray-300 rounded w-full py-2 px-3 text-gray-700"
/>
<ErrorMessage
name="slack_bot_token"
component="div"
className="text-red-500 text-sm mt-1"
/>
</div>
<div className="mb-4">
<label
htmlFor="workspace_id"
className="block text-gray-700 mb-1"
>
Workspace ID:
</label>
<Field
type="text"
name="workspace_id"
id="workspace_id"
className="border border-gray-300 rounded w-full py-2 px-3 text-gray-700"
/>
<ErrorMessage
name="workspace_id"
component="div"
className="text-red-500 text-sm mt-1"
/>
</div>
<div className="mb-4">
<label
htmlFor="workspace_id"
className="block text-gray-700 mb-1"
>
Pull Frequency:
</label>
<Field
type="text"
name="pull_frequency"
id="pull_frequency"
className="border border-gray-300 rounded w-full py-2 px-3 text-gray-700"
/>
<ErrorMessage
name="pull_frequency"
component="div"
className="text-red-500 text-sm mt-1"
/>
</div>
<button
type="submit"
disabled={isSubmitting}
className="bg-blue-500 hover:bg-blue-700 text-white font-bold py-2 px-4 rounded focus:outline-none focus:shadow-outline w-full"
>
Submit
</button>
</Form>
)}
</Formik>
</>
);
};

View File

@ -0,0 +1,18 @@
.thinking {
font-size: 1.5rem;
font-weight: bold;
}
.dots {
animation: blink 1s linear infinite;
}
@keyframes blink {
0%,
100% {
opacity: 1;
}
50% {
opacity: 0.5;
}
}

View File

@ -0,0 +1,10 @@
export interface Quote {
document_id: string;
link: string;
source_type: string;
}
export interface SearchResponse {
answer: string;
quotes: Record<string, Quote>;
}

2
web/src/lib/constants.ts Normal file
View File

@ -0,0 +1,2 @@
export const BACKEND_URL =
process.env.NEXT_PUBLIC_BACKEND_URL || "http://localhost:8080"; // "http://servi-lb8a1-jhqpsz92kbm2-1605938866.us-east-2.elb.amazonaws.com/direct-qa";

15
web/tailwind.config.js Normal file
View File

@ -0,0 +1,15 @@
/** @type {import('tailwindcss').Config} */
module.exports = {
content: [
"./app/**/*.{js,ts,jsx,tsx,mdx}",
"./pages/**/*.{js,ts,jsx,tsx,mdx}",
"./components/**/*.{js,ts,jsx,tsx,mdx}",
// Or if using `src` directory:
"./src/**/*.{js,ts,jsx,tsx,mdx}",
],
theme: {
extend: {},
},
plugins: [],
};

28
web/tsconfig.json Normal file
View File

@ -0,0 +1,28 @@
{
"compilerOptions": {
"target": "es5",
"lib": ["dom", "dom.iterable", "esnext"],
"allowJs": true,
"skipLibCheck": true,
"strict": true,
"forceConsistentCasingInFileNames": true,
"noEmit": true,
"esModuleInterop": true,
"module": "esnext",
"moduleResolution": "node",
"resolveJsonModule": true,
"isolatedModules": true,
"jsx": "preserve",
"incremental": true,
"plugins": [
{
"name": "next"
}
],
"paths": {
"@/*": ["./src/*"]
}
},
"include": ["next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts"],
"exclude": ["node_modules"]
}