DAN-81 Improve search round 2 (#82)

Includes:
- Multi vector indexing/search
- Ensemble model reranking
- Keyword Search backend
This commit is contained in:
Yuhong Sun 2023-06-04 20:02:32 -07:00 committed by GitHub
parent 7cc64efc3a
commit c4e8afe4d2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 1223 additions and 863 deletions

View File

@ -45,6 +45,8 @@ def create_indexing_jobs(db_session: Session) -> None:
in_progress_indexing_attempts = get_incomplete_index_attempts(
connector.id, db_session
)
if in_progress_indexing_attempts:
logger.error("Found incomplete indexing attempts")
# Currently single threaded so any still in-progress must have errored
for attempt in in_progress_indexing_attempts:
@ -113,7 +115,6 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
document_ids: list[str] = []
for doc_batch in doc_batch_generator:
# TODO introduce permissioning here
index_user_id = (
None if db_credential.public_doc else db_credential.user_id
)

View File

@ -1,4 +1,5 @@
import inspect
import json
from dataclasses import dataclass
from typing import Any
from typing import cast
@ -19,12 +20,14 @@ class BaseChunk:
@dataclass
class IndexChunk(BaseChunk):
# During indexing flow, we have access to a complete "Document"
# During inference we only have access to the document id and do not reconstruct the Document
source_document: Document
@dataclass
class EmbeddedIndexChunk(IndexChunk):
embedding: list[float]
embeddings: list[list[float]]
@dataclass
@ -39,8 +42,13 @@ class InferenceChunk(BaseChunk):
k: v for k, v in init_dict.items() if k in inspect.signature(cls).parameters
}
if "source_links" in init_kwargs:
source_links = init_kwargs["source_links"]
source_links_dict = (
json.loads(source_links)
if isinstance(source_links, str)
else source_links
)
init_kwargs["source_links"] = {
int(k): v
for k, v in cast(dict[str, str], init_kwargs["source_links"]).items()
int(k): v for k, v in cast(dict[str, str], source_links_dict).items()
}
return cls(**init_kwargs)

View File

@ -52,16 +52,23 @@ MASK_CREDENTIAL_PREFIX = (
#####
# DB Configs
#####
DEFAULT_VECTOR_STORE = os.environ.get("VECTOR_DB", "qdrant")
# Qdrant is Semantic Search Vector DB
# Url / Key are used to connect to a remote Qdrant instance
QDRANT_URL = os.environ.get("QDRANT_URL", "")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
# Host / Port are used for connecting to local Qdrant instance
QDRANT_HOST = os.environ.get("QDRANT_HOST", "localhost")
QDRANT_PORT = 6333
QDRANT_DEFAULT_COLLECTION = os.environ.get("QDRANT_COLLECTION", "semantic_search")
DB_CONN_TIMEOUT = 2 # Timeout seconds connecting to DBs
INDEX_BATCH_SIZE = 16 # File batches (not accounting file chunking)
QDRANT_DEFAULT_COLLECTION = os.environ.get("QDRANT_DEFAULT_COLLECTION", "danswer_index")
# Typesense is the Keyword Search Engine
TYPESENSE_HOST = os.environ.get("TYPESENSE_HOST", "localhost")
TYPESENSE_PORT = 8108
TYPESENSE_DEFAULT_COLLECTION = os.environ.get(
"TYPESENSE_DEFAULT_COLLECTION", "danswer_index"
)
TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "")
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
INDEX_BATCH_SIZE = 16
# below are intended to match the env variables names used by the official postgres docker image
# https://hub.docker.com/_/postgres
@ -81,13 +88,11 @@ GOOGLE_DRIVE_INCLUDE_SHARED = False
#####
# Query Configs
#####
DEFAULT_PROMPT = "generic-qa"
NUM_RETURNED_HITS = 15
NUM_RERANKED_RESULTS = 4
KEYWORD_MAX_HITS = 5
QUOTE_ALLOWED_ERROR_PERCENT = (
0.05 # 1 edit per 2 characters, currently unused due to fuzzy match being too slow
)
NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15
NUM_GENERATIVE_AI_INPUT_DOCS = 5
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = 10 # 10 seconds
@ -97,6 +102,11 @@ QA_TIMEOUT = 10 # 10 seconds
# Chunking docs to this number of characters not including finishing the last word and the overlap words below
# Calculated by ~500 to 512 tokens max * average 4 chars per token
CHUNK_SIZE = 2000
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MINI_CHUNK = False
# Mini chunks for fine-grained embedding, calculated as 128 tokens for 4 additional vectors for 512 chunk size above
# Not rounded down to not lose any context in full chunk.
MINI_CHUNK_SIZE = 512
# Each chunk includes an additional 5 words from previous chunk
# in extreme cases, may cause some words at the end to be truncated by embedding model
CHUNK_OVERLAP = 5
@ -120,10 +130,6 @@ CROSS_ENCODER_PORT = 9000
#####
# Miscellaneous
#####
TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "")
TYPESENSE_HOST = "localhost"
TYPESENSE_PORT = 8108
DYNAMIC_CONFIG_STORE = os.environ.get(
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
)

View File

@ -5,13 +5,22 @@ import os
# Models used must be MIT or Apache license
# Inference/Indexing speed
# Bi/Cross-Encoder Model Configs
# https://www.sbert.net/docs/pretrained_models.html
# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding)
# Context size is 256 for above though
DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1"
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html
# Previously using "cross-encoder/ms-marco-MiniLM-L-6-v2" alone
CROSS_ENCODER_MODEL_ENSEMBLE = [
"cross-encoder/ms-marco-MiniLM-L-4-v2",
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
]
QUERY_EMBEDDING_CONTEXT_SIZE = 256
# The below is correlated with CHUNK_SIZE in app_configs but not strictly calculated
# To avoid extra overhead of tokenizing for chunking during indexing.
DOC_EMBEDDING_CONTEXT_SIZE = 512
CROSS_EMBED_CONTEXT_SIZE = 512

View File

@ -79,92 +79,96 @@ class WebConnector(LoadConnector):
if self.base_url[-1] != "/":
visited_links.add(self.base_url + "/")
with sync_playwright() as playwright:
browser = playwright.chromium.launch(headless=True)
context = browser.new_context()
restart_playwright = True
while to_visit:
current_url = to_visit.pop()
if current_url in visited_links:
continue
visited_links.add(current_url)
while to_visit:
current_url = to_visit.pop()
if current_url in visited_links:
continue
visited_links.add(current_url)
try:
if restart_playwright:
playwright = sync_playwright().start()
browser = playwright.chromium.launch(headless=True)
context = browser.new_context()
restart_playwright = False
try:
if current_url.split(".")[-1] == "pdf":
# PDF files are not checked for links
response = requests.get(current_url)
pdf_reader = PdfReader(io.BytesIO(response.content))
page_text = ""
for pdf_page in pdf_reader.pages:
page_text += pdf_page.extract_text()
doc_batch.append(
Document(
id=current_url,
sections=[Section(link=current_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=current_url.split(".")[-1],
metadata={},
)
)
continue
page = context.new_page()
page.goto(current_url)
content = page.content()
soup = BeautifulSoup(content, "html.parser")
internal_links = get_internal_links(
self.base_url, current_url, soup
)
for link in internal_links:
if link not in visited_links:
to_visit.append(link)
title_tag = soup.find("title")
title = None
if title_tag and title_tag.text:
title = title_tag.text
# Heuristics based cleaning
for undesired_div in ["sidebar", "header", "footer"]:
[
tag.extract()
for tag in soup.find_all(
"div", class_=lambda x: x and undesired_div in x.split()
)
]
for undesired_tag in [
"nav",
"header",
"footer",
"meta",
"script",
"style",
]:
[tag.extract() for tag in soup.find_all(undesired_tag)]
page_text = soup.get_text(HTML_SEPARATOR)
if current_url.split(".")[-1] == "pdf":
# PDF files are not checked for links
response = requests.get(current_url)
pdf_reader = PdfReader(io.BytesIO(response.content))
page_text = ""
for pdf_page in pdf_reader.pages:
page_text += pdf_page.extract_text()
doc_batch.append(
Document(
id=current_url,
sections=[Section(link=current_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=title,
semantic_identifier=current_url.split(".")[-1],
metadata={},
)
)
page.close()
except Exception as e:
logger.error(f"Failed to fetch '{current_url}': {e}")
continue
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
page = context.new_page()
page.goto(current_url)
content = page.content()
soup = BeautifulSoup(content, "html.parser")
if doc_batch:
internal_links = get_internal_links(self.base_url, current_url, soup)
for link in internal_links:
if link not in visited_links:
to_visit.append(link)
title_tag = soup.find("title")
title = None
if title_tag and title_tag.text:
title = title_tag.text
# Heuristics based cleaning
for undesired_div in ["sidebar", "header", "footer"]:
[
tag.extract()
for tag in soup.find_all(
"div", class_=lambda x: x and undesired_div in x.split()
)
]
for undesired_tag in [
"nav",
"header",
"footer",
"meta",
"script",
"style",
]:
[tag.extract() for tag in soup.find_all(undesired_tag)]
page_text = soup.get_text(HTML_SEPARATOR)
doc_batch.append(
Document(
id=current_url,
sections=[Section(link=current_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=title,
metadata={},
)
)
page.close()
except Exception as e:
logger.error(f"Failed to fetch '{current_url}': {e}")
continue
if len(doc_batch) >= self.batch_size:
playwright.stop()
restart_playwright = True
yield doc_batch
doc_batch = []
if doc_batch:
playwright.stop()
yield doc_batch

View File

@ -1,25 +0,0 @@
from typing import Type
from danswer.configs.app_configs import DEFAULT_VECTOR_STORE
from danswer.datastores.interfaces import Datastore
from danswer.datastores.qdrant.store import QdrantDatastore
def get_selected_datastore_cls(
vector_db_type: str = DEFAULT_VECTOR_STORE,
) -> Type[Datastore]:
"""Returns the selected Datastore cls. Only one datastore
should be selected for a specific deployment."""
if vector_db_type == "qdrant":
return QdrantDatastore
else:
raise ValueError(f"Invalid Vector DB setting: {vector_db_type}")
def create_datastore(
collection: str, vector_db_type: str = DEFAULT_VECTOR_STORE
) -> Datastore:
if vector_db_type == "qdrant":
return QdrantDatastore(collection=collection)
else:
raise ValueError(f"Invalid Vector DB setting: {vector_db_type}")

View File

@ -0,0 +1,72 @@
import uuid
from collections.abc import Callable
from copy import deepcopy
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import ALLOWED_GROUPS
from danswer.configs.constants import ALLOWED_USERS
DEFAULT_BATCH_SIZE = 30
def get_uuid_from_chunk(
chunk: IndexChunk | EmbeddedIndexChunk | InferenceChunk, mini_chunk_ind: int = 0
) -> uuid.UUID:
doc_str = (
chunk.document_id
if isinstance(chunk, InferenceChunk)
else chunk.source_document.id
)
# Web parsing URL duplicate catching
if doc_str and doc_str[-1] == "/":
doc_str = doc_str[:-1]
unique_identifier_string = "_".join(
[doc_str, str(chunk.chunk_id), str(mini_chunk_ind)]
)
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
# Takes the chunk identifier returns whether the chunk exists and the user/group whitelists
WhitelistCallable = Callable[[str], tuple[bool, list[str], list[str]]]
def update_doc_user_map(
chunk: IndexChunk | EmbeddedIndexChunk,
doc_whitelist_map: dict[str, dict[str, list[str]]],
doc_store_whitelist_fnc: WhitelistCallable,
user_str: str,
) -> tuple[dict[str, dict[str, list[str]]], bool]:
"""Returns an updated document id to whitelists mapping and if the document's chunks need to be wiped."""
doc_whitelist_map = deepcopy(doc_whitelist_map)
first_chunk_uuid = str(get_uuid_from_chunk(chunk))
document = chunk.source_document
if document.id not in doc_whitelist_map:
first_chunk_found, whitelist_users, whitelist_groups = doc_store_whitelist_fnc(
first_chunk_uuid
)
if not first_chunk_found:
doc_whitelist_map[document.id] = {
ALLOWED_USERS: [user_str],
# TODO introduce groups logic here
ALLOWED_GROUPS: whitelist_groups,
}
# First chunk does not exist so document does not exist, no need for deletion
return doc_whitelist_map, False
else:
if user_str not in whitelist_users:
whitelist_users.append(user_str)
# TODO introduce groups logic here
doc_whitelist_map[document.id] = {
ALLOWED_USERS: whitelist_users,
ALLOWED_GROUPS: whitelist_groups,
}
# First chunk exists, but with update, there may be less total chunks now
# Must delete rest of document chunks
return doc_whitelist_map, True
# If document is already in the mapping, don't delete again
return doc_whitelist_map, False

View File

@ -1,22 +1,42 @@
import abc
from typing import Generic
from typing import TypeVar
from danswer.chunking.models import BaseChunk
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
DatastoreFilter = dict[str, str | list[str] | None]
T = TypeVar("T", bound=BaseChunk)
IndexFilter = dict[str, str | list[str] | None]
class Datastore:
class DocumentIndex(Generic[T], abc.ABC):
@abc.abstractmethod
def index(self, chunks: list[EmbeddedIndexChunk], user_id: int | None) -> bool:
def index(self, chunks: list[T], user_id: int | None) -> bool:
raise NotImplementedError
class VectorIndex(DocumentIndex[EmbeddedIndexChunk], abc.ABC):
@abc.abstractmethod
def semantic_retrieval(
self,
query: str,
user_id: int | None,
filters: list[DatastoreFilter] | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int,
) -> list[InferenceChunk]:
raise NotImplementedError
class KeywordIndex(DocumentIndex[IndexChunk], abc.ABC):
@abc.abstractmethod
def keyword_search(
self,
query: str,
user_id: int | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int,
) -> list[InferenceChunk]:
raise NotImplementedError

View File

@ -1,4 +1,4 @@
import uuid
from functools import partial
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.configs.constants import ALLOWED_GROUPS
@ -13,9 +13,11 @@ from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.datastore_utils import update_doc_user_map
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logging import setup_logger
from danswer.utils.timing import log_function_time
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.exceptions import ResponseHandlingException
@ -26,16 +28,15 @@ from qdrant_client.models import Distance
from qdrant_client.models import PointStruct
from qdrant_client.models import VectorParams
logger = setup_logger()
DEFAULT_BATCH_SIZE = 30
def list_collections() -> CollectionsResponse:
def list_qdrant_collections() -> CollectionsResponse:
return get_qdrant_client().get_collections()
def create_collection(
def create_qdrant_collection(
collection_name: str, embedding_dim: int = DOC_EMBEDDING_DIM
) -> None:
logger.info(f"Attempting to create collection {collection_name}")
@ -47,25 +48,25 @@ def create_collection(
raise RuntimeError("Could not create Qdrant collection")
def get_document_whitelists(
def get_qdrant_document_whitelists(
doc_chunk_id: str, collection_name: str, q_client: QdrantClient
) -> tuple[int, list[str], list[str]]:
) -> tuple[bool, list[str], list[str]]:
results = q_client.retrieve(
collection_name=collection_name,
ids=[doc_chunk_id],
with_payload=[ALLOWED_USERS, ALLOWED_GROUPS],
)
if len(results) == 0:
return 0, [], []
return False, [], []
payload = results[0].payload
if not payload:
raise RuntimeError(
"Qdrant Index is corrupted, Document found with no access lists."
)
return len(results), payload[ALLOWED_USERS], payload[ALLOWED_GROUPS]
return True, payload[ALLOWED_USERS], payload[ALLOWED_GROUPS]
def delete_doc_chunks(
def delete_qdrant_doc_chunks(
document_id: str, collection_name: str, q_client: QdrantClient
) -> None:
q_client.delete(
@ -83,24 +84,7 @@ def delete_doc_chunks(
)
def recreate_collection(
collection_name: str, embedding_dim: int = DOC_EMBEDDING_DIM
) -> None:
logger.info(f"Attempting to recreate collection {collection_name}")
result = get_qdrant_client().recreate_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)
if not result:
raise RuntimeError("Could not create Qdrant collection")
def get_uuid_from_chunk(chunk: EmbeddedIndexChunk) -> uuid.UUID:
unique_identifier_string = "_".join([chunk.source_document.id, str(chunk.chunk_id)])
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
def index_chunks(
def index_qdrant_chunks(
chunks: list[EmbeddedIndexChunk],
user_id: int | None,
collection: str,
@ -112,51 +96,45 @@ def index_chunks(
user_str = PUBLIC_DOC_PAT if user_id is None else str(user_id)
q_client: QdrantClient = client if client else get_qdrant_client()
point_structs = []
point_structs: list[PointStruct] = []
# Maps document id to dict of whitelists for users/groups each containing list of users/groups as strings
doc_user_map: dict[str, dict[str, list[str]]] = {}
for chunk in chunks:
chunk_uuid = str(get_uuid_from_chunk(chunk))
document = chunk.source_document
doc_user_map, delete_doc = update_doc_user_map(
chunk,
doc_user_map,
partial(
get_qdrant_document_whitelists,
collection_name=collection,
q_client=q_client,
),
user_str,
)
if document.id not in doc_user_map:
num_doc_chunks, whitelist_users, whitelist_groups = get_document_whitelists(
chunk_uuid, collection, q_client
)
if num_doc_chunks == 0:
doc_user_map[document.id] = {
ALLOWED_USERS: [user_str],
# TODO introduce groups logic here
ALLOWED_GROUPS: whitelist_groups,
}
else:
if user_str not in whitelist_users:
whitelist_users.append(user_str)
# TODO introduce groups logic here
doc_user_map[document.id] = {
ALLOWED_USERS: whitelist_users,
ALLOWED_GROUPS: whitelist_groups,
}
# Need to delete document chunks because number of chunks may decrease
delete_doc_chunks(document.id, collection, q_client)
if delete_doc:
delete_qdrant_doc_chunks(document.id, collection, q_client)
point_structs.append(
PointStruct(
id=chunk_uuid,
payload={
DOCUMENT_ID: document.id,
CHUNK_ID: chunk.chunk_id,
BLURB: chunk.blurb,
CONTENT: chunk.content,
SOURCE_TYPE: str(document.source.value),
SOURCE_LINKS: chunk.source_links,
SEMANTIC_IDENTIFIER: document.semantic_identifier,
SECTION_CONTINUATION: chunk.section_continuation,
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
},
vector=chunk.embedding,
)
point_structs.extend(
[
PointStruct(
id=str(get_uuid_from_chunk(chunk, minichunk_ind)),
payload={
DOCUMENT_ID: document.id,
CHUNK_ID: chunk.chunk_id,
BLURB: chunk.blurb,
CONTENT: chunk.content,
SOURCE_TYPE: str(document.source.value),
SOURCE_LINKS: chunk.source_links,
SEMANTIC_IDENTIFIER: document.semantic_identifier,
SECTION_CONTINUATION: chunk.section_continuation,
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
},
vector=embedding,
)
for minichunk_ind, embedding in enumerate(chunk.embeddings)
]
)
index_results = None
@ -182,12 +160,14 @@ def index_chunks(
index_results = upsert()
log_status = index_results.status if index_results else "Failed"
logger.info(
f"Indexed {len(point_struct_batch)} chunks into collection '{collection}', "
f"Indexed {len(point_struct_batch)} chunks into Qdrant collection '{collection}', "
f"status: {log_status}"
)
else:
index_results = q_client.upsert(
collection_name=collection, points=point_structs
)
logger.info(f"Batch indexing status: {index_results.status}")
logger.info(
f"Document batch of size {len(point_structs)} indexing status: {index_results.status}"
)
return index_results is not None and index_results.status == UpdateStatus.COMPLETED

View File

@ -1,12 +1,17 @@
import uuid
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.constants import ALLOWED_USERS
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.datastores.interfaces import Datastore
from danswer.datastores.interfaces import DatastoreFilter
from danswer.datastores.qdrant.indexing import index_chunks
from danswer.semantic_search.semantic_search import get_default_embedding_model
from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import VectorIndex
from danswer.datastores.qdrant.indexing import index_qdrant_chunks
from danswer.search.semantic_search import get_default_embedding_model
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logging import setup_logger
from danswer.utils.timing import log_function_time
@ -20,13 +25,60 @@ from qdrant_client.http.models import MatchValue
logger = setup_logger()
class QdrantDatastore(Datastore):
def _build_qdrant_filters(
user_id: int | None, filters: list[IndexFilter] | None
) -> list[FieldCondition]:
filter_conditions: list[FieldCondition] = []
# Permissions filter
if user_id:
filter_conditions.append(
FieldCondition(
key=ALLOWED_USERS,
match=MatchAny(any=[str(user_id), PUBLIC_DOC_PAT]),
)
)
else:
filter_conditions.append(
FieldCondition(
key=ALLOWED_USERS,
match=MatchValue(value=PUBLIC_DOC_PAT),
)
)
# Provided query filters
if filters:
for filter_dict in filters:
valid_filters = {
key: value for key, value in filter_dict.items() if value is not None
}
for filter_key, filter_val in valid_filters.items():
if isinstance(filter_val, str):
filter_conditions.append(
FieldCondition(
key=filter_key,
match=MatchValue(value=filter_val),
)
)
elif isinstance(filter_val, list):
filter_conditions.append(
FieldCondition(
key=filter_key,
match=MatchAny(any=filter_val),
)
)
else:
raise ValueError("Invalid filters provided")
return filter_conditions
class QdrantIndex(VectorIndex):
def __init__(self, collection: str = QDRANT_DEFAULT_COLLECTION) -> None:
self.collection = collection
self.client = get_qdrant_client()
def index(self, chunks: list[EmbeddedIndexChunk], user_id: int | None) -> bool:
return index_chunks(
return index_qdrant_chunks(
chunks=chunks,
user_id=user_id,
collection=self.collection,
@ -38,8 +90,9 @@ class QdrantDatastore(Datastore):
self,
query: str,
user_id: int | None,
filters: list[DatastoreFilter] | None,
num_to_retrieve: int,
filters: list[IndexFilter] | None,
num_to_retrieve: int = NUM_RETURNED_HITS,
page_size: int = NUM_RERANKED_RESULTS,
) -> list[InferenceChunk]:
query_embedding = get_default_embedding_model().encode(
query
@ -47,68 +100,47 @@ class QdrantDatastore(Datastore):
if not isinstance(query_embedding, list):
query_embedding = query_embedding.tolist()
hits = []
filter_conditions = []
try:
# Permissions filter
if user_id:
filter_conditions.append(
FieldCondition(
key=ALLOWED_USERS,
match=MatchAny(any=[str(user_id), PUBLIC_DOC_PAT]),
)
)
else:
filter_conditions.append(
FieldCondition(
key=ALLOWED_USERS,
match=MatchValue(value=PUBLIC_DOC_PAT),
)
)
filter_conditions = _build_qdrant_filters(user_id, filters)
# Provided query filters
if filters:
for filter_dict in filters:
valid_filters = {
key: value
for key, value in filter_dict.items()
if value is not None
}
for filter_key, filter_val in valid_filters.items():
if isinstance(filter_val, str):
filter_conditions.append(
FieldCondition(
key=filter_key,
match=MatchValue(value=filter_val),
)
)
elif isinstance(filter_val, list):
filter_conditions.append(
FieldCondition(
key=filter_key,
match=MatchAny(any=filter_val),
)
)
else:
raise ValueError("Invalid filters provided")
page_offset = 0
found_inference_chunks: list[InferenceChunk] = []
found_chunk_uuids: set[uuid.UUID] = set()
while len(found_inference_chunks) < num_to_retrieve:
try:
hits = self.client.search(
collection_name=self.collection,
query_vector=query_embedding,
query_filter=Filter(must=list(filter_conditions)),
limit=page_size,
offset=page_offset,
)
page_offset += page_size
if not hits:
break
except ResponseHandlingException as e:
logger.exception(
f'Qdrant querying failed due to: "{e}", is Qdrant set up?'
)
break
except UnexpectedResponse as e:
logger.exception(
f'Qdrant querying failed due to: "{e}", has ingestion been run?'
)
break
hits = self.client.search(
collection_name=self.collection,
query_vector=query_embedding,
query_filter=Filter(must=list(filter_conditions)),
limit=num_to_retrieve,
)
except ResponseHandlingException as e:
logger.exception(f'Qdrant querying failed due to: "{e}", is Qdrant set up?')
except UnexpectedResponse as e:
logger.exception(
f'Qdrant querying failed due to: "{e}", has ingestion been run?'
)
return [
InferenceChunk.from_dict(hit.payload)
for hit in hits
if hit.payload is not None
]
inference_chunks_from_hits = [
InferenceChunk.from_dict(hit.payload)
for hit in hits
if hit.payload is not None
]
for inf_chunk in inference_chunks_from_hits:
# remove duplicate chunks which happen if minichunks are used
inf_chunk_id = get_uuid_from_chunk(inf_chunk)
if inf_chunk_id not in found_chunk_uuids:
found_inference_chunks.append(inf_chunk)
found_chunk_uuids.add(inf_chunk_id)
return found_inference_chunks
def get_from_id(self, object_id: str) -> InferenceChunk | None:
matches, _ = self.client.scroll(

View File

@ -0,0 +1,238 @@
import json
from functools import partial
from typing import Any
import typesense # type: ignore
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
from danswer.configs.constants import ALLOWED_GROUPS
from danswer.configs.constants import ALLOWED_USERS
from danswer.configs.constants import BLURB
from danswer.configs.constants import CHUNK_ID
from danswer.configs.constants import CONTENT
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.datastore_utils import update_doc_user_map
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import KeywordIndex
from danswer.utils.clients import get_typesense_client
from danswer.utils.logging import setup_logger
from typesense.exceptions import ObjectNotFound # type: ignore
logger = setup_logger()
def check_typesense_collection_exist(
collection_name: str = TYPESENSE_DEFAULT_COLLECTION,
) -> bool:
client = get_typesense_client()
try:
client.collections[collection_name].retrieve()
except ObjectNotFound:
return False
return True
def create_typesense_collection(
collection_name: str = TYPESENSE_DEFAULT_COLLECTION,
) -> None:
ts_client = get_typesense_client()
collection_schema = {
"name": collection_name,
"fields": [
# Typesense uses "id" type string as a special field
{"name": "id", "type": "string"},
{"name": DOCUMENT_ID, "type": "string"},
{"name": CHUNK_ID, "type": "int32"},
{"name": BLURB, "type": "string"},
{"name": CONTENT, "type": "string"},
{"name": SOURCE_TYPE, "type": "string"},
{"name": SOURCE_LINKS, "type": "string"},
{"name": SEMANTIC_IDENTIFIER, "type": "string"},
{"name": SECTION_CONTINUATION, "type": "bool"},
{"name": ALLOWED_USERS, "type": "string[]"},
{"name": ALLOWED_GROUPS, "type": "string[]"},
],
}
ts_client.collections.create(collection_schema)
def get_typesense_document_whitelists(
doc_chunk_id: str, collection_name: str, ts_client: typesense.Client
) -> tuple[bool, list[str], list[str]]:
"""Returns whether the document already exists and the users/group whitelists"""
try:
document = (
ts_client.collections[collection_name].documents[doc_chunk_id].retrieve()
)
except ObjectNotFound:
return False, [], []
if document[ALLOWED_USERS] is None or document[ALLOWED_GROUPS] is None:
raise RuntimeError(
"Typesense Index is corrupted, Document found with no access lists."
)
return True, document[ALLOWED_USERS], document[ALLOWED_GROUPS]
def delete_typesense_doc_chunks(
document_id: str, collection_name: str, ts_client: typesense.Client
) -> None:
search_parameters = {
"q": document_id,
"query_by": DOCUMENT_ID,
}
# TODO consider race conditions if running multiple processes/threads
hits = ts_client.collections[collection_name].documents.search(search_parameters)
[
ts_client.collections[collection_name].documents[hit["document"]["id"]].delete()
for hit in hits["hits"]
]
def index_typesense_chunks(
chunks: list[IndexChunk | EmbeddedIndexChunk],
user_id: int | None,
collection: str,
client: typesense.Client | None = None,
batch_upsert: bool = True,
) -> bool:
user_str = PUBLIC_DOC_PAT if user_id is None else str(user_id)
ts_client: typesense.Client = client if client else get_typesense_client()
new_documents: list[dict[str, Any]] = []
doc_user_map: dict[str, dict[str, list[str]]] = {}
for chunk in chunks:
document = chunk.source_document
doc_user_map, delete_doc = update_doc_user_map(
chunk,
doc_user_map,
partial(
get_typesense_document_whitelists,
collection_name=collection,
ts_client=ts_client,
),
user_str,
)
if delete_doc:
delete_typesense_doc_chunks(document.id, collection, ts_client)
new_documents.append(
{
"id": str(get_uuid_from_chunk(chunk)), # No minichunks for typesense
DOCUMENT_ID: document.id,
CHUNK_ID: chunk.chunk_id,
BLURB: chunk.blurb,
CONTENT: chunk.content,
SOURCE_TYPE: str(document.source.value),
SOURCE_LINKS: json.dumps(chunk.source_links),
SEMANTIC_IDENTIFIER: document.semantic_identifier,
SECTION_CONTINUATION: chunk.section_continuation,
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
}
)
if batch_upsert:
doc_batches = [
new_documents[x : x + DEFAULT_BATCH_SIZE]
for x in range(0, len(new_documents), DEFAULT_BATCH_SIZE)
]
for doc_batch in doc_batches:
results = ts_client.collections[collection].documents.import_(
doc_batch, {"action": "upsert"}
)
failures = [
doc_res["success"]
for doc_res in results
if doc_res["success"] is not True
]
logger.info(
f"Indexed {len(doc_batch)} chunks into Typesense collection '{collection}', "
f"number failed: {len(failures)}"
)
else:
[
ts_client.collections[collection].documents.upsert(document)
for document in new_documents
]
return True
def _build_typesense_filters(
user_id: int | None, filters: list[IndexFilter] | None
) -> str:
filter_str = ""
# Permissions filter
if user_id:
filter_str += f"{ALLOWED_USERS}:=[{PUBLIC_DOC_PAT}|{user_id}] && "
else:
filter_str += f"{ALLOWED_USERS}:={PUBLIC_DOC_PAT} && "
# Provided query filters
if filters:
for filter_dict in filters:
valid_filters = {
key: value for key, value in filter_dict.items() if value is not None
}
for filter_key, filter_val in valid_filters.items():
if isinstance(filter_val, str):
filter_str += f"{filter_key}:={filter_val} && "
elif isinstance(filter_val, list):
filters_or = ",".join([str(f_val) for f_val in filter_val])
filter_str += f"{filter_key}:=[{filters_or}] && "
else:
raise ValueError("Invalid filters provided")
if filter_str[-4:] == " && ":
filter_str = filter_str[:-4]
return filter_str
class TypesenseIndex(KeywordIndex):
def __init__(self, collection: str = TYPESENSE_DEFAULT_COLLECTION) -> None:
self.collection = collection
self.ts_client = get_typesense_client()
def index(self, chunks: list[IndexChunk], user_id: int | None) -> bool:
return index_typesense_chunks(
chunks=chunks,
user_id=user_id,
collection=self.collection,
client=self.ts_client,
)
def keyword_search(
self,
query: str,
user_id: int | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int,
) -> list[InferenceChunk]:
filters_str = _build_typesense_filters(user_id, filters)
search_results = self.ts_client.collections[self.collection].documents.search(
{
"q": query,
"query_by": CONTENT,
"filter_by": filters_str,
"per_page": num_to_retrieve,
"limit_hits": num_to_retrieve,
}
)
hits = search_results["hits"]
inference_chunks = [InferenceChunk.from_dict(hit["document"]) for hit in hits]
return inference_chunks

View File

@ -332,9 +332,17 @@ class OpenAICompletionQA(OpenAIQAModel):
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
logger.info(answer)
if answer:
logger.info(answer)
else:
logger.warning(
"Answer extraction from model output failed, most likely no quotes provided"
)
yield quotes_dict
if quotes_dict is None:
yield {}
else:
yield quotes_dict
class OpenAIChatCompletionQA(OpenAIQAModel):
@ -442,6 +450,14 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
logger.info(answer)
if answer:
logger.info(answer)
else:
logger.warning(
"Answer extraction from model output failed, most likely no quotes provided"
)
yield quotes_dict
if quotes_dict is None:
yield {}
else:
yield quotes_dict

View File

@ -1,3 +1,4 @@
import nltk # type:ignore
import uvicorn
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead
@ -9,8 +10,11 @@ from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import ENABLE_OAUTH
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.datastores.qdrant.indexing import list_collections
from danswer.datastores.qdrant.indexing import list_qdrant_collections
from danswer.datastores.typesense.store import check_typesense_collection_exist
from danswer.datastores.typesense.store import create_typesense_collection
from danswer.db.credentials import create_initial_public_credential
from danswer.server.event_loading import router as event_processing_router
from danswer.server.health import router as health_router
@ -107,24 +111,36 @@ def get_application() -> FastAPI:
@application.on_event("startup")
def startup_event() -> None:
# To avoid circular imports
from danswer.semantic_search.semantic_search import (
from danswer.search.semantic_search import (
warm_up_models,
)
from danswer.datastores.qdrant.indexing import create_collection
from danswer.datastores.qdrant.indexing import create_qdrant_collection
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
if QDRANT_DEFAULT_COLLECTION not in {
collection.name for collection in list_collections().collections
}:
logger.info(f"Creating collection with name: {QDRANT_DEFAULT_COLLECTION}")
create_collection(collection_name=QDRANT_DEFAULT_COLLECTION)
logger.info("Warming up local NLP models.")
warm_up_models()
logger.info("Semantic Search models are ready.")
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords")
nltk.download("wordnet")
logger.info("Verifying public credential exists.")
create_initial_public_credential()
logger.info("Verifying Document Indexes are available.")
if QDRANT_DEFAULT_COLLECTION not in {
collection.name for collection in list_qdrant_collections().collections
}:
logger.info(
f"Creating Qdrant collection with name: {QDRANT_DEFAULT_COLLECTION}"
)
create_qdrant_collection(collection_name=QDRANT_DEFAULT_COLLECTION)
if not check_typesense_collection_exist(TYPESENSE_DEFAULT_COLLECTION):
logger.info(
f"Creating Typesense collection with name: {TYPESENSE_DEFAULT_COLLECTION}"
)
create_typesense_collection(collection_name=TYPESENSE_DEFAULT_COLLECTION)
return application

View File

View File

@ -0,0 +1,52 @@
import json
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import KeywordIndex
from danswer.utils.logging import setup_logger
from danswer.utils.timing import log_function_time
from nltk.corpus import stopwords # type:ignore
from nltk.stem import WordNetLemmatizer # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
logger = setup_logger()
def lemmatize_text(text: str) -> str:
lemmatizer = WordNetLemmatizer()
word_tokens = word_tokenize(text)
lemmatized_text = [lemmatizer.lemmatize(word) for word in word_tokens]
return " ".join(lemmatized_text)
def remove_stop_words(text: str) -> str:
stop_words = set(stopwords.words("english"))
word_tokens = word_tokenize(text)
filtered_text = [word for word in word_tokens if word.casefold() not in stop_words]
return " ".join(filtered_text)
def query_processing(query: str) -> str:
query = remove_stop_words(query)
query = lemmatize_text(query)
return query
@log_function_time()
def retrieve_keyword_documents(
query: str,
user_id: int | None,
filters: list[IndexFilter] | None,
datastore: KeywordIndex,
num_hits: int = NUM_RETURNED_HITS,
) -> list[InferenceChunk] | None:
edited_query = query_processing(query)
top_chunks = datastore.keyword_search(edited_query, user_id, filters, num_hits)
if not top_chunks:
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
logger.warning(
f"Keyword search returned no results...\nfilters: {filters_log_msg}\nedited query: {edited_query}"
)
return None
return top_chunks

View File

@ -0,0 +1,201 @@
import json
import numpy
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.app_configs import MINI_CHUNK_SIZE
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import VectorIndex
from danswer.search.type_aliases import Embedder
from danswer.server.models import SearchDoc
from danswer.utils.logging import setup_logger
from danswer.utils.timing import log_function_time
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
logger = setup_logger()
_EMBED_MODEL: None | SentenceTransformer = None
_RERANK_MODELS: None | list[CrossEncoder] = None
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
search_docs = (
[
SearchDoc(
semantic_identifier=chunk.semantic_identifier,
link=chunk.source_links.get(0) if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
)
for chunk in chunks
]
if chunks
else []
)
return search_docs
def get_default_embedding_model() -> SentenceTransformer:
global _EMBED_MODEL
if _EMBED_MODEL is None:
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
return _EMBED_MODEL
def get_default_reranking_model_ensemble() -> list[CrossEncoder]:
global _RERANK_MODELS
if _RERANK_MODELS is None:
_RERANK_MODELS = [
CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE
]
for model in _RERANK_MODELS:
model.max_length = CROSS_EMBED_CONTEXT_SIZE
return _RERANK_MODELS
def warm_up_models() -> None:
get_default_embedding_model().encode("Danswer is so cool")
cross_encoders = get_default_reranking_model_ensemble()
[cross_encoder.predict(("What is Danswer", "Enterprise QA")) for cross_encoder in cross_encoders] # type: ignore
@log_function_time()
def semantic_reranking(
query: str,
chunks: list[InferenceChunk],
) -> list[InferenceChunk]:
cross_encoders = get_default_reranking_model_ensemble()
sim_scores = sum([encoder.predict([(query, chunk.content) for chunk in chunks]) for encoder in cross_encoders]) # type: ignore
scored_results = list(zip(sim_scores, chunks))
scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_chunks = zip(*scored_results)
logger.debug(f"Reranked similarity scores: {str(ranked_sim_scores)}")
return ranked_chunks
@log_function_time()
def retrieve_ranked_documents(
query: str,
user_id: int | None,
filters: list[IndexFilter] | None,
datastore: VectorIndex,
num_hits: int = NUM_RETURNED_HITS,
num_rerank: int = NUM_RERANKED_RESULTS,
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
top_chunks = datastore.semantic_retrieval(query, user_id, filters, num_hits)
if not top_chunks:
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
logger.warning(
f"Semantic search returned no results with filters: {filters_log_msg}"
)
return None, None
ranked_chunks = semantic_reranking(query, top_chunks[:num_rerank])
top_docs = [
ranked_chunk.source_links[0]
for ranked_chunk in ranked_chunks
if ranked_chunk.source_links is not None
]
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
logger.info(files_log_msg)
return ranked_chunks, top_chunks[num_rerank:]
def split_chunk_text_into_mini_chunks(
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
) -> list[str]:
chunks = []
start = 0
separators = [" ", "\n", "\r", "\t"]
while start < len(chunk_text):
if len(chunk_text) - start <= mini_chunk_size:
end = len(chunk_text)
else:
# Find the first separator character after min_chunk_length
end_positions = [
(chunk_text[start + mini_chunk_size :]).find(sep) for sep in separators
]
# Filter out the not found cases (-1)
end_positions = [pos for pos in end_positions if pos != -1]
if not end_positions:
# If no more separators, the rest of the string becomes a chunk
end = len(chunk_text)
else:
# Add min_chunk_length and start to the end position
end = min(end_positions) + start + mini_chunk_size
chunks.append(chunk_text[start:end])
start = end + 1 # Move to the next character after the separator
return chunks
@log_function_time()
def encode_chunks(
chunks: list[IndexChunk],
embedding_model: SentenceTransformer | None = None,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[EmbeddedIndexChunk]:
embedded_chunks: list[EmbeddedIndexChunk] = []
if embedding_model is None:
embedding_model = get_default_embedding_model()
chunk_texts = []
chunk_mini_chunks_count = {}
for chunk_ind, chunk in enumerate(chunks):
chunk_texts.append(chunk.content)
mini_chunk_texts = (
split_chunk_text_into_mini_chunks(chunk.content)
if enable_mini_chunk
else []
)
chunk_texts.extend(mini_chunk_texts)
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
text_batches = [
chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size)
]
embeddings_np: list[numpy.ndarray] = []
for text_batch in text_batches:
embeddings_np.extend(embedding_model.encode(text_batch))
embeddings: list[list[float]] = [embedding.tolist() for embedding in embeddings_np]
embedding_ind_start = 0
for chunk_ind, chunk in enumerate(chunks):
num_embeddings = chunk_mini_chunks_count[chunk_ind]
chunk_embeddings = embeddings[
embedding_ind_start : embedding_ind_start + num_embeddings
]
new_embedded_chunk = EmbeddedIndexChunk(
**{k: getattr(chunk, k) for k in chunk.__dataclass_fields__},
embeddings=chunk_embeddings,
)
embedded_chunks.append(new_embedded_chunk)
embedding_ind_start += num_embeddings
return embedded_chunks
class DefaultEmbedder(Embedder):
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
return encode_chunks(chunks)

View File

@ -1,201 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright 2019 Nils Reimers
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

View File

@ -1,43 +0,0 @@
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.semantic_search.semantic_search import get_default_embedding_model
from danswer.semantic_search.type_aliases import Embedder
from danswer.utils.logging import setup_logger
from sentence_transformers import SentenceTransformer # type: ignore
logger = setup_logger()
def encode_chunks(
chunks: list[IndexChunk],
embedding_model: SentenceTransformer | None = None,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
) -> list[EmbeddedIndexChunk]:
embedded_chunks = []
if embedding_model is None:
embedding_model = get_default_embedding_model()
chunk_batches = [
chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)
]
for batch_ind, chunk_batch in enumerate(chunk_batches):
embeddings_batch = embedding_model.encode(
[chunk.content for chunk in chunk_batch]
)
embedded_chunks.extend(
[
EmbeddedIndexChunk(
**{k: getattr(chunk, k) for k in chunk.__dataclass_fields__},
embedding=embeddings_batch[i].tolist()
)
for i, chunk in enumerate(chunk_batch)
]
)
return embedded_chunks
class DefaultEmbedder(Embedder):
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
return encode_chunks(chunks)

View File

@ -1,105 +0,0 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
#
# The NLP models used here are licensed under Apache 2.0, the original author's LICENSE file is
# included in this same directory.
# Specifically the sentence-transformers/all-distilroberta-v1 and cross-encoder/ms-marco-MiniLM-L-6-v2 models
# The original authors can be found at https://www.sbert.net/
import json
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.datastores.interfaces import Datastore
from danswer.datastores.interfaces import DatastoreFilter
from danswer.utils.logging import setup_logger
from danswer.utils.timing import log_function_time
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
logger = setup_logger()
_EMBED_MODEL: None | SentenceTransformer = None
_RERANK_MODEL: None | CrossEncoder = None
def get_default_embedding_model() -> SentenceTransformer:
global _EMBED_MODEL
if _EMBED_MODEL is None:
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
return _EMBED_MODEL
def get_default_reranking_model() -> CrossEncoder:
global _RERANK_MODEL
if _RERANK_MODEL is None:
_RERANK_MODEL = CrossEncoder(CROSS_ENCODER_MODEL)
_RERANK_MODEL.max_length = CROSS_EMBED_CONTEXT_SIZE
return _RERANK_MODEL
def warm_up_models() -> None:
get_default_embedding_model().encode("Danswer is so cool")
get_default_reranking_model().predict(("What is Danswer", "Enterprise QA")) # type: ignore
@log_function_time()
def semantic_reranking(
query: str,
chunks: list[InferenceChunk],
) -> list[InferenceChunk]:
cross_encoder = get_default_reranking_model()
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
scored_results = list(zip(sim_scores, chunks))
scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_chunks = zip(*scored_results)
logger.debug(f"Reranked similarity scores: {str(ranked_sim_scores)}")
return ranked_chunks
@log_function_time()
def retrieve_ranked_documents(
query: str,
user_id: int | None,
filters: list[DatastoreFilter] | None,
datastore: Datastore,
num_hits: int = NUM_RETURNED_HITS,
) -> list[InferenceChunk] | None:
top_chunks = datastore.semantic_retrieval(query, user_id, filters, num_hits)
if not top_chunks:
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
logger.warning(
f"Semantic search returned no results with filters: {filters_log_msg}"
)
return None
ranked_chunks = semantic_reranking(query, top_chunks)
top_docs = [
ranked_chunk.source_links[0]
for ranked_chunk in ranked_chunks
if ranked_chunk.source_links is not None
]
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
logger.info(files_log_msg)
return ranked_chunks

View File

@ -1,6 +1,7 @@
from collections import defaultdict
from typing import cast
from danswer.auth.schemas import UserRole
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
@ -31,6 +32,7 @@ from danswer.db.credentials import fetch_credential_by_id
from danswer.db.credentials import fetch_credentials
from danswer.db.credentials import mask_credential_dict
from danswer.db.credentials import update_credential
from danswer.db.engine import build_async_engine
from danswer.db.engine import get_session
from danswer.db.index_attempt import create_index_attempt
from danswer.db.models import Connector
@ -55,23 +57,45 @@ from danswer.server.models import IndexAttemptSnapshot
from danswer.server.models import ObjectCreationIdResponse
from danswer.server.models import RunConnectorRequest
from danswer.server.models import StatusResponse
from danswer.server.models import UserByEmail
from danswer.server.models import UserRoleResponse
from danswer.utils.logging import setup_logger
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
router = APIRouter(prefix="/manage")
router = APIRouter(prefix="/manage")
logger = setup_logger()
_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME = "google_drive_credential_id"
"""Admin only API endpoints"""
@router.patch("/promote-user-to-admin", response_model=None)
async def promote_admin(
user_email: UserByEmail, user: User = Depends(current_admin_user)
) -> None:
if user.role != UserRole.ADMIN:
raise HTTPException(status_code=401, detail="Unauthorized")
async with AsyncSession(build_async_engine()) as asession:
user_db = SQLAlchemyUserDatabase(asession, User) # type: ignore
user_to_promote = await user_db.get_by_email(user_email.user_email)
if not user_to_promote:
raise HTTPException(status_code=404, detail="User not found")
user_to_promote.role = UserRole.ADMIN
asession.add(user_to_promote)
await asession.commit()
return
@router.get("/admin/connector/google-drive/app-credential")
def check_google_app_credentials_exist(
_: User = Depends(current_admin_user),
@ -403,7 +427,14 @@ def delete_openai_api_key(
get_dynamic_config_store().delete(OPENAI_API_KEY_STORAGE_KEY)
"""Endpoints for all!"""
"""Endpoints for basic users"""
@router.get("/get-user-role", response_model=UserRoleResponse)
async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
if user is None:
raise ValueError("Invalid or missing user.")
return UserRoleResponse(role=user.role)
@router.get("/connector/google-drive/authorize/{credential_id}")

View File

@ -3,12 +3,10 @@ from typing import Any
from typing import Generic
from typing import Literal
from typing import Optional
from typing import TYPE_CHECKING
from typing import TypeVar
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.datastores.interfaces import DatastoreFilter
from danswer.db.models import Connector
from danswer.db.models import IndexingStatus
from pydantic import BaseModel
@ -75,20 +73,25 @@ class SearchDoc(BaseModel):
source_type: str
class QAQuestion(BaseModel):
class QuestionRequest(BaseModel):
query: str
collection: str
filters: list[DatastoreFilter] | None
use_keyword: bool | None
filters: str | None # string of list[IndexFilter]
class SearchResponse(BaseModel):
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
top_ranked_docs: list[SearchDoc] | None
semi_ranked_docs: list[SearchDoc] | None
class QAResponse(BaseModel):
answer: str | None
quotes: dict[str, dict[str, str | int | None]] | None
ranked_documents: list[SearchDoc] | None
class KeywordResponse(BaseModel):
results: list[str] | None
# for performance, only a few top documents are cross-encoded for rerank, the rest follow retrieval order
unranked_documents: list[SearchDoc] | None
class UserByEmail(BaseModel):

View File

@ -1,96 +1,105 @@
import json
import time
from collections.abc import Generator
from danswer.auth.schemas import UserRole
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.configs.app_configs import KEYWORD_MAX_HITS
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import CONTENT
from danswer.configs.constants import SOURCE_LINKS
from danswer.datastores import create_datastore
from danswer.db.engine import build_async_engine
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.question_answer import get_json_line
from danswer.semantic_search.semantic_search import retrieve_ranked_documents
from danswer.server.models import KeywordResponse
from danswer.server.models import QAQuestion
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.server.models import QAResponse
from danswer.server.models import SearchDoc
from danswer.server.models import UserByEmail
from danswer.server.models import UserRoleResponse
from danswer.utils.clients import TSClient
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchResponse
from danswer.utils.logging import setup_logger
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession
logger = setup_logger()
router = APIRouter()
@router.get("/get-user-role", response_model=UserRoleResponse)
async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
if user is None:
raise ValueError("Invalid or missing user.")
return UserRoleResponse(role=user.role)
@router.get("/semantic-search")
def semantic_search(
question: QuestionRequest = Depends(), user: User = Depends(current_user)
) -> SearchResponse:
query = question.query
collection = question.collection
filters = json.loads(question.filters) if question.filters is not None else None
logger.info(f"Received semantic search query: {query}")
user_id = None if user is None else int(user.id)
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, QdrantIndex(collection)
)
if not ranked_chunks:
return SearchResponse(top_ranked_docs=None, semi_ranked_docs=None)
top_docs = chunks_to_search_docs(ranked_chunks)
other_top_docs = chunks_to_search_docs(unranked_chunks)
return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=other_top_docs)
@router.patch("/promote-user-to-admin", response_model=None)
async def promote_admin(
user_email: UserByEmail, user: User = Depends(current_admin_user)
) -> None:
if user.role != UserRole.ADMIN:
raise HTTPException(status_code=401, detail="Unauthorized")
async with AsyncSession(build_async_engine()) as asession:
user_db = SQLAlchemyUserDatabase(asession, User) # type: ignore
user_to_promote = await user_db.get_by_email(user_email.user_email)
if not user_to_promote:
raise HTTPException(status_code=404, detail="User not found")
user_to_promote.role = UserRole.ADMIN
asession.add(user_to_promote)
await asession.commit()
return
@router.get("/keyword-search", response_model=SearchResponse)
def keyword_search(
question: QuestionRequest = Depends(), user: User = Depends(current_user)
) -> SearchResponse:
query = question.query
collection = question.collection
filters = json.loads(question.filters) if question.filters is not None else None
logger.info(f"Received keyword search query: {query}")
user_id = None if user is None else int(user.id)
ranked_chunks = retrieve_keyword_documents(
query, user_id, filters, TypesenseIndex(collection)
)
if not ranked_chunks:
return SearchResponse(top_ranked_docs=None, semi_ranked_docs=None)
top_docs = chunks_to_search_docs(ranked_chunks)
return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=None)
@router.get("/direct-qa", response_model=QAResponse)
def direct_qa(
question: QAQuestion = Depends(), user: User = Depends(current_user)
question: QuestionRequest = Depends(), user: User = Depends(current_user)
) -> QAResponse:
start_time = time.time()
query = question.query
collection = question.collection
filters = question.filters
logger.info(f"Received semantic query: {query}")
filters = json.loads(question.filters) if question.filters is not None else None
use_keyword = question.use_keyword
logger.info(f"Received QA query: {query}")
user_id = None if user is None else int(user.id)
ranked_chunks = retrieve_ranked_documents(
query, user_id, filters, create_datastore(collection)
)
if not ranked_chunks:
return QAResponse(answer=None, quotes=None, ranked_documents=None)
top_docs = [
SearchDoc(
semantic_identifier=chunk.semantic_identifier,
link=chunk.source_links.get(0) if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query, user_id, filters, TypesenseIndex(collection)
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, QdrantIndex(collection)
)
if not ranked_chunks:
return QAResponse(
answer=None, quotes=None, ranked_documents=None, unranked_documents=None
)
for chunk in ranked_chunks
]
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
try:
answer, quotes = qa_model.answer_question(
query, ranked_chunks[:NUM_RERANKED_RESULTS]
query, ranked_chunks[:NUM_GENERATIVE_AI_INPUT_DOCS]
)
except Exception:
# exception is logged in the answer_question method, no need to re-log
@ -98,45 +107,54 @@ def direct_qa(
logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse(answer=answer, quotes=quotes, ranked_documents=top_docs)
return QAResponse(
answer=answer,
quotes=quotes,
ranked_documents=chunks_to_search_docs(ranked_chunks),
unranked_documents=chunks_to_search_docs(unranked_chunks),
)
@router.get("/stream-direct-qa")
def stream_direct_qa(
question: QAQuestion = Depends(), user: User = Depends(current_user)
question: QuestionRequest = Depends(), user: User = Depends(current_user)
) -> StreamingResponse:
top_documents_key = "top_documents"
unranked_top_docs_key = "unranked_top_documents"
def stream_qa_portions() -> Generator[str, None, None]:
query = question.query
collection = question.collection
filters = question.filters
logger.info(f"Received semantic query: {query}")
filters = json.loads(question.filters) if question.filters is not None else None
use_keyword = question.use_keyword
logger.info(f"Received QA query: {query}")
user_id = None if user is None else int(user.id)
ranked_chunks = retrieve_ranked_documents(
query, user_id, filters, create_datastore(collection)
)
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query, user_id, filters, TypesenseIndex(collection)
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, QdrantIndex(collection)
)
if not ranked_chunks:
yield get_json_line({top_documents_key: None})
yield get_json_line({top_documents_key: None, unranked_top_docs_key: None})
return
top_docs = [
SearchDoc(
semantic_identifier=chunk.semantic_identifier,
link=chunk.source_links.get(0) if chunk.source_links else None,
blurb=chunk.blurb,
source_type=chunk.source_type,
)
for chunk in ranked_chunks
]
top_docs_dict = {top_documents_key: [top_doc.json() for top_doc in top_docs]}
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
top_docs_dict = {
top_documents_key: [top_doc.json() for top_doc in top_docs],
unranked_top_docs_key: [doc.json() for doc in unranked_top_docs],
}
yield get_json_line(top_docs_dict)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
try:
for response_dict in qa_model.answer_question_stream(
query, ranked_chunks[:NUM_RERANKED_RESULTS]
query, ranked_chunks[:NUM_GENERATIVE_AI_INPUT_DOCS]
):
if response_dict is None:
continue
@ -145,36 +163,6 @@ def stream_direct_qa(
except Exception:
# exception is logged in the answer_question method, no need to re-log
pass
return
return StreamingResponse(stream_qa_portions(), media_type="application/json")
@router.get("/keyword-search", response_model=KeywordResponse)
def keyword_search(
question: QAQuestion = Depends(), _: User = Depends(current_user)
) -> KeywordResponse:
ts_client = TSClient.get_instance()
query = question.query
collection = question.collection
logger.info(f"Received keyword query: {query}")
start_time = time.time()
search_results = ts_client.collections[collection].documents.search(
{
"q": query,
"query_by": CONTENT,
"per_page": KEYWORD_MAX_HITS,
"limit_hits": KEYWORD_MAX_HITS,
}
)
hits = search_results["hits"]
sources = [hit["document"][SOURCE_LINKS][0] for hit in hits]
total_time = time.time() - start_time
logger.info(f"Total Keyword Search took {total_time} seconds")
return KeywordResponse(results=sources)

View File

@ -1,8 +1,4 @@
from typing import Any
from typing import Optional
import typesense # type: ignore
from danswer.configs.app_configs import DB_CONN_TIMEOUT
from danswer.configs.app_configs import QDRANT_API_KEY
from danswer.configs.app_configs import QDRANT_HOST
from danswer.configs.app_configs import QDRANT_PORT
@ -14,6 +10,7 @@ from qdrant_client import QdrantClient
_qdrant_client: QdrantClient | None = None
_typesense_client: typesense.Client | None = None
def get_qdrant_client() -> QdrantClient:
@ -29,35 +26,23 @@ def get_qdrant_client() -> QdrantClient:
return _qdrant_client
class TSClient:
__instance: Optional["TSClient"] = None
@staticmethod
def get_instance(
host: str = TYPESENSE_HOST,
port: int = TYPESENSE_PORT,
api_key: str = TYPESENSE_API_KEY,
timeout: int = DB_CONN_TIMEOUT,
) -> "TSClient":
if TSClient.__instance is None:
TSClient(host, port, api_key, timeout)
return TSClient.__instance # type: ignore
def __init__(self, host: str, port: int, api_key: str, timeout: int) -> None:
if TSClient.__instance is not None:
raise Exception(
"Singleton instance already exists. Use TSClient.get_instance() to get the instance."
)
else:
TSClient.__instance = self
self.client = typesense.Client(
def get_typesense_client() -> typesense.Client:
global _typesense_client
if _typesense_client is None:
if TYPESENSE_HOST and TYPESENSE_PORT and TYPESENSE_API_KEY:
_typesense_client = typesense.Client(
{
"api_key": api_key,
"nodes": [{"host": host, "port": str(port), "protocol": "http"}],
"connection_timeout_seconds": timeout,
"api_key": TYPESENSE_API_KEY,
"nodes": [
{
"host": TYPESENSE_HOST,
"port": str(TYPESENSE_PORT),
"protocol": "http",
}
],
}
)
else:
raise Exception("Unable to instantiate TypesenseClient")
# delegate all client operations to the third party client
def __getattr__(self, name: str) -> Any:
return getattr(self.client, name)
return _typesense_client

View File

@ -1,17 +1,17 @@
from collections.abc import Callable
from functools import partial
from itertools import chain
from typing import Any
from typing import Protocol
from danswer.chunking.chunk import Chunker
from danswer.chunking.chunk import DefaultChunker
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.connectors.models import Document
from danswer.datastores.interfaces import Datastore
from danswer.datastores.qdrant.store import QdrantDatastore
from danswer.semantic_search.biencoder import DefaultEmbedder
from danswer.semantic_search.type_aliases import Embedder
from danswer.datastores.interfaces import KeywordIndex
from danswer.datastores.interfaces import VectorIndex
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.search.semantic_search import DefaultEmbedder
from danswer.search.type_aliases import Embedder
class IndexingPipelineProtocol(Protocol):
@ -25,15 +25,18 @@ def _indexing_pipeline(
*,
chunker: Chunker,
embedder: Embedder,
datastore: Datastore,
vector_index: VectorIndex,
keyword_index: KeywordIndex,
documents: list[Document],
user_id: int | None,
) -> list[EmbeddedIndexChunk]:
# TODO: make entire indexing pipeline async to not block the entire process
# when running on async endpoints
chunks = list(chain(*[chunker.chunk(document) for document in documents]))
# TODO keyword indexing can occur at same time as embedding
keyword_index.index(chunks, user_id)
chunks_with_embeddings = embedder.embed(chunks)
datastore.index(chunks_with_embeddings, user_id)
vector_index.index(chunks_with_embeddings, user_id)
return chunks_with_embeddings
@ -41,7 +44,8 @@ def build_indexing_pipeline(
*,
chunker: Chunker | None = None,
embedder: Embedder | None = None,
datastore: Datastore | None = None,
vector_index: VectorIndex | None = None,
keyword_index: KeywordIndex | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipline which takes in a list of docs and indexes them.
@ -52,9 +56,16 @@ def build_indexing_pipeline(
if embedder is None:
embedder = DefaultEmbedder()
if datastore is None:
datastore = QdrantDatastore()
if vector_index is None:
vector_index = QdrantIndex()
if keyword_index is None:
keyword_index = TypesenseIndex()
return partial(
_indexing_pipeline, chunker=chunker, embedder=embedder, datastore=datastore
_indexing_pipeline,
chunker=chunker,
embedder=embedder,
vector_index=vector_index,
keyword_index=keyword_index,
)

View File

@ -13,6 +13,7 @@ httpcore==0.16.3
httpx==0.23.3
httpx-oauth==0.11.2
Mako==1.2.4
nltk==3.8.1
openai==0.27.6
playwright==1.32.1
psycopg2==2.9.6
@ -21,7 +22,7 @@ pydantic==1.10.7
PyGithub==1.58.2
PyPDF2==3.0.1
pytest-playwright==0.3.2
qdrant-client==1.1.0
qdrant-client==1.2.0
requests==2.28.2
rfc3986==1.5.0
sentence-transformers==2.2.2

View File

@ -0,0 +1,39 @@
# This file is purely for development use, not included in any builds
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.datastores.typesense.store import create_typesense_collection
from danswer.utils.clients import get_qdrant_client
from danswer.utils.clients import get_typesense_client
from danswer.utils.logging import setup_logger
from qdrant_client.http.models import Distance
from qdrant_client.http.models import VectorParams
from typesense.exceptions import ObjectNotFound # type: ignore
logger = setup_logger()
def recreate_qdrant_collection(
collection_name: str, embedding_dim: int = DOC_EMBEDDING_DIM
) -> None:
logger.info(f"Attempting to recreate Qdrant collection {collection_name}")
result = get_qdrant_client().recreate_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)
if not result:
raise RuntimeError("Could not create Qdrant collection")
def recreate_typesense_collection(collection_name: str) -> None:
logger.info(f"Attempting to recreate Typesense collection {collection_name}")
ts_client = get_typesense_client()
try:
ts_client.collections[collection_name].delete()
except ObjectNotFound:
logger.debug(f"Collection {collection_name} does not already exist")
create_typesense_collection(collection_name)
if __name__ == "__main__":
recreate_qdrant_collection("danswer_index")
recreate_typesense_collection("danswer_index")

View File

@ -1,13 +1,12 @@
# This file is purely for development use, not included in any builds
import argparse
import json
import urllib
from pprint import pprint
import requests
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.constants import BLURB
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
@ -16,35 +15,44 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-k",
"--keyword-search",
action="store_true",
help="Use keyword search if set, semantic search otherwise",
"-f",
"--flow",
type=str,
default="QA",
help='"Search" or "QA", defaults to "QA"',
)
parser.add_argument(
"-t",
"--source-types",
"--type",
type=str,
help="Comma separated list of source types to filter by",
default="Semantic",
help='"Semantic" or "Keyword", defaults to "Semantic"',
)
parser.add_argument(
"-s",
"--stream",
action="store_true",
help="Enable streaming response",
help='Enable streaming response, only for flow="QA"',
)
parser.add_argument(
"--filters",
type=str,
help="Comma separated list of source types to filter by (no spaces)",
)
parser.add_argument("query", nargs="*", help="The query to process")
previous_input = None
while True:
try:
user_input = input(
"\n\nAsk any question:\n"
" - prefix with -t to add a filter by source type(s)\n"
" - prefix with -s to stream answer\n"
" - input an empty string to rerun last query\n\t"
" - Use -f (QA/Search) and -t (Semantic/Keyword) flags to set endpoint.\n"
" - prefix with -s to stream answer, --filters web,slack etc. for filters.\n"
" - input an empty string to rerun last query.\n\t"
)
if user_input:
@ -58,62 +66,51 @@ if __name__ == "__main__":
args = parser.parse_args(user_input.split())
keyword_search = args.keyword_search
source_types = args.source_types.split(",") if args.source_types else None
flow = str(args.flow).lower()
flow_type = str(args.type).lower()
stream = args.stream
source_types = args.filters.split(",") if args.filters else None
if source_types and len(source_types) == 1:
source_types = source_types[0]
query = " ".join(args.query)
endpoint = (
f"http://127.0.0.1:{APP_PORT}/direct-qa"
if not args.stream
else f"http://127.0.0.1:{APP_PORT}/stream-direct-qa"
)
if args.keyword_search:
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
raise NotImplementedError("keyword search is not supported for now")
if flow not in ["qa", "search"]:
raise ValueError("Flow value must be QA or Search")
if flow_type not in ["keyword", "semantic"]:
raise ValueError("Type value must be keyword or semantic")
if flow != "qa" and stream:
raise ValueError("Can only stream results for QA")
if (flow, flow_type) == ("search", "keyword"):
path = "keyword-search"
elif (flow, flow_type) == ("search", "semantic"):
path = "semantic-search"
elif stream:
path = "stream-direct-qa"
else:
path = "direct-qa"
endpoint = f"http://127.0.0.1:{APP_PORT}/{path}"
query_json = {
"query": query,
"collection": QDRANT_DEFAULT_COLLECTION,
"filters": [{SOURCE_TYPE: source_types}],
"use_keyword": flow_type == "keyword", # Ignore if not QA Endpoints
"filters": json.dumps([{SOURCE_TYPE: source_types}]),
}
if not args.stream:
response = requests.get(
endpoint, params=urllib.parse.urlencode(query_json)
)
contents = json.loads(response.content)
if keyword_search:
if contents["results"]:
for link in contents["results"]:
print(link)
else:
print("No matches found")
else:
answer = contents.get("answer")
if answer:
print("Answer: " + answer)
else:
print("Answer: ?")
if contents.get("quotes"):
for ind, (quote, quote_info) in enumerate(
contents["quotes"].items()
):
print(f"Quote {str(ind + 1)}:\n{quote}")
print(
f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}"
)
print(f"Blurb: {quote_info[BLURB]}")
print(f"Link: {quote_info[SOURCE_LINK]}")
print(f"Source: {quote_info[SOURCE_TYPE]}")
else:
print("No quotes found")
else:
if args.stream:
with requests.get(
endpoint, params=urllib.parse.urlencode(query_json), stream=True
) as r:
for json_response in r.iter_lines():
print(json.loads(json_response.decode()))
pprint(json.loads(json_response.decode()))
else:
response = requests.get(
endpoint, params=urllib.parse.urlencode(query_json)
)
contents = json.loads(response.content)
pprint(contents)
except Exception as e:
print(f"Failed due to {e}, retrying")

View File

@ -1,5 +1,2 @@
# For a local deployment, no additional setup is needed
# Refer to env.dev.template and env.prod.template for additional options
# Setting Auth to false for local setup convenience to avoid setting up Google OAuth app in GPC.
DISABLE_AUTH=True
# This empty .env file is provided for compatibility with older Docker/Docker-Compose installations
# To change default values, check env.dev.template or env.prod.template

View File

@ -7,6 +7,7 @@ services:
depends_on:
- relational_db
- vector_db
- search_engine
restart: always
ports:
- "8080:8080"
@ -15,6 +16,9 @@ services:
environment:
- POSTGRES_HOST=relational_db
- QDRANT_HOST=vector_db
- TYPESENSE_HOST=search_engine
- TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-local_dev_typesense}
- DISABLE_AUTH=True
volumes:
- local_dynamic_storage:/home/storage
background:
@ -43,12 +47,13 @@ services:
- .env
environment:
- INTERNAL_URL=http://api_server:8080
- DISABLE_AUTH=True
relational_db:
image: postgres:15.2-alpine
restart: always
environment:
POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password}
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
env_file:
- .env
ports:
@ -58,10 +63,24 @@ services:
vector_db:
image: qdrant/qdrant:v1.1.3
restart: always
env_file:
- .env
ports:
- "6333:6333"
volumes:
- qdrant_volume:/qdrant/storage
search_engine:
image: typesense/typesense:0.24.1
restart: always
environment:
- TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-local_dev_typesense}
- TYPESENSE_DATA_DIR=/typesense/storage
env_file:
- .env
ports:
- "8108:8108"
volumes:
- typesense_volume:/typesense/storage
nginx:
image: nginx:1.23.4-alpine
restart: always
@ -82,3 +101,4 @@ volumes:
local_dynamic_storage:
db_volume:
qdrant_volume:
typesense_volume:

View File

@ -7,12 +7,15 @@ services:
depends_on:
- relational_db
- vector_db
- search_engine
restart: always
env_file:
- .env
environment:
- POSTGRES_HOST=relational_db
- QDRANT_HOST=vector_db
- TYPESENSE_HOST=search_engine
- TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-local_dev_typesense}
volumes:
- local_dynamic_storage:/home/storage
background:
@ -54,8 +57,22 @@ services:
vector_db:
image: qdrant/qdrant:v1.1.3
restart: always
env_file:
- .env
volumes:
- qdrant_volume:/qdrant/storage
search_engine:
image: typesense/typesense:0.24.1
restart: always
# TYPESENSE_API_KEY must be set in .env file
environment:
- TYPESENSE_DATA_DIR=/typesense/storage
env_file:
- .env
ports:
- "8108:8108"
volumes:
- typesense_volume:/typesense/storage
nginx:
image: nginx:1.23.4-alpine
restart: always
@ -83,3 +100,4 @@ volumes:
local_dynamic_storage:
db_volume:
qdrant_volume:
typesense_volume:

View File

@ -1,13 +1,8 @@
# Very basic .env file with options that are easy to change. Allows you to deploy everything on a single machine.
# We don't suggest using these settings for production.
# .env is not required unless you wish to change defaults
# Choose between "openai-chat-completion" and "openai-completion"
INTERNAL_MODEL_VERSION=openai-chat-completion
# Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility
OPENAPI_MODEL_VERSION=gpt-3.5-turbo
# Auth not necessary for local
DISABLE_AUTH=True

View File

@ -13,14 +13,8 @@ OPENAI_MODEL_VERSION=gpt-4
# Could be something like danswer.companyname.com. Requires additional setup if not localhost
WEB_DOMAIN=http://localhost:3000
# BACKEND DB can leave these as defaults
POSTGRES_USER=postgres
POSTGRES_PASSWORD=password
# AUTH CONFIGS
DISABLE_AUTH=False
# Required
TYPESENSE_API_KEY=
# Currently frontend page doesn't have basic auth, use OAuth if user auth is enabled.
ENABLE_OAUTH=True

View File

@ -64,7 +64,7 @@ const searchRequestStreamed = async (
const url = new URL("/api/stream-direct-qa", window.location.origin);
const params = new URLSearchParams({
query,
collection: "semantic_search",
collection: "danswer_index",
}).toString();
url.search = params;