mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-23 12:31:30 +02:00
DAN-81 Improve search round 2 (#82)
Includes: - Multi vector indexing/search - Ensemble model reranking - Keyword Search backend
This commit is contained in:
@@ -45,6 +45,8 @@ def create_indexing_jobs(db_session: Session) -> None:
|
|||||||
in_progress_indexing_attempts = get_incomplete_index_attempts(
|
in_progress_indexing_attempts = get_incomplete_index_attempts(
|
||||||
connector.id, db_session
|
connector.id, db_session
|
||||||
)
|
)
|
||||||
|
if in_progress_indexing_attempts:
|
||||||
|
logger.error("Found incomplete indexing attempts")
|
||||||
|
|
||||||
# Currently single threaded so any still in-progress must have errored
|
# Currently single threaded so any still in-progress must have errored
|
||||||
for attempt in in_progress_indexing_attempts:
|
for attempt in in_progress_indexing_attempts:
|
||||||
@@ -113,7 +115,6 @@ def run_indexing_jobs(last_run_time: float, db_session: Session) -> None:
|
|||||||
|
|
||||||
document_ids: list[str] = []
|
document_ids: list[str] = []
|
||||||
for doc_batch in doc_batch_generator:
|
for doc_batch in doc_batch_generator:
|
||||||
# TODO introduce permissioning here
|
|
||||||
index_user_id = (
|
index_user_id = (
|
||||||
None if db_credential.public_doc else db_credential.user_id
|
None if db_credential.public_doc else db_credential.user_id
|
||||||
)
|
)
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
@@ -19,12 +20,14 @@ class BaseChunk:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IndexChunk(BaseChunk):
|
class IndexChunk(BaseChunk):
|
||||||
|
# During indexing flow, we have access to a complete "Document"
|
||||||
|
# During inference we only have access to the document id and do not reconstruct the Document
|
||||||
source_document: Document
|
source_document: Document
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddedIndexChunk(IndexChunk):
|
class EmbeddedIndexChunk(IndexChunk):
|
||||||
embedding: list[float]
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -39,8 +42,13 @@ class InferenceChunk(BaseChunk):
|
|||||||
k: v for k, v in init_dict.items() if k in inspect.signature(cls).parameters
|
k: v for k, v in init_dict.items() if k in inspect.signature(cls).parameters
|
||||||
}
|
}
|
||||||
if "source_links" in init_kwargs:
|
if "source_links" in init_kwargs:
|
||||||
|
source_links = init_kwargs["source_links"]
|
||||||
|
source_links_dict = (
|
||||||
|
json.loads(source_links)
|
||||||
|
if isinstance(source_links, str)
|
||||||
|
else source_links
|
||||||
|
)
|
||||||
init_kwargs["source_links"] = {
|
init_kwargs["source_links"] = {
|
||||||
int(k): v
|
int(k): v for k, v in cast(dict[str, str], source_links_dict).items()
|
||||||
for k, v in cast(dict[str, str], init_kwargs["source_links"]).items()
|
|
||||||
}
|
}
|
||||||
return cls(**init_kwargs)
|
return cls(**init_kwargs)
|
||||||
|
@@ -52,16 +52,23 @@ MASK_CREDENTIAL_PREFIX = (
|
|||||||
#####
|
#####
|
||||||
# DB Configs
|
# DB Configs
|
||||||
#####
|
#####
|
||||||
DEFAULT_VECTOR_STORE = os.environ.get("VECTOR_DB", "qdrant")
|
# Qdrant is Semantic Search Vector DB
|
||||||
# Url / Key are used to connect to a remote Qdrant instance
|
# Url / Key are used to connect to a remote Qdrant instance
|
||||||
QDRANT_URL = os.environ.get("QDRANT_URL", "")
|
QDRANT_URL = os.environ.get("QDRANT_URL", "")
|
||||||
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
|
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
|
||||||
# Host / Port are used for connecting to local Qdrant instance
|
# Host / Port are used for connecting to local Qdrant instance
|
||||||
QDRANT_HOST = os.environ.get("QDRANT_HOST", "localhost")
|
QDRANT_HOST = os.environ.get("QDRANT_HOST", "localhost")
|
||||||
QDRANT_PORT = 6333
|
QDRANT_PORT = 6333
|
||||||
QDRANT_DEFAULT_COLLECTION = os.environ.get("QDRANT_COLLECTION", "semantic_search")
|
QDRANT_DEFAULT_COLLECTION = os.environ.get("QDRANT_DEFAULT_COLLECTION", "danswer_index")
|
||||||
DB_CONN_TIMEOUT = 2 # Timeout seconds connecting to DBs
|
# Typesense is the Keyword Search Engine
|
||||||
INDEX_BATCH_SIZE = 16 # File batches (not accounting file chunking)
|
TYPESENSE_HOST = os.environ.get("TYPESENSE_HOST", "localhost")
|
||||||
|
TYPESENSE_PORT = 8108
|
||||||
|
TYPESENSE_DEFAULT_COLLECTION = os.environ.get(
|
||||||
|
"TYPESENSE_DEFAULT_COLLECTION", "danswer_index"
|
||||||
|
)
|
||||||
|
TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "")
|
||||||
|
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||||
|
INDEX_BATCH_SIZE = 16
|
||||||
|
|
||||||
# below are intended to match the env variables names used by the official postgres docker image
|
# below are intended to match the env variables names used by the official postgres docker image
|
||||||
# https://hub.docker.com/_/postgres
|
# https://hub.docker.com/_/postgres
|
||||||
@@ -81,13 +88,11 @@ GOOGLE_DRIVE_INCLUDE_SHARED = False
|
|||||||
#####
|
#####
|
||||||
# Query Configs
|
# Query Configs
|
||||||
#####
|
#####
|
||||||
DEFAULT_PROMPT = "generic-qa"
|
NUM_RETURNED_HITS = 50
|
||||||
NUM_RETURNED_HITS = 15
|
NUM_RERANKED_RESULTS = 15
|
||||||
NUM_RERANKED_RESULTS = 4
|
NUM_GENERATIVE_AI_INPUT_DOCS = 5
|
||||||
KEYWORD_MAX_HITS = 5
|
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
||||||
QUOTE_ALLOWED_ERROR_PERCENT = (
|
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||||
0.05 # 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
|
||||||
)
|
|
||||||
QA_TIMEOUT = 10 # 10 seconds
|
QA_TIMEOUT = 10 # 10 seconds
|
||||||
|
|
||||||
|
|
||||||
@@ -97,6 +102,11 @@ QA_TIMEOUT = 10 # 10 seconds
|
|||||||
# Chunking docs to this number of characters not including finishing the last word and the overlap words below
|
# Chunking docs to this number of characters not including finishing the last word and the overlap words below
|
||||||
# Calculated by ~500 to 512 tokens max * average 4 chars per token
|
# Calculated by ~500 to 512 tokens max * average 4 chars per token
|
||||||
CHUNK_SIZE = 2000
|
CHUNK_SIZE = 2000
|
||||||
|
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||||
|
ENABLE_MINI_CHUNK = False
|
||||||
|
# Mini chunks for fine-grained embedding, calculated as 128 tokens for 4 additional vectors for 512 chunk size above
|
||||||
|
# Not rounded down to not lose any context in full chunk.
|
||||||
|
MINI_CHUNK_SIZE = 512
|
||||||
# Each chunk includes an additional 5 words from previous chunk
|
# Each chunk includes an additional 5 words from previous chunk
|
||||||
# in extreme cases, may cause some words at the end to be truncated by embedding model
|
# in extreme cases, may cause some words at the end to be truncated by embedding model
|
||||||
CHUNK_OVERLAP = 5
|
CHUNK_OVERLAP = 5
|
||||||
@@ -120,10 +130,6 @@ CROSS_ENCODER_PORT = 9000
|
|||||||
#####
|
#####
|
||||||
# Miscellaneous
|
# Miscellaneous
|
||||||
#####
|
#####
|
||||||
TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "")
|
|
||||||
TYPESENSE_HOST = "localhost"
|
|
||||||
TYPESENSE_PORT = 8108
|
|
||||||
|
|
||||||
DYNAMIC_CONFIG_STORE = os.environ.get(
|
DYNAMIC_CONFIG_STORE = os.environ.get(
|
||||||
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
|
"DYNAMIC_CONFIG_STORE", "FileSystemBackedDynamicConfigStore"
|
||||||
)
|
)
|
||||||
|
@@ -5,13 +5,22 @@ import os
|
|||||||
# Models used must be MIT or Apache license
|
# Models used must be MIT or Apache license
|
||||||
# Inference/Indexing speed
|
# Inference/Indexing speed
|
||||||
|
|
||||||
# Bi/Cross-Encoder Model Configs
|
# https://www.sbert.net/docs/pretrained_models.html
|
||||||
# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding)
|
# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding)
|
||||||
|
# Context size is 256 for above though
|
||||||
DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1"
|
DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1"
|
||||||
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
|
||||||
DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model
|
DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model
|
||||||
|
|
||||||
|
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html
|
||||||
|
# Previously using "cross-encoder/ms-marco-MiniLM-L-6-v2" alone
|
||||||
|
CROSS_ENCODER_MODEL_ENSEMBLE = [
|
||||||
|
"cross-encoder/ms-marco-MiniLM-L-4-v2",
|
||||||
|
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||||
|
]
|
||||||
|
|
||||||
QUERY_EMBEDDING_CONTEXT_SIZE = 256
|
QUERY_EMBEDDING_CONTEXT_SIZE = 256
|
||||||
|
# The below is correlated with CHUNK_SIZE in app_configs but not strictly calculated
|
||||||
|
# To avoid extra overhead of tokenizing for chunking during indexing.
|
||||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||||
CROSS_EMBED_CONTEXT_SIZE = 512
|
CROSS_EMBED_CONTEXT_SIZE = 512
|
||||||
|
|
||||||
|
@@ -79,92 +79,96 @@ class WebConnector(LoadConnector):
|
|||||||
if self.base_url[-1] != "/":
|
if self.base_url[-1] != "/":
|
||||||
visited_links.add(self.base_url + "/")
|
visited_links.add(self.base_url + "/")
|
||||||
|
|
||||||
with sync_playwright() as playwright:
|
restart_playwright = True
|
||||||
browser = playwright.chromium.launch(headless=True)
|
while to_visit:
|
||||||
context = browser.new_context()
|
current_url = to_visit.pop()
|
||||||
|
if current_url in visited_links:
|
||||||
|
continue
|
||||||
|
visited_links.add(current_url)
|
||||||
|
|
||||||
while to_visit:
|
try:
|
||||||
current_url = to_visit.pop()
|
if restart_playwright:
|
||||||
if current_url in visited_links:
|
playwright = sync_playwright().start()
|
||||||
continue
|
browser = playwright.chromium.launch(headless=True)
|
||||||
visited_links.add(current_url)
|
context = browser.new_context()
|
||||||
|
restart_playwright = False
|
||||||
|
|
||||||
try:
|
if current_url.split(".")[-1] == "pdf":
|
||||||
if current_url.split(".")[-1] == "pdf":
|
# PDF files are not checked for links
|
||||||
# PDF files are not checked for links
|
response = requests.get(current_url)
|
||||||
response = requests.get(current_url)
|
pdf_reader = PdfReader(io.BytesIO(response.content))
|
||||||
pdf_reader = PdfReader(io.BytesIO(response.content))
|
page_text = ""
|
||||||
page_text = ""
|
for pdf_page in pdf_reader.pages:
|
||||||
for pdf_page in pdf_reader.pages:
|
page_text += pdf_page.extract_text()
|
||||||
page_text += pdf_page.extract_text()
|
|
||||||
|
|
||||||
doc_batch.append(
|
|
||||||
Document(
|
|
||||||
id=current_url,
|
|
||||||
sections=[Section(link=current_url, text=page_text)],
|
|
||||||
source=DocumentSource.WEB,
|
|
||||||
semantic_identifier=current_url.split(".")[-1],
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
page = context.new_page()
|
|
||||||
page.goto(current_url)
|
|
||||||
content = page.content()
|
|
||||||
soup = BeautifulSoup(content, "html.parser")
|
|
||||||
|
|
||||||
internal_links = get_internal_links(
|
|
||||||
self.base_url, current_url, soup
|
|
||||||
)
|
|
||||||
for link in internal_links:
|
|
||||||
if link not in visited_links:
|
|
||||||
to_visit.append(link)
|
|
||||||
|
|
||||||
title_tag = soup.find("title")
|
|
||||||
title = None
|
|
||||||
if title_tag and title_tag.text:
|
|
||||||
title = title_tag.text
|
|
||||||
|
|
||||||
# Heuristics based cleaning
|
|
||||||
for undesired_div in ["sidebar", "header", "footer"]:
|
|
||||||
[
|
|
||||||
tag.extract()
|
|
||||||
for tag in soup.find_all(
|
|
||||||
"div", class_=lambda x: x and undesired_div in x.split()
|
|
||||||
)
|
|
||||||
]
|
|
||||||
|
|
||||||
for undesired_tag in [
|
|
||||||
"nav",
|
|
||||||
"header",
|
|
||||||
"footer",
|
|
||||||
"meta",
|
|
||||||
"script",
|
|
||||||
"style",
|
|
||||||
]:
|
|
||||||
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
|
||||||
|
|
||||||
page_text = soup.get_text(HTML_SEPARATOR)
|
|
||||||
|
|
||||||
doc_batch.append(
|
doc_batch.append(
|
||||||
Document(
|
Document(
|
||||||
id=current_url,
|
id=current_url,
|
||||||
sections=[Section(link=current_url, text=page_text)],
|
sections=[Section(link=current_url, text=page_text)],
|
||||||
source=DocumentSource.WEB,
|
source=DocumentSource.WEB,
|
||||||
semantic_identifier=title,
|
semantic_identifier=current_url.split(".")[-1],
|
||||||
metadata={},
|
metadata={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
page.close()
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to fetch '{current_url}': {e}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if len(doc_batch) >= self.batch_size:
|
page = context.new_page()
|
||||||
yield doc_batch
|
page.goto(current_url)
|
||||||
doc_batch = []
|
content = page.content()
|
||||||
|
soup = BeautifulSoup(content, "html.parser")
|
||||||
|
|
||||||
if doc_batch:
|
internal_links = get_internal_links(self.base_url, current_url, soup)
|
||||||
|
for link in internal_links:
|
||||||
|
if link not in visited_links:
|
||||||
|
to_visit.append(link)
|
||||||
|
|
||||||
|
title_tag = soup.find("title")
|
||||||
|
title = None
|
||||||
|
if title_tag and title_tag.text:
|
||||||
|
title = title_tag.text
|
||||||
|
|
||||||
|
# Heuristics based cleaning
|
||||||
|
for undesired_div in ["sidebar", "header", "footer"]:
|
||||||
|
[
|
||||||
|
tag.extract()
|
||||||
|
for tag in soup.find_all(
|
||||||
|
"div", class_=lambda x: x and undesired_div in x.split()
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
for undesired_tag in [
|
||||||
|
"nav",
|
||||||
|
"header",
|
||||||
|
"footer",
|
||||||
|
"meta",
|
||||||
|
"script",
|
||||||
|
"style",
|
||||||
|
]:
|
||||||
|
[tag.extract() for tag in soup.find_all(undesired_tag)]
|
||||||
|
|
||||||
|
page_text = soup.get_text(HTML_SEPARATOR)
|
||||||
|
|
||||||
|
doc_batch.append(
|
||||||
|
Document(
|
||||||
|
id=current_url,
|
||||||
|
sections=[Section(link=current_url, text=page_text)],
|
||||||
|
source=DocumentSource.WEB,
|
||||||
|
semantic_identifier=title,
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
page.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to fetch '{current_url}': {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(doc_batch) >= self.batch_size:
|
||||||
|
playwright.stop()
|
||||||
|
restart_playwright = True
|
||||||
yield doc_batch
|
yield doc_batch
|
||||||
|
doc_batch = []
|
||||||
|
|
||||||
|
if doc_batch:
|
||||||
|
playwright.stop()
|
||||||
|
yield doc_batch
|
||||||
|
@@ -1,25 +0,0 @@
|
|||||||
from typing import Type
|
|
||||||
|
|
||||||
from danswer.configs.app_configs import DEFAULT_VECTOR_STORE
|
|
||||||
from danswer.datastores.interfaces import Datastore
|
|
||||||
from danswer.datastores.qdrant.store import QdrantDatastore
|
|
||||||
|
|
||||||
|
|
||||||
def get_selected_datastore_cls(
|
|
||||||
vector_db_type: str = DEFAULT_VECTOR_STORE,
|
|
||||||
) -> Type[Datastore]:
|
|
||||||
"""Returns the selected Datastore cls. Only one datastore
|
|
||||||
should be selected for a specific deployment."""
|
|
||||||
if vector_db_type == "qdrant":
|
|
||||||
return QdrantDatastore
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid Vector DB setting: {vector_db_type}")
|
|
||||||
|
|
||||||
|
|
||||||
def create_datastore(
|
|
||||||
collection: str, vector_db_type: str = DEFAULT_VECTOR_STORE
|
|
||||||
) -> Datastore:
|
|
||||||
if vector_db_type == "qdrant":
|
|
||||||
return QdrantDatastore(collection=collection)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid Vector DB setting: {vector_db_type}")
|
|
||||||
|
72
backend/danswer/datastores/datastore_utils.py
Normal file
72
backend/danswer/datastores/datastore_utils.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
import uuid
|
||||||
|
from collections.abc import Callable
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
from danswer.chunking.models import EmbeddedIndexChunk
|
||||||
|
from danswer.chunking.models import IndexChunk
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.constants import ALLOWED_GROUPS
|
||||||
|
from danswer.configs.constants import ALLOWED_USERS
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_BATCH_SIZE = 30
|
||||||
|
|
||||||
|
|
||||||
|
def get_uuid_from_chunk(
|
||||||
|
chunk: IndexChunk | EmbeddedIndexChunk | InferenceChunk, mini_chunk_ind: int = 0
|
||||||
|
) -> uuid.UUID:
|
||||||
|
doc_str = (
|
||||||
|
chunk.document_id
|
||||||
|
if isinstance(chunk, InferenceChunk)
|
||||||
|
else chunk.source_document.id
|
||||||
|
)
|
||||||
|
# Web parsing URL duplicate catching
|
||||||
|
if doc_str and doc_str[-1] == "/":
|
||||||
|
doc_str = doc_str[:-1]
|
||||||
|
unique_identifier_string = "_".join(
|
||||||
|
[doc_str, str(chunk.chunk_id), str(mini_chunk_ind)]
|
||||||
|
)
|
||||||
|
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||||
|
|
||||||
|
|
||||||
|
# Takes the chunk identifier returns whether the chunk exists and the user/group whitelists
|
||||||
|
WhitelistCallable = Callable[[str], tuple[bool, list[str], list[str]]]
|
||||||
|
|
||||||
|
|
||||||
|
def update_doc_user_map(
|
||||||
|
chunk: IndexChunk | EmbeddedIndexChunk,
|
||||||
|
doc_whitelist_map: dict[str, dict[str, list[str]]],
|
||||||
|
doc_store_whitelist_fnc: WhitelistCallable,
|
||||||
|
user_str: str,
|
||||||
|
) -> tuple[dict[str, dict[str, list[str]]], bool]:
|
||||||
|
"""Returns an updated document id to whitelists mapping and if the document's chunks need to be wiped."""
|
||||||
|
doc_whitelist_map = deepcopy(doc_whitelist_map)
|
||||||
|
first_chunk_uuid = str(get_uuid_from_chunk(chunk))
|
||||||
|
document = chunk.source_document
|
||||||
|
if document.id not in doc_whitelist_map:
|
||||||
|
first_chunk_found, whitelist_users, whitelist_groups = doc_store_whitelist_fnc(
|
||||||
|
first_chunk_uuid
|
||||||
|
)
|
||||||
|
|
||||||
|
if not first_chunk_found:
|
||||||
|
doc_whitelist_map[document.id] = {
|
||||||
|
ALLOWED_USERS: [user_str],
|
||||||
|
# TODO introduce groups logic here
|
||||||
|
ALLOWED_GROUPS: whitelist_groups,
|
||||||
|
}
|
||||||
|
# First chunk does not exist so document does not exist, no need for deletion
|
||||||
|
return doc_whitelist_map, False
|
||||||
|
else:
|
||||||
|
if user_str not in whitelist_users:
|
||||||
|
whitelist_users.append(user_str)
|
||||||
|
# TODO introduce groups logic here
|
||||||
|
doc_whitelist_map[document.id] = {
|
||||||
|
ALLOWED_USERS: whitelist_users,
|
||||||
|
ALLOWED_GROUPS: whitelist_groups,
|
||||||
|
}
|
||||||
|
# First chunk exists, but with update, there may be less total chunks now
|
||||||
|
# Must delete rest of document chunks
|
||||||
|
return doc_whitelist_map, True
|
||||||
|
|
||||||
|
# If document is already in the mapping, don't delete again
|
||||||
|
return doc_whitelist_map, False
|
@@ -1,22 +1,42 @@
|
|||||||
import abc
|
import abc
|
||||||
|
from typing import Generic
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from danswer.chunking.models import BaseChunk
|
||||||
from danswer.chunking.models import EmbeddedIndexChunk
|
from danswer.chunking.models import EmbeddedIndexChunk
|
||||||
|
from danswer.chunking.models import IndexChunk
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
|
||||||
DatastoreFilter = dict[str, str | list[str] | None]
|
|
||||||
|
T = TypeVar("T", bound=BaseChunk)
|
||||||
|
IndexFilter = dict[str, str | list[str] | None]
|
||||||
|
|
||||||
|
|
||||||
class Datastore:
|
class DocumentIndex(Generic[T], abc.ABC):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def index(self, chunks: list[EmbeddedIndexChunk], user_id: int | None) -> bool:
|
def index(self, chunks: list[T], user_id: int | None) -> bool:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class VectorIndex(DocumentIndex[EmbeddedIndexChunk], abc.ABC):
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def semantic_retrieval(
|
def semantic_retrieval(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
user_id: int | None,
|
user_id: int | None,
|
||||||
filters: list[DatastoreFilter] | None,
|
filters: list[IndexFilter] | None,
|
||||||
|
num_to_retrieve: int,
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordIndex(DocumentIndex[IndexChunk], abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def keyword_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
user_id: int | None,
|
||||||
|
filters: list[IndexFilter] | None,
|
||||||
num_to_retrieve: int,
|
num_to_retrieve: int,
|
||||||
) -> list[InferenceChunk]:
|
) -> list[InferenceChunk]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
import uuid
|
from functools import partial
|
||||||
|
|
||||||
from danswer.chunking.models import EmbeddedIndexChunk
|
from danswer.chunking.models import EmbeddedIndexChunk
|
||||||
from danswer.configs.constants import ALLOWED_GROUPS
|
from danswer.configs.constants import ALLOWED_GROUPS
|
||||||
@@ -13,9 +13,11 @@ from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
|||||||
from danswer.configs.constants import SOURCE_LINKS
|
from danswer.configs.constants import SOURCE_LINKS
|
||||||
from danswer.configs.constants import SOURCE_TYPE
|
from danswer.configs.constants import SOURCE_TYPE
|
||||||
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||||
|
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
|
||||||
|
from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
||||||
|
from danswer.datastores.datastore_utils import update_doc_user_map
|
||||||
from danswer.utils.clients import get_qdrant_client
|
from danswer.utils.clients import get_qdrant_client
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from danswer.utils.timing import log_function_time
|
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
from qdrant_client.http import models
|
from qdrant_client.http import models
|
||||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||||
@@ -26,16 +28,15 @@ from qdrant_client.models import Distance
|
|||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
from qdrant_client.models import VectorParams
|
from qdrant_client.models import VectorParams
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
DEFAULT_BATCH_SIZE = 30
|
|
||||||
|
|
||||||
|
def list_qdrant_collections() -> CollectionsResponse:
|
||||||
def list_collections() -> CollectionsResponse:
|
|
||||||
return get_qdrant_client().get_collections()
|
return get_qdrant_client().get_collections()
|
||||||
|
|
||||||
|
|
||||||
def create_collection(
|
def create_qdrant_collection(
|
||||||
collection_name: str, embedding_dim: int = DOC_EMBEDDING_DIM
|
collection_name: str, embedding_dim: int = DOC_EMBEDDING_DIM
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info(f"Attempting to create collection {collection_name}")
|
logger.info(f"Attempting to create collection {collection_name}")
|
||||||
@@ -47,25 +48,25 @@ def create_collection(
|
|||||||
raise RuntimeError("Could not create Qdrant collection")
|
raise RuntimeError("Could not create Qdrant collection")
|
||||||
|
|
||||||
|
|
||||||
def get_document_whitelists(
|
def get_qdrant_document_whitelists(
|
||||||
doc_chunk_id: str, collection_name: str, q_client: QdrantClient
|
doc_chunk_id: str, collection_name: str, q_client: QdrantClient
|
||||||
) -> tuple[int, list[str], list[str]]:
|
) -> tuple[bool, list[str], list[str]]:
|
||||||
results = q_client.retrieve(
|
results = q_client.retrieve(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
ids=[doc_chunk_id],
|
ids=[doc_chunk_id],
|
||||||
with_payload=[ALLOWED_USERS, ALLOWED_GROUPS],
|
with_payload=[ALLOWED_USERS, ALLOWED_GROUPS],
|
||||||
)
|
)
|
||||||
if len(results) == 0:
|
if len(results) == 0:
|
||||||
return 0, [], []
|
return False, [], []
|
||||||
payload = results[0].payload
|
payload = results[0].payload
|
||||||
if not payload:
|
if not payload:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Qdrant Index is corrupted, Document found with no access lists."
|
"Qdrant Index is corrupted, Document found with no access lists."
|
||||||
)
|
)
|
||||||
return len(results), payload[ALLOWED_USERS], payload[ALLOWED_GROUPS]
|
return True, payload[ALLOWED_USERS], payload[ALLOWED_GROUPS]
|
||||||
|
|
||||||
|
|
||||||
def delete_doc_chunks(
|
def delete_qdrant_doc_chunks(
|
||||||
document_id: str, collection_name: str, q_client: QdrantClient
|
document_id: str, collection_name: str, q_client: QdrantClient
|
||||||
) -> None:
|
) -> None:
|
||||||
q_client.delete(
|
q_client.delete(
|
||||||
@@ -83,24 +84,7 @@ def delete_doc_chunks(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def recreate_collection(
|
def index_qdrant_chunks(
|
||||||
collection_name: str, embedding_dim: int = DOC_EMBEDDING_DIM
|
|
||||||
) -> None:
|
|
||||||
logger.info(f"Attempting to recreate collection {collection_name}")
|
|
||||||
result = get_qdrant_client().recreate_collection(
|
|
||||||
collection_name=collection_name,
|
|
||||||
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
|
|
||||||
)
|
|
||||||
if not result:
|
|
||||||
raise RuntimeError("Could not create Qdrant collection")
|
|
||||||
|
|
||||||
|
|
||||||
def get_uuid_from_chunk(chunk: EmbeddedIndexChunk) -> uuid.UUID:
|
|
||||||
unique_identifier_string = "_".join([chunk.source_document.id, str(chunk.chunk_id)])
|
|
||||||
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
|
||||||
|
|
||||||
|
|
||||||
def index_chunks(
|
|
||||||
chunks: list[EmbeddedIndexChunk],
|
chunks: list[EmbeddedIndexChunk],
|
||||||
user_id: int | None,
|
user_id: int | None,
|
||||||
collection: str,
|
collection: str,
|
||||||
@@ -112,51 +96,45 @@ def index_chunks(
|
|||||||
user_str = PUBLIC_DOC_PAT if user_id is None else str(user_id)
|
user_str = PUBLIC_DOC_PAT if user_id is None else str(user_id)
|
||||||
q_client: QdrantClient = client if client else get_qdrant_client()
|
q_client: QdrantClient = client if client else get_qdrant_client()
|
||||||
|
|
||||||
point_structs = []
|
point_structs: list[PointStruct] = []
|
||||||
# Maps document id to dict of whitelists for users/groups each containing list of users/groups as strings
|
# Maps document id to dict of whitelists for users/groups each containing list of users/groups as strings
|
||||||
doc_user_map: dict[str, dict[str, list[str]]] = {}
|
doc_user_map: dict[str, dict[str, list[str]]] = {}
|
||||||
for chunk in chunks:
|
for chunk in chunks:
|
||||||
chunk_uuid = str(get_uuid_from_chunk(chunk))
|
|
||||||
document = chunk.source_document
|
document = chunk.source_document
|
||||||
|
doc_user_map, delete_doc = update_doc_user_map(
|
||||||
|
chunk,
|
||||||
|
doc_user_map,
|
||||||
|
partial(
|
||||||
|
get_qdrant_document_whitelists,
|
||||||
|
collection_name=collection,
|
||||||
|
q_client=q_client,
|
||||||
|
),
|
||||||
|
user_str,
|
||||||
|
)
|
||||||
|
|
||||||
if document.id not in doc_user_map:
|
if delete_doc:
|
||||||
num_doc_chunks, whitelist_users, whitelist_groups = get_document_whitelists(
|
delete_qdrant_doc_chunks(document.id, collection, q_client)
|
||||||
chunk_uuid, collection, q_client
|
|
||||||
)
|
|
||||||
if num_doc_chunks == 0:
|
|
||||||
doc_user_map[document.id] = {
|
|
||||||
ALLOWED_USERS: [user_str],
|
|
||||||
# TODO introduce groups logic here
|
|
||||||
ALLOWED_GROUPS: whitelist_groups,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
if user_str not in whitelist_users:
|
|
||||||
whitelist_users.append(user_str)
|
|
||||||
# TODO introduce groups logic here
|
|
||||||
doc_user_map[document.id] = {
|
|
||||||
ALLOWED_USERS: whitelist_users,
|
|
||||||
ALLOWED_GROUPS: whitelist_groups,
|
|
||||||
}
|
|
||||||
# Need to delete document chunks because number of chunks may decrease
|
|
||||||
delete_doc_chunks(document.id, collection, q_client)
|
|
||||||
|
|
||||||
point_structs.append(
|
point_structs.extend(
|
||||||
PointStruct(
|
[
|
||||||
id=chunk_uuid,
|
PointStruct(
|
||||||
payload={
|
id=str(get_uuid_from_chunk(chunk, minichunk_ind)),
|
||||||
DOCUMENT_ID: document.id,
|
payload={
|
||||||
CHUNK_ID: chunk.chunk_id,
|
DOCUMENT_ID: document.id,
|
||||||
BLURB: chunk.blurb,
|
CHUNK_ID: chunk.chunk_id,
|
||||||
CONTENT: chunk.content,
|
BLURB: chunk.blurb,
|
||||||
SOURCE_TYPE: str(document.source.value),
|
CONTENT: chunk.content,
|
||||||
SOURCE_LINKS: chunk.source_links,
|
SOURCE_TYPE: str(document.source.value),
|
||||||
SEMANTIC_IDENTIFIER: document.semantic_identifier,
|
SOURCE_LINKS: chunk.source_links,
|
||||||
SECTION_CONTINUATION: chunk.section_continuation,
|
SEMANTIC_IDENTIFIER: document.semantic_identifier,
|
||||||
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
|
SECTION_CONTINUATION: chunk.section_continuation,
|
||||||
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
|
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
|
||||||
},
|
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
|
||||||
vector=chunk.embedding,
|
},
|
||||||
)
|
vector=embedding,
|
||||||
|
)
|
||||||
|
for minichunk_ind, embedding in enumerate(chunk.embeddings)
|
||||||
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
index_results = None
|
index_results = None
|
||||||
@@ -182,12 +160,14 @@ def index_chunks(
|
|||||||
index_results = upsert()
|
index_results = upsert()
|
||||||
log_status = index_results.status if index_results else "Failed"
|
log_status = index_results.status if index_results else "Failed"
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Indexed {len(point_struct_batch)} chunks into collection '{collection}', "
|
f"Indexed {len(point_struct_batch)} chunks into Qdrant collection '{collection}', "
|
||||||
f"status: {log_status}"
|
f"status: {log_status}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
index_results = q_client.upsert(
|
index_results = q_client.upsert(
|
||||||
collection_name=collection, points=point_structs
|
collection_name=collection, points=point_structs
|
||||||
)
|
)
|
||||||
logger.info(f"Batch indexing status: {index_results.status}")
|
logger.info(
|
||||||
|
f"Document batch of size {len(point_structs)} indexing status: {index_results.status}"
|
||||||
|
)
|
||||||
return index_results is not None and index_results.status == UpdateStatus.COMPLETED
|
return index_results is not None and index_results.status == UpdateStatus.COMPLETED
|
||||||
|
@@ -1,12 +1,17 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
from danswer.chunking.models import EmbeddedIndexChunk
|
from danswer.chunking.models import EmbeddedIndexChunk
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||||
|
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||||
from danswer.configs.constants import ALLOWED_USERS
|
from danswer.configs.constants import ALLOWED_USERS
|
||||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||||
from danswer.datastores.interfaces import Datastore
|
from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
||||||
from danswer.datastores.interfaces import DatastoreFilter
|
from danswer.datastores.interfaces import IndexFilter
|
||||||
from danswer.datastores.qdrant.indexing import index_chunks
|
from danswer.datastores.interfaces import VectorIndex
|
||||||
from danswer.semantic_search.semantic_search import get_default_embedding_model
|
from danswer.datastores.qdrant.indexing import index_qdrant_chunks
|
||||||
|
from danswer.search.semantic_search import get_default_embedding_model
|
||||||
from danswer.utils.clients import get_qdrant_client
|
from danswer.utils.clients import get_qdrant_client
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from danswer.utils.timing import log_function_time
|
from danswer.utils.timing import log_function_time
|
||||||
@@ -20,13 +25,60 @@ from qdrant_client.http.models import MatchValue
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
class QdrantDatastore(Datastore):
|
def _build_qdrant_filters(
|
||||||
|
user_id: int | None, filters: list[IndexFilter] | None
|
||||||
|
) -> list[FieldCondition]:
|
||||||
|
filter_conditions: list[FieldCondition] = []
|
||||||
|
# Permissions filter
|
||||||
|
if user_id:
|
||||||
|
filter_conditions.append(
|
||||||
|
FieldCondition(
|
||||||
|
key=ALLOWED_USERS,
|
||||||
|
match=MatchAny(any=[str(user_id), PUBLIC_DOC_PAT]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
filter_conditions.append(
|
||||||
|
FieldCondition(
|
||||||
|
key=ALLOWED_USERS,
|
||||||
|
match=MatchValue(value=PUBLIC_DOC_PAT),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Provided query filters
|
||||||
|
if filters:
|
||||||
|
for filter_dict in filters:
|
||||||
|
valid_filters = {
|
||||||
|
key: value for key, value in filter_dict.items() if value is not None
|
||||||
|
}
|
||||||
|
for filter_key, filter_val in valid_filters.items():
|
||||||
|
if isinstance(filter_val, str):
|
||||||
|
filter_conditions.append(
|
||||||
|
FieldCondition(
|
||||||
|
key=filter_key,
|
||||||
|
match=MatchValue(value=filter_val),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif isinstance(filter_val, list):
|
||||||
|
filter_conditions.append(
|
||||||
|
FieldCondition(
|
||||||
|
key=filter_key,
|
||||||
|
match=MatchAny(any=filter_val),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid filters provided")
|
||||||
|
|
||||||
|
return filter_conditions
|
||||||
|
|
||||||
|
|
||||||
|
class QdrantIndex(VectorIndex):
|
||||||
def __init__(self, collection: str = QDRANT_DEFAULT_COLLECTION) -> None:
|
def __init__(self, collection: str = QDRANT_DEFAULT_COLLECTION) -> None:
|
||||||
self.collection = collection
|
self.collection = collection
|
||||||
self.client = get_qdrant_client()
|
self.client = get_qdrant_client()
|
||||||
|
|
||||||
def index(self, chunks: list[EmbeddedIndexChunk], user_id: int | None) -> bool:
|
def index(self, chunks: list[EmbeddedIndexChunk], user_id: int | None) -> bool:
|
||||||
return index_chunks(
|
return index_qdrant_chunks(
|
||||||
chunks=chunks,
|
chunks=chunks,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
collection=self.collection,
|
collection=self.collection,
|
||||||
@@ -38,8 +90,9 @@ class QdrantDatastore(Datastore):
|
|||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
user_id: int | None,
|
user_id: int | None,
|
||||||
filters: list[DatastoreFilter] | None,
|
filters: list[IndexFilter] | None,
|
||||||
num_to_retrieve: int,
|
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||||
|
page_size: int = NUM_RERANKED_RESULTS,
|
||||||
) -> list[InferenceChunk]:
|
) -> list[InferenceChunk]:
|
||||||
query_embedding = get_default_embedding_model().encode(
|
query_embedding = get_default_embedding_model().encode(
|
||||||
query
|
query
|
||||||
@@ -47,68 +100,47 @@ class QdrantDatastore(Datastore):
|
|||||||
if not isinstance(query_embedding, list):
|
if not isinstance(query_embedding, list):
|
||||||
query_embedding = query_embedding.tolist()
|
query_embedding = query_embedding.tolist()
|
||||||
|
|
||||||
hits = []
|
filter_conditions = _build_qdrant_filters(user_id, filters)
|
||||||
filter_conditions = []
|
|
||||||
try:
|
|
||||||
# Permissions filter
|
|
||||||
if user_id:
|
|
||||||
filter_conditions.append(
|
|
||||||
FieldCondition(
|
|
||||||
key=ALLOWED_USERS,
|
|
||||||
match=MatchAny(any=[str(user_id), PUBLIC_DOC_PAT]),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
filter_conditions.append(
|
|
||||||
FieldCondition(
|
|
||||||
key=ALLOWED_USERS,
|
|
||||||
match=MatchValue(value=PUBLIC_DOC_PAT),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Provided query filters
|
page_offset = 0
|
||||||
if filters:
|
found_inference_chunks: list[InferenceChunk] = []
|
||||||
for filter_dict in filters:
|
found_chunk_uuids: set[uuid.UUID] = set()
|
||||||
valid_filters = {
|
while len(found_inference_chunks) < num_to_retrieve:
|
||||||
key: value
|
try:
|
||||||
for key, value in filter_dict.items()
|
hits = self.client.search(
|
||||||
if value is not None
|
collection_name=self.collection,
|
||||||
}
|
query_vector=query_embedding,
|
||||||
for filter_key, filter_val in valid_filters.items():
|
query_filter=Filter(must=list(filter_conditions)),
|
||||||
if isinstance(filter_val, str):
|
limit=page_size,
|
||||||
filter_conditions.append(
|
offset=page_offset,
|
||||||
FieldCondition(
|
)
|
||||||
key=filter_key,
|
page_offset += page_size
|
||||||
match=MatchValue(value=filter_val),
|
if not hits:
|
||||||
)
|
break
|
||||||
)
|
except ResponseHandlingException as e:
|
||||||
elif isinstance(filter_val, list):
|
logger.exception(
|
||||||
filter_conditions.append(
|
f'Qdrant querying failed due to: "{e}", is Qdrant set up?'
|
||||||
FieldCondition(
|
)
|
||||||
key=filter_key,
|
break
|
||||||
match=MatchAny(any=filter_val),
|
except UnexpectedResponse as e:
|
||||||
)
|
logger.exception(
|
||||||
)
|
f'Qdrant querying failed due to: "{e}", has ingestion been run?'
|
||||||
else:
|
)
|
||||||
raise ValueError("Invalid filters provided")
|
break
|
||||||
|
|
||||||
hits = self.client.search(
|
inference_chunks_from_hits = [
|
||||||
collection_name=self.collection,
|
InferenceChunk.from_dict(hit.payload)
|
||||||
query_vector=query_embedding,
|
for hit in hits
|
||||||
query_filter=Filter(must=list(filter_conditions)),
|
if hit.payload is not None
|
||||||
limit=num_to_retrieve,
|
]
|
||||||
)
|
for inf_chunk in inference_chunks_from_hits:
|
||||||
except ResponseHandlingException as e:
|
# remove duplicate chunks which happen if minichunks are used
|
||||||
logger.exception(f'Qdrant querying failed due to: "{e}", is Qdrant set up?')
|
inf_chunk_id = get_uuid_from_chunk(inf_chunk)
|
||||||
except UnexpectedResponse as e:
|
if inf_chunk_id not in found_chunk_uuids:
|
||||||
logger.exception(
|
found_inference_chunks.append(inf_chunk)
|
||||||
f'Qdrant querying failed due to: "{e}", has ingestion been run?'
|
found_chunk_uuids.add(inf_chunk_id)
|
||||||
)
|
|
||||||
return [
|
return found_inference_chunks
|
||||||
InferenceChunk.from_dict(hit.payload)
|
|
||||||
for hit in hits
|
|
||||||
if hit.payload is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
def get_from_id(self, object_id: str) -> InferenceChunk | None:
|
def get_from_id(self, object_id: str) -> InferenceChunk | None:
|
||||||
matches, _ = self.client.scroll(
|
matches, _ = self.client.scroll(
|
||||||
|
238
backend/danswer/datastores/typesense/store.py
Normal file
238
backend/danswer/datastores/typesense/store.py
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
import json
|
||||||
|
from functools import partial
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import typesense # type: ignore
|
||||||
|
from danswer.chunking.models import EmbeddedIndexChunk
|
||||||
|
from danswer.chunking.models import IndexChunk
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
|
||||||
|
from danswer.configs.constants import ALLOWED_GROUPS
|
||||||
|
from danswer.configs.constants import ALLOWED_USERS
|
||||||
|
from danswer.configs.constants import BLURB
|
||||||
|
from danswer.configs.constants import CHUNK_ID
|
||||||
|
from danswer.configs.constants import CONTENT
|
||||||
|
from danswer.configs.constants import DOCUMENT_ID
|
||||||
|
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||||
|
from danswer.configs.constants import SECTION_CONTINUATION
|
||||||
|
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||||
|
from danswer.configs.constants import SOURCE_LINKS
|
||||||
|
from danswer.configs.constants import SOURCE_TYPE
|
||||||
|
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
|
||||||
|
from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
||||||
|
from danswer.datastores.datastore_utils import update_doc_user_map
|
||||||
|
from danswer.datastores.interfaces import IndexFilter
|
||||||
|
from danswer.datastores.interfaces import KeywordIndex
|
||||||
|
from danswer.utils.clients import get_typesense_client
|
||||||
|
from danswer.utils.logging import setup_logger
|
||||||
|
from typesense.exceptions import ObjectNotFound # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def check_typesense_collection_exist(
|
||||||
|
collection_name: str = TYPESENSE_DEFAULT_COLLECTION,
|
||||||
|
) -> bool:
|
||||||
|
client = get_typesense_client()
|
||||||
|
try:
|
||||||
|
client.collections[collection_name].retrieve()
|
||||||
|
except ObjectNotFound:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def create_typesense_collection(
|
||||||
|
collection_name: str = TYPESENSE_DEFAULT_COLLECTION,
|
||||||
|
) -> None:
|
||||||
|
ts_client = get_typesense_client()
|
||||||
|
collection_schema = {
|
||||||
|
"name": collection_name,
|
||||||
|
"fields": [
|
||||||
|
# Typesense uses "id" type string as a special field
|
||||||
|
{"name": "id", "type": "string"},
|
||||||
|
{"name": DOCUMENT_ID, "type": "string"},
|
||||||
|
{"name": CHUNK_ID, "type": "int32"},
|
||||||
|
{"name": BLURB, "type": "string"},
|
||||||
|
{"name": CONTENT, "type": "string"},
|
||||||
|
{"name": SOURCE_TYPE, "type": "string"},
|
||||||
|
{"name": SOURCE_LINKS, "type": "string"},
|
||||||
|
{"name": SEMANTIC_IDENTIFIER, "type": "string"},
|
||||||
|
{"name": SECTION_CONTINUATION, "type": "bool"},
|
||||||
|
{"name": ALLOWED_USERS, "type": "string[]"},
|
||||||
|
{"name": ALLOWED_GROUPS, "type": "string[]"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
ts_client.collections.create(collection_schema)
|
||||||
|
|
||||||
|
|
||||||
|
def get_typesense_document_whitelists(
|
||||||
|
doc_chunk_id: str, collection_name: str, ts_client: typesense.Client
|
||||||
|
) -> tuple[bool, list[str], list[str]]:
|
||||||
|
"""Returns whether the document already exists and the users/group whitelists"""
|
||||||
|
try:
|
||||||
|
document = (
|
||||||
|
ts_client.collections[collection_name].documents[doc_chunk_id].retrieve()
|
||||||
|
)
|
||||||
|
except ObjectNotFound:
|
||||||
|
return False, [], []
|
||||||
|
if document[ALLOWED_USERS] is None or document[ALLOWED_GROUPS] is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Typesense Index is corrupted, Document found with no access lists."
|
||||||
|
)
|
||||||
|
return True, document[ALLOWED_USERS], document[ALLOWED_GROUPS]
|
||||||
|
|
||||||
|
|
||||||
|
def delete_typesense_doc_chunks(
|
||||||
|
document_id: str, collection_name: str, ts_client: typesense.Client
|
||||||
|
) -> None:
|
||||||
|
search_parameters = {
|
||||||
|
"q": document_id,
|
||||||
|
"query_by": DOCUMENT_ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO consider race conditions if running multiple processes/threads
|
||||||
|
hits = ts_client.collections[collection_name].documents.search(search_parameters)
|
||||||
|
[
|
||||||
|
ts_client.collections[collection_name].documents[hit["document"]["id"]].delete()
|
||||||
|
for hit in hits["hits"]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def index_typesense_chunks(
|
||||||
|
chunks: list[IndexChunk | EmbeddedIndexChunk],
|
||||||
|
user_id: int | None,
|
||||||
|
collection: str,
|
||||||
|
client: typesense.Client | None = None,
|
||||||
|
batch_upsert: bool = True,
|
||||||
|
) -> bool:
|
||||||
|
user_str = PUBLIC_DOC_PAT if user_id is None else str(user_id)
|
||||||
|
ts_client: typesense.Client = client if client else get_typesense_client()
|
||||||
|
|
||||||
|
new_documents: list[dict[str, Any]] = []
|
||||||
|
doc_user_map: dict[str, dict[str, list[str]]] = {}
|
||||||
|
for chunk in chunks:
|
||||||
|
document = chunk.source_document
|
||||||
|
doc_user_map, delete_doc = update_doc_user_map(
|
||||||
|
chunk,
|
||||||
|
doc_user_map,
|
||||||
|
partial(
|
||||||
|
get_typesense_document_whitelists,
|
||||||
|
collection_name=collection,
|
||||||
|
ts_client=ts_client,
|
||||||
|
),
|
||||||
|
user_str,
|
||||||
|
)
|
||||||
|
|
||||||
|
if delete_doc:
|
||||||
|
delete_typesense_doc_chunks(document.id, collection, ts_client)
|
||||||
|
|
||||||
|
new_documents.append(
|
||||||
|
{
|
||||||
|
"id": str(get_uuid_from_chunk(chunk)), # No minichunks for typesense
|
||||||
|
DOCUMENT_ID: document.id,
|
||||||
|
CHUNK_ID: chunk.chunk_id,
|
||||||
|
BLURB: chunk.blurb,
|
||||||
|
CONTENT: chunk.content,
|
||||||
|
SOURCE_TYPE: str(document.source.value),
|
||||||
|
SOURCE_LINKS: json.dumps(chunk.source_links),
|
||||||
|
SEMANTIC_IDENTIFIER: document.semantic_identifier,
|
||||||
|
SECTION_CONTINUATION: chunk.section_continuation,
|
||||||
|
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
|
||||||
|
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch_upsert:
|
||||||
|
doc_batches = [
|
||||||
|
new_documents[x : x + DEFAULT_BATCH_SIZE]
|
||||||
|
for x in range(0, len(new_documents), DEFAULT_BATCH_SIZE)
|
||||||
|
]
|
||||||
|
for doc_batch in doc_batches:
|
||||||
|
results = ts_client.collections[collection].documents.import_(
|
||||||
|
doc_batch, {"action": "upsert"}
|
||||||
|
)
|
||||||
|
failures = [
|
||||||
|
doc_res["success"]
|
||||||
|
for doc_res in results
|
||||||
|
if doc_res["success"] is not True
|
||||||
|
]
|
||||||
|
logger.info(
|
||||||
|
f"Indexed {len(doc_batch)} chunks into Typesense collection '{collection}', "
|
||||||
|
f"number failed: {len(failures)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
[
|
||||||
|
ts_client.collections[collection].documents.upsert(document)
|
||||||
|
for document in new_documents
|
||||||
|
]
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _build_typesense_filters(
|
||||||
|
user_id: int | None, filters: list[IndexFilter] | None
|
||||||
|
) -> str:
|
||||||
|
filter_str = ""
|
||||||
|
|
||||||
|
# Permissions filter
|
||||||
|
if user_id:
|
||||||
|
filter_str += f"{ALLOWED_USERS}:=[{PUBLIC_DOC_PAT}|{user_id}] && "
|
||||||
|
else:
|
||||||
|
filter_str += f"{ALLOWED_USERS}:={PUBLIC_DOC_PAT} && "
|
||||||
|
|
||||||
|
# Provided query filters
|
||||||
|
if filters:
|
||||||
|
for filter_dict in filters:
|
||||||
|
valid_filters = {
|
||||||
|
key: value for key, value in filter_dict.items() if value is not None
|
||||||
|
}
|
||||||
|
for filter_key, filter_val in valid_filters.items():
|
||||||
|
if isinstance(filter_val, str):
|
||||||
|
filter_str += f"{filter_key}:={filter_val} && "
|
||||||
|
elif isinstance(filter_val, list):
|
||||||
|
filters_or = ",".join([str(f_val) for f_val in filter_val])
|
||||||
|
filter_str += f"{filter_key}:=[{filters_or}] && "
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid filters provided")
|
||||||
|
if filter_str[-4:] == " && ":
|
||||||
|
filter_str = filter_str[:-4]
|
||||||
|
return filter_str
|
||||||
|
|
||||||
|
|
||||||
|
class TypesenseIndex(KeywordIndex):
|
||||||
|
def __init__(self, collection: str = TYPESENSE_DEFAULT_COLLECTION) -> None:
|
||||||
|
self.collection = collection
|
||||||
|
self.ts_client = get_typesense_client()
|
||||||
|
|
||||||
|
def index(self, chunks: list[IndexChunk], user_id: int | None) -> bool:
|
||||||
|
return index_typesense_chunks(
|
||||||
|
chunks=chunks,
|
||||||
|
user_id=user_id,
|
||||||
|
collection=self.collection,
|
||||||
|
client=self.ts_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
def keyword_search(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
user_id: int | None,
|
||||||
|
filters: list[IndexFilter] | None,
|
||||||
|
num_to_retrieve: int,
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
|
filters_str = _build_typesense_filters(user_id, filters)
|
||||||
|
|
||||||
|
search_results = self.ts_client.collections[self.collection].documents.search(
|
||||||
|
{
|
||||||
|
"q": query,
|
||||||
|
"query_by": CONTENT,
|
||||||
|
"filter_by": filters_str,
|
||||||
|
"per_page": num_to_retrieve,
|
||||||
|
"limit_hits": num_to_retrieve,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
hits = search_results["hits"]
|
||||||
|
inference_chunks = [InferenceChunk.from_dict(hit["document"]) for hit in hits]
|
||||||
|
|
||||||
|
return inference_chunks
|
@@ -332,9 +332,17 @@ class OpenAICompletionQA(OpenAIQAModel):
|
|||||||
logger.debug(model_output)
|
logger.debug(model_output)
|
||||||
|
|
||||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
logger.info(answer)
|
if answer:
|
||||||
|
logger.info(answer)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Answer extraction from model output failed, most likely no quotes provided"
|
||||||
|
)
|
||||||
|
|
||||||
yield quotes_dict
|
if quotes_dict is None:
|
||||||
|
yield {}
|
||||||
|
else:
|
||||||
|
yield quotes_dict
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletionQA(OpenAIQAModel):
|
class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||||
@@ -442,6 +450,14 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
|||||||
logger.debug(model_output)
|
logger.debug(model_output)
|
||||||
|
|
||||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
logger.info(answer)
|
if answer:
|
||||||
|
logger.info(answer)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Answer extraction from model output failed, most likely no quotes provided"
|
||||||
|
)
|
||||||
|
|
||||||
yield quotes_dict
|
if quotes_dict is None:
|
||||||
|
yield {}
|
||||||
|
else:
|
||||||
|
yield quotes_dict
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import nltk # type:ignore
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from danswer.auth.schemas import UserCreate
|
from danswer.auth.schemas import UserCreate
|
||||||
from danswer.auth.schemas import UserRead
|
from danswer.auth.schemas import UserRead
|
||||||
@@ -9,8 +10,11 @@ from danswer.configs.app_configs import APP_HOST
|
|||||||
from danswer.configs.app_configs import APP_PORT
|
from danswer.configs.app_configs import APP_PORT
|
||||||
from danswer.configs.app_configs import ENABLE_OAUTH
|
from danswer.configs.app_configs import ENABLE_OAUTH
|
||||||
from danswer.configs.app_configs import SECRET
|
from danswer.configs.app_configs import SECRET
|
||||||
|
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
|
||||||
from danswer.configs.app_configs import WEB_DOMAIN
|
from danswer.configs.app_configs import WEB_DOMAIN
|
||||||
from danswer.datastores.qdrant.indexing import list_collections
|
from danswer.datastores.qdrant.indexing import list_qdrant_collections
|
||||||
|
from danswer.datastores.typesense.store import check_typesense_collection_exist
|
||||||
|
from danswer.datastores.typesense.store import create_typesense_collection
|
||||||
from danswer.db.credentials import create_initial_public_credential
|
from danswer.db.credentials import create_initial_public_credential
|
||||||
from danswer.server.event_loading import router as event_processing_router
|
from danswer.server.event_loading import router as event_processing_router
|
||||||
from danswer.server.health import router as health_router
|
from danswer.server.health import router as health_router
|
||||||
@@ -107,24 +111,36 @@ def get_application() -> FastAPI:
|
|||||||
@application.on_event("startup")
|
@application.on_event("startup")
|
||||||
def startup_event() -> None:
|
def startup_event() -> None:
|
||||||
# To avoid circular imports
|
# To avoid circular imports
|
||||||
from danswer.semantic_search.semantic_search import (
|
from danswer.search.semantic_search import (
|
||||||
warm_up_models,
|
warm_up_models,
|
||||||
)
|
)
|
||||||
from danswer.datastores.qdrant.indexing import create_collection
|
from danswer.datastores.qdrant.indexing import create_qdrant_collection
|
||||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||||
|
|
||||||
if QDRANT_DEFAULT_COLLECTION not in {
|
logger.info("Warming up local NLP models.")
|
||||||
collection.name for collection in list_collections().collections
|
|
||||||
}:
|
|
||||||
logger.info(f"Creating collection with name: {QDRANT_DEFAULT_COLLECTION}")
|
|
||||||
create_collection(collection_name=QDRANT_DEFAULT_COLLECTION)
|
|
||||||
|
|
||||||
warm_up_models()
|
warm_up_models()
|
||||||
logger.info("Semantic Search models are ready.")
|
|
||||||
|
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||||
|
nltk.download("stopwords")
|
||||||
|
nltk.download("wordnet")
|
||||||
|
|
||||||
logger.info("Verifying public credential exists.")
|
logger.info("Verifying public credential exists.")
|
||||||
create_initial_public_credential()
|
create_initial_public_credential()
|
||||||
|
|
||||||
|
logger.info("Verifying Document Indexes are available.")
|
||||||
|
if QDRANT_DEFAULT_COLLECTION not in {
|
||||||
|
collection.name for collection in list_qdrant_collections().collections
|
||||||
|
}:
|
||||||
|
logger.info(
|
||||||
|
f"Creating Qdrant collection with name: {QDRANT_DEFAULT_COLLECTION}"
|
||||||
|
)
|
||||||
|
create_qdrant_collection(collection_name=QDRANT_DEFAULT_COLLECTION)
|
||||||
|
if not check_typesense_collection_exist(TYPESENSE_DEFAULT_COLLECTION):
|
||||||
|
logger.info(
|
||||||
|
f"Creating Typesense collection with name: {TYPESENSE_DEFAULT_COLLECTION}"
|
||||||
|
)
|
||||||
|
create_typesense_collection(collection_name=TYPESENSE_DEFAULT_COLLECTION)
|
||||||
|
|
||||||
return application
|
return application
|
||||||
|
|
||||||
|
|
||||||
|
0
backend/danswer/search/__init__.py
Normal file
0
backend/danswer/search/__init__.py
Normal file
52
backend/danswer/search/keyword_search.py
Normal file
52
backend/danswer/search/keyword_search.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||||
|
from danswer.datastores.interfaces import IndexFilter
|
||||||
|
from danswer.datastores.interfaces import KeywordIndex
|
||||||
|
from danswer.utils.logging import setup_logger
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
|
from nltk.corpus import stopwords # type:ignore
|
||||||
|
from nltk.stem import WordNetLemmatizer # type:ignore
|
||||||
|
from nltk.tokenize import word_tokenize # type:ignore
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def lemmatize_text(text: str) -> str:
|
||||||
|
lemmatizer = WordNetLemmatizer()
|
||||||
|
word_tokens = word_tokenize(text)
|
||||||
|
lemmatized_text = [lemmatizer.lemmatize(word) for word in word_tokens]
|
||||||
|
return " ".join(lemmatized_text)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_stop_words(text: str) -> str:
|
||||||
|
stop_words = set(stopwords.words("english"))
|
||||||
|
word_tokens = word_tokenize(text)
|
||||||
|
filtered_text = [word for word in word_tokens if word.casefold() not in stop_words]
|
||||||
|
return " ".join(filtered_text)
|
||||||
|
|
||||||
|
|
||||||
|
def query_processing(query: str) -> str:
|
||||||
|
query = remove_stop_words(query)
|
||||||
|
query = lemmatize_text(query)
|
||||||
|
return query
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def retrieve_keyword_documents(
|
||||||
|
query: str,
|
||||||
|
user_id: int | None,
|
||||||
|
filters: list[IndexFilter] | None,
|
||||||
|
datastore: KeywordIndex,
|
||||||
|
num_hits: int = NUM_RETURNED_HITS,
|
||||||
|
) -> list[InferenceChunk] | None:
|
||||||
|
edited_query = query_processing(query)
|
||||||
|
top_chunks = datastore.keyword_search(edited_query, user_id, filters, num_hits)
|
||||||
|
if not top_chunks:
|
||||||
|
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
|
||||||
|
logger.warning(
|
||||||
|
f"Keyword search returned no results...\nfilters: {filters_log_msg}\nedited query: {edited_query}"
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
return top_chunks
|
201
backend/danswer/search/semantic_search.py
Normal file
201
backend/danswer/search/semantic_search.py
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
from danswer.chunking.models import EmbeddedIndexChunk
|
||||||
|
from danswer.chunking.models import IndexChunk
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||||
|
from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
||||||
|
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||||
|
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||||
|
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||||
|
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||||
|
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||||
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
|
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||||
|
from danswer.datastores.interfaces import IndexFilter
|
||||||
|
from danswer.datastores.interfaces import VectorIndex
|
||||||
|
from danswer.search.type_aliases import Embedder
|
||||||
|
from danswer.server.models import SearchDoc
|
||||||
|
from danswer.utils.logging import setup_logger
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
|
from sentence_transformers import CrossEncoder # type: ignore
|
||||||
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
_EMBED_MODEL: None | SentenceTransformer = None
|
||||||
|
_RERANK_MODELS: None | list[CrossEncoder] = None
|
||||||
|
|
||||||
|
|
||||||
|
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
|
||||||
|
search_docs = (
|
||||||
|
[
|
||||||
|
SearchDoc(
|
||||||
|
semantic_identifier=chunk.semantic_identifier,
|
||||||
|
link=chunk.source_links.get(0) if chunk.source_links else None,
|
||||||
|
blurb=chunk.blurb,
|
||||||
|
source_type=chunk.source_type,
|
||||||
|
)
|
||||||
|
for chunk in chunks
|
||||||
|
]
|
||||||
|
if chunks
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
return search_docs
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_embedding_model() -> SentenceTransformer:
|
||||||
|
global _EMBED_MODEL
|
||||||
|
if _EMBED_MODEL is None:
|
||||||
|
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
|
||||||
|
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
|
|
||||||
|
return _EMBED_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_reranking_model_ensemble() -> list[CrossEncoder]:
|
||||||
|
global _RERANK_MODELS
|
||||||
|
if _RERANK_MODELS is None:
|
||||||
|
_RERANK_MODELS = [
|
||||||
|
CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE
|
||||||
|
]
|
||||||
|
for model in _RERANK_MODELS:
|
||||||
|
model.max_length = CROSS_EMBED_CONTEXT_SIZE
|
||||||
|
|
||||||
|
return _RERANK_MODELS
|
||||||
|
|
||||||
|
|
||||||
|
def warm_up_models() -> None:
|
||||||
|
get_default_embedding_model().encode("Danswer is so cool")
|
||||||
|
cross_encoders = get_default_reranking_model_ensemble()
|
||||||
|
[cross_encoder.predict(("What is Danswer", "Enterprise QA")) for cross_encoder in cross_encoders] # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def semantic_reranking(
|
||||||
|
query: str,
|
||||||
|
chunks: list[InferenceChunk],
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
|
cross_encoders = get_default_reranking_model_ensemble()
|
||||||
|
sim_scores = sum([encoder.predict([(query, chunk.content) for chunk in chunks]) for encoder in cross_encoders]) # type: ignore
|
||||||
|
scored_results = list(zip(sim_scores, chunks))
|
||||||
|
scored_results.sort(key=lambda x: x[0], reverse=True)
|
||||||
|
ranked_sim_scores, ranked_chunks = zip(*scored_results)
|
||||||
|
|
||||||
|
logger.debug(f"Reranked similarity scores: {str(ranked_sim_scores)}")
|
||||||
|
|
||||||
|
return ranked_chunks
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def retrieve_ranked_documents(
|
||||||
|
query: str,
|
||||||
|
user_id: int | None,
|
||||||
|
filters: list[IndexFilter] | None,
|
||||||
|
datastore: VectorIndex,
|
||||||
|
num_hits: int = NUM_RETURNED_HITS,
|
||||||
|
num_rerank: int = NUM_RERANKED_RESULTS,
|
||||||
|
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
|
||||||
|
top_chunks = datastore.semantic_retrieval(query, user_id, filters, num_hits)
|
||||||
|
if not top_chunks:
|
||||||
|
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
|
||||||
|
logger.warning(
|
||||||
|
f"Semantic search returned no results with filters: {filters_log_msg}"
|
||||||
|
)
|
||||||
|
return None, None
|
||||||
|
ranked_chunks = semantic_reranking(query, top_chunks[:num_rerank])
|
||||||
|
|
||||||
|
top_docs = [
|
||||||
|
ranked_chunk.source_links[0]
|
||||||
|
for ranked_chunk in ranked_chunks
|
||||||
|
if ranked_chunk.source_links is not None
|
||||||
|
]
|
||||||
|
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
||||||
|
logger.info(files_log_msg)
|
||||||
|
|
||||||
|
return ranked_chunks, top_chunks[num_rerank:]
|
||||||
|
|
||||||
|
|
||||||
|
def split_chunk_text_into_mini_chunks(
|
||||||
|
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
|
||||||
|
) -> list[str]:
|
||||||
|
chunks = []
|
||||||
|
start = 0
|
||||||
|
separators = [" ", "\n", "\r", "\t"]
|
||||||
|
|
||||||
|
while start < len(chunk_text):
|
||||||
|
if len(chunk_text) - start <= mini_chunk_size:
|
||||||
|
end = len(chunk_text)
|
||||||
|
else:
|
||||||
|
# Find the first separator character after min_chunk_length
|
||||||
|
end_positions = [
|
||||||
|
(chunk_text[start + mini_chunk_size :]).find(sep) for sep in separators
|
||||||
|
]
|
||||||
|
# Filter out the not found cases (-1)
|
||||||
|
end_positions = [pos for pos in end_positions if pos != -1]
|
||||||
|
if not end_positions:
|
||||||
|
# If no more separators, the rest of the string becomes a chunk
|
||||||
|
end = len(chunk_text)
|
||||||
|
else:
|
||||||
|
# Add min_chunk_length and start to the end position
|
||||||
|
end = min(end_positions) + start + mini_chunk_size
|
||||||
|
|
||||||
|
chunks.append(chunk_text[start:end])
|
||||||
|
start = end + 1 # Move to the next character after the separator
|
||||||
|
|
||||||
|
return chunks
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def encode_chunks(
|
||||||
|
chunks: list[IndexChunk],
|
||||||
|
embedding_model: SentenceTransformer | None = None,
|
||||||
|
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||||
|
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||||
|
) -> list[EmbeddedIndexChunk]:
|
||||||
|
embedded_chunks: list[EmbeddedIndexChunk] = []
|
||||||
|
if embedding_model is None:
|
||||||
|
embedding_model = get_default_embedding_model()
|
||||||
|
|
||||||
|
chunk_texts = []
|
||||||
|
chunk_mini_chunks_count = {}
|
||||||
|
for chunk_ind, chunk in enumerate(chunks):
|
||||||
|
chunk_texts.append(chunk.content)
|
||||||
|
mini_chunk_texts = (
|
||||||
|
split_chunk_text_into_mini_chunks(chunk.content)
|
||||||
|
if enable_mini_chunk
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
chunk_texts.extend(mini_chunk_texts)
|
||||||
|
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||||
|
|
||||||
|
text_batches = [
|
||||||
|
chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
embeddings_np: list[numpy.ndarray] = []
|
||||||
|
for text_batch in text_batches:
|
||||||
|
embeddings_np.extend(embedding_model.encode(text_batch))
|
||||||
|
embeddings: list[list[float]] = [embedding.tolist() for embedding in embeddings_np]
|
||||||
|
|
||||||
|
embedding_ind_start = 0
|
||||||
|
for chunk_ind, chunk in enumerate(chunks):
|
||||||
|
num_embeddings = chunk_mini_chunks_count[chunk_ind]
|
||||||
|
chunk_embeddings = embeddings[
|
||||||
|
embedding_ind_start : embedding_ind_start + num_embeddings
|
||||||
|
]
|
||||||
|
new_embedded_chunk = EmbeddedIndexChunk(
|
||||||
|
**{k: getattr(chunk, k) for k in chunk.__dataclass_fields__},
|
||||||
|
embeddings=chunk_embeddings,
|
||||||
|
)
|
||||||
|
embedded_chunks.append(new_embedded_chunk)
|
||||||
|
embedding_ind_start += num_embeddings
|
||||||
|
|
||||||
|
return embedded_chunks
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultEmbedder(Embedder):
|
||||||
|
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
|
||||||
|
return encode_chunks(chunks)
|
@@ -1,201 +0,0 @@
|
|||||||
Apache License
|
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
|
||||||
the copyright owner that is granting the License.
|
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
|
||||||
exercising permissions granted by this License.
|
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
APPENDIX: How to apply the Apache License to your work.
|
|
||||||
|
|
||||||
To apply the Apache License to your work, attach the following
|
|
||||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
|
||||||
replaced with your own identifying information. (Don't include
|
|
||||||
the brackets!) The text should be enclosed in the appropriate
|
|
||||||
comment syntax for the file format. We also recommend that a
|
|
||||||
file or class name and description of purpose be included on the
|
|
||||||
same "printed page" as the copyright notice for easier
|
|
||||||
identification within third-party archives.
|
|
||||||
|
|
||||||
Copyright 2019 Nils Reimers
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
@@ -1,43 +0,0 @@
|
|||||||
from danswer.chunking.models import EmbeddedIndexChunk
|
|
||||||
from danswer.chunking.models import IndexChunk
|
|
||||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
|
||||||
from danswer.semantic_search.semantic_search import get_default_embedding_model
|
|
||||||
from danswer.semantic_search.type_aliases import Embedder
|
|
||||||
from danswer.utils.logging import setup_logger
|
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def encode_chunks(
|
|
||||||
chunks: list[IndexChunk],
|
|
||||||
embedding_model: SentenceTransformer | None = None,
|
|
||||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
|
||||||
) -> list[EmbeddedIndexChunk]:
|
|
||||||
embedded_chunks = []
|
|
||||||
if embedding_model is None:
|
|
||||||
embedding_model = get_default_embedding_model()
|
|
||||||
|
|
||||||
chunk_batches = [
|
|
||||||
chunks[i : i + batch_size] for i in range(0, len(chunks), batch_size)
|
|
||||||
]
|
|
||||||
for batch_ind, chunk_batch in enumerate(chunk_batches):
|
|
||||||
embeddings_batch = embedding_model.encode(
|
|
||||||
[chunk.content for chunk in chunk_batch]
|
|
||||||
)
|
|
||||||
embedded_chunks.extend(
|
|
||||||
[
|
|
||||||
EmbeddedIndexChunk(
|
|
||||||
**{k: getattr(chunk, k) for k in chunk.__dataclass_fields__},
|
|
||||||
embedding=embeddings_batch[i].tolist()
|
|
||||||
)
|
|
||||||
for i, chunk in enumerate(chunk_batch)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return embedded_chunks
|
|
||||||
|
|
||||||
|
|
||||||
class DefaultEmbedder(Embedder):
|
|
||||||
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
|
|
||||||
return encode_chunks(chunks)
|
|
@@ -1,105 +0,0 @@
|
|||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
#
|
|
||||||
#
|
|
||||||
# The NLP models used here are licensed under Apache 2.0, the original author's LICENSE file is
|
|
||||||
# included in this same directory.
|
|
||||||
# Specifically the sentence-transformers/all-distilroberta-v1 and cross-encoder/ms-marco-MiniLM-L-6-v2 models
|
|
||||||
# The original authors can be found at https://www.sbert.net/
|
|
||||||
import json
|
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
|
||||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
|
||||||
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL
|
|
||||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
|
||||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
|
||||||
from danswer.datastores.interfaces import Datastore
|
|
||||||
from danswer.datastores.interfaces import DatastoreFilter
|
|
||||||
from danswer.utils.logging import setup_logger
|
|
||||||
from danswer.utils.timing import log_function_time
|
|
||||||
from sentence_transformers import CrossEncoder # type: ignore
|
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
|
|
||||||
_EMBED_MODEL: None | SentenceTransformer = None
|
|
||||||
_RERANK_MODEL: None | CrossEncoder = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_embedding_model() -> SentenceTransformer:
|
|
||||||
global _EMBED_MODEL
|
|
||||||
if _EMBED_MODEL is None:
|
|
||||||
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
|
|
||||||
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
|
|
||||||
|
|
||||||
return _EMBED_MODEL
|
|
||||||
|
|
||||||
|
|
||||||
def get_default_reranking_model() -> CrossEncoder:
|
|
||||||
global _RERANK_MODEL
|
|
||||||
if _RERANK_MODEL is None:
|
|
||||||
_RERANK_MODEL = CrossEncoder(CROSS_ENCODER_MODEL)
|
|
||||||
_RERANK_MODEL.max_length = CROSS_EMBED_CONTEXT_SIZE
|
|
||||||
|
|
||||||
return _RERANK_MODEL
|
|
||||||
|
|
||||||
|
|
||||||
def warm_up_models() -> None:
|
|
||||||
get_default_embedding_model().encode("Danswer is so cool")
|
|
||||||
get_default_reranking_model().predict(("What is Danswer", "Enterprise QA")) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def semantic_reranking(
|
|
||||||
query: str,
|
|
||||||
chunks: list[InferenceChunk],
|
|
||||||
) -> list[InferenceChunk]:
|
|
||||||
cross_encoder = get_default_reranking_model()
|
|
||||||
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
|
|
||||||
scored_results = list(zip(sim_scores, chunks))
|
|
||||||
scored_results.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
ranked_sim_scores, ranked_chunks = zip(*scored_results)
|
|
||||||
|
|
||||||
logger.debug(f"Reranked similarity scores: {str(ranked_sim_scores)}")
|
|
||||||
|
|
||||||
return ranked_chunks
|
|
||||||
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def retrieve_ranked_documents(
|
|
||||||
query: str,
|
|
||||||
user_id: int | None,
|
|
||||||
filters: list[DatastoreFilter] | None,
|
|
||||||
datastore: Datastore,
|
|
||||||
num_hits: int = NUM_RETURNED_HITS,
|
|
||||||
) -> list[InferenceChunk] | None:
|
|
||||||
top_chunks = datastore.semantic_retrieval(query, user_id, filters, num_hits)
|
|
||||||
if not top_chunks:
|
|
||||||
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
|
|
||||||
logger.warning(
|
|
||||||
f"Semantic search returned no results with filters: {filters_log_msg}"
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
ranked_chunks = semantic_reranking(query, top_chunks)
|
|
||||||
|
|
||||||
top_docs = [
|
|
||||||
ranked_chunk.source_links[0]
|
|
||||||
for ranked_chunk in ranked_chunks
|
|
||||||
if ranked_chunk.source_links is not None
|
|
||||||
]
|
|
||||||
files_log_msg = f"Top links from semantic search: {', '.join(top_docs)}"
|
|
||||||
logger.info(files_log_msg)
|
|
||||||
|
|
||||||
return ranked_chunks
|
|
@@ -1,6 +1,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
from danswer.auth.schemas import UserRole
|
||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||||
@@ -31,6 +32,7 @@ from danswer.db.credentials import fetch_credential_by_id
|
|||||||
from danswer.db.credentials import fetch_credentials
|
from danswer.db.credentials import fetch_credentials
|
||||||
from danswer.db.credentials import mask_credential_dict
|
from danswer.db.credentials import mask_credential_dict
|
||||||
from danswer.db.credentials import update_credential
|
from danswer.db.credentials import update_credential
|
||||||
|
from danswer.db.engine import build_async_engine
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
from danswer.db.index_attempt import create_index_attempt
|
from danswer.db.index_attempt import create_index_attempt
|
||||||
from danswer.db.models import Connector
|
from danswer.db.models import Connector
|
||||||
@@ -55,23 +57,45 @@ from danswer.server.models import IndexAttemptSnapshot
|
|||||||
from danswer.server.models import ObjectCreationIdResponse
|
from danswer.server.models import ObjectCreationIdResponse
|
||||||
from danswer.server.models import RunConnectorRequest
|
from danswer.server.models import RunConnectorRequest
|
||||||
from danswer.server.models import StatusResponse
|
from danswer.server.models import StatusResponse
|
||||||
|
from danswer.server.models import UserByEmail
|
||||||
|
from danswer.server.models import UserRoleResponse
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
router = APIRouter(prefix="/manage")
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/manage")
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME = "google_drive_credential_id"
|
_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME = "google_drive_credential_id"
|
||||||
|
|
||||||
|
|
||||||
"""Admin only API endpoints"""
|
"""Admin only API endpoints"""
|
||||||
|
|
||||||
|
|
||||||
|
@router.patch("/promote-user-to-admin", response_model=None)
|
||||||
|
async def promote_admin(
|
||||||
|
user_email: UserByEmail, user: User = Depends(current_admin_user)
|
||||||
|
) -> None:
|
||||||
|
if user.role != UserRole.ADMIN:
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
async with AsyncSession(build_async_engine()) as asession:
|
||||||
|
user_db = SQLAlchemyUserDatabase(asession, User) # type: ignore
|
||||||
|
user_to_promote = await user_db.get_by_email(user_email.user_email)
|
||||||
|
if not user_to_promote:
|
||||||
|
raise HTTPException(status_code=404, detail="User not found")
|
||||||
|
user_to_promote.role = UserRole.ADMIN
|
||||||
|
asession.add(user_to_promote)
|
||||||
|
await asession.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
@router.get("/admin/connector/google-drive/app-credential")
|
@router.get("/admin/connector/google-drive/app-credential")
|
||||||
def check_google_app_credentials_exist(
|
def check_google_app_credentials_exist(
|
||||||
_: User = Depends(current_admin_user),
|
_: User = Depends(current_admin_user),
|
||||||
@@ -403,7 +427,14 @@ def delete_openai_api_key(
|
|||||||
get_dynamic_config_store().delete(OPENAI_API_KEY_STORAGE_KEY)
|
get_dynamic_config_store().delete(OPENAI_API_KEY_STORAGE_KEY)
|
||||||
|
|
||||||
|
|
||||||
"""Endpoints for all!"""
|
"""Endpoints for basic users"""
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/get-user-role", response_model=UserRoleResponse)
|
||||||
|
async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
|
||||||
|
if user is None:
|
||||||
|
raise ValueError("Invalid or missing user.")
|
||||||
|
return UserRoleResponse(role=user.role)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/connector/google-drive/authorize/{credential_id}")
|
@router.get("/connector/google-drive/authorize/{credential_id}")
|
||||||
|
@@ -3,12 +3,10 @@ from typing import Any
|
|||||||
from typing import Generic
|
from typing import Generic
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.connectors.models import InputType
|
from danswer.connectors.models import InputType
|
||||||
from danswer.datastores.interfaces import DatastoreFilter
|
|
||||||
from danswer.db.models import Connector
|
from danswer.db.models import Connector
|
||||||
from danswer.db.models import IndexingStatus
|
from danswer.db.models import IndexingStatus
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -75,20 +73,25 @@ class SearchDoc(BaseModel):
|
|||||||
source_type: str
|
source_type: str
|
||||||
|
|
||||||
|
|
||||||
class QAQuestion(BaseModel):
|
class QuestionRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
collection: str
|
collection: str
|
||||||
filters: list[DatastoreFilter] | None
|
use_keyword: bool | None
|
||||||
|
filters: str | None # string of list[IndexFilter]
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResponse(BaseModel):
|
||||||
|
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
|
||||||
|
top_ranked_docs: list[SearchDoc] | None
|
||||||
|
semi_ranked_docs: list[SearchDoc] | None
|
||||||
|
|
||||||
|
|
||||||
class QAResponse(BaseModel):
|
class QAResponse(BaseModel):
|
||||||
answer: str | None
|
answer: str | None
|
||||||
quotes: dict[str, dict[str, str | int | None]] | None
|
quotes: dict[str, dict[str, str | int | None]] | None
|
||||||
ranked_documents: list[SearchDoc] | None
|
ranked_documents: list[SearchDoc] | None
|
||||||
|
# for performance, only a few top documents are cross-encoded for rerank, the rest follow retrieval order
|
||||||
|
unranked_documents: list[SearchDoc] | None
|
||||||
class KeywordResponse(BaseModel):
|
|
||||||
results: list[str] | None
|
|
||||||
|
|
||||||
|
|
||||||
class UserByEmail(BaseModel):
|
class UserByEmail(BaseModel):
|
||||||
|
@@ -1,96 +1,105 @@
|
|||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from danswer.auth.schemas import UserRole
|
|
||||||
from danswer.auth.users import current_admin_user
|
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
from danswer.configs.app_configs import KEYWORD_MAX_HITS
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.constants import CONTENT
|
from danswer.datastores.qdrant.store import QdrantIndex
|
||||||
from danswer.configs.constants import SOURCE_LINKS
|
from danswer.datastores.typesense.store import TypesenseIndex
|
||||||
from danswer.datastores import create_datastore
|
|
||||||
from danswer.db.engine import build_async_engine
|
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa import get_default_backend_qa_model
|
from danswer.direct_qa import get_default_backend_qa_model
|
||||||
from danswer.direct_qa.question_answer import get_json_line
|
from danswer.direct_qa.question_answer import get_json_line
|
||||||
from danswer.semantic_search.semantic_search import retrieve_ranked_documents
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||||
from danswer.server.models import KeywordResponse
|
from danswer.search.semantic_search import chunks_to_search_docs
|
||||||
from danswer.server.models import QAQuestion
|
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
from danswer.server.models import SearchDoc
|
from danswer.server.models import QuestionRequest
|
||||||
from danswer.server.models import UserByEmail
|
from danswer.server.models import SearchResponse
|
||||||
from danswer.server.models import UserRoleResponse
|
|
||||||
from danswer.utils.clients import TSClient
|
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi import HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/get-user-role", response_model=UserRoleResponse)
|
@router.get("/semantic-search")
|
||||||
async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
|
def semantic_search(
|
||||||
if user is None:
|
question: QuestionRequest = Depends(), user: User = Depends(current_user)
|
||||||
raise ValueError("Invalid or missing user.")
|
) -> SearchResponse:
|
||||||
return UserRoleResponse(role=user.role)
|
query = question.query
|
||||||
|
collection = question.collection
|
||||||
|
filters = json.loads(question.filters) if question.filters is not None else None
|
||||||
|
logger.info(f"Received semantic search query: {query}")
|
||||||
|
|
||||||
|
user_id = None if user is None else int(user.id)
|
||||||
|
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||||
|
query, user_id, filters, QdrantIndex(collection)
|
||||||
|
)
|
||||||
|
if not ranked_chunks:
|
||||||
|
return SearchResponse(top_ranked_docs=None, semi_ranked_docs=None)
|
||||||
|
|
||||||
|
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||||
|
other_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||||
|
|
||||||
|
return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=other_top_docs)
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/promote-user-to-admin", response_model=None)
|
@router.get("/keyword-search", response_model=SearchResponse)
|
||||||
async def promote_admin(
|
def keyword_search(
|
||||||
user_email: UserByEmail, user: User = Depends(current_admin_user)
|
question: QuestionRequest = Depends(), user: User = Depends(current_user)
|
||||||
) -> None:
|
) -> SearchResponse:
|
||||||
if user.role != UserRole.ADMIN:
|
query = question.query
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
collection = question.collection
|
||||||
async with AsyncSession(build_async_engine()) as asession:
|
filters = json.loads(question.filters) if question.filters is not None else None
|
||||||
user_db = SQLAlchemyUserDatabase(asession, User) # type: ignore
|
logger.info(f"Received keyword search query: {query}")
|
||||||
user_to_promote = await user_db.get_by_email(user_email.user_email)
|
|
||||||
if not user_to_promote:
|
user_id = None if user is None else int(user.id)
|
||||||
raise HTTPException(status_code=404, detail="User not found")
|
ranked_chunks = retrieve_keyword_documents(
|
||||||
user_to_promote.role = UserRole.ADMIN
|
query, user_id, filters, TypesenseIndex(collection)
|
||||||
asession.add(user_to_promote)
|
)
|
||||||
await asession.commit()
|
if not ranked_chunks:
|
||||||
return
|
return SearchResponse(top_ranked_docs=None, semi_ranked_docs=None)
|
||||||
|
|
||||||
|
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||||
|
return SearchResponse(top_ranked_docs=top_docs, semi_ranked_docs=None)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/direct-qa", response_model=QAResponse)
|
@router.get("/direct-qa", response_model=QAResponse)
|
||||||
def direct_qa(
|
def direct_qa(
|
||||||
question: QAQuestion = Depends(), user: User = Depends(current_user)
|
question: QuestionRequest = Depends(), user: User = Depends(current_user)
|
||||||
) -> QAResponse:
|
) -> QAResponse:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
query = question.query
|
query = question.query
|
||||||
collection = question.collection
|
collection = question.collection
|
||||||
filters = question.filters
|
filters = json.loads(question.filters) if question.filters is not None else None
|
||||||
logger.info(f"Received semantic query: {query}")
|
use_keyword = question.use_keyword
|
||||||
|
logger.info(f"Received QA query: {query}")
|
||||||
|
|
||||||
user_id = None if user is None else int(user.id)
|
user_id = None if user is None else int(user.id)
|
||||||
ranked_chunks = retrieve_ranked_documents(
|
if use_keyword:
|
||||||
query, user_id, filters, create_datastore(collection)
|
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||||
)
|
query, user_id, filters, TypesenseIndex(collection)
|
||||||
if not ranked_chunks:
|
)
|
||||||
return QAResponse(answer=None, quotes=None, ranked_documents=None)
|
unranked_chunks: list[InferenceChunk] | None = []
|
||||||
|
else:
|
||||||
top_docs = [
|
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||||
SearchDoc(
|
query, user_id, filters, QdrantIndex(collection)
|
||||||
semantic_identifier=chunk.semantic_identifier,
|
)
|
||||||
link=chunk.source_links.get(0) if chunk.source_links else None,
|
if not ranked_chunks:
|
||||||
blurb=chunk.blurb,
|
return QAResponse(
|
||||||
source_type=chunk.source_type,
|
answer=None, quotes=None, ranked_documents=None, unranked_documents=None
|
||||||
)
|
)
|
||||||
for chunk in ranked_chunks
|
|
||||||
]
|
|
||||||
|
|
||||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||||
try:
|
try:
|
||||||
answer, quotes = qa_model.answer_question(
|
answer, quotes = qa_model.answer_question(
|
||||||
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
query, ranked_chunks[:NUM_GENERATIVE_AI_INPUT_DOCS]
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
# exception is logged in the answer_question method, no need to re-log
|
# exception is logged in the answer_question method, no need to re-log
|
||||||
@@ -98,45 +107,54 @@ def direct_qa(
|
|||||||
|
|
||||||
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
||||||
|
|
||||||
return QAResponse(answer=answer, quotes=quotes, ranked_documents=top_docs)
|
return QAResponse(
|
||||||
|
answer=answer,
|
||||||
|
quotes=quotes,
|
||||||
|
ranked_documents=chunks_to_search_docs(ranked_chunks),
|
||||||
|
unranked_documents=chunks_to_search_docs(unranked_chunks),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/stream-direct-qa")
|
@router.get("/stream-direct-qa")
|
||||||
def stream_direct_qa(
|
def stream_direct_qa(
|
||||||
question: QAQuestion = Depends(), user: User = Depends(current_user)
|
question: QuestionRequest = Depends(), user: User = Depends(current_user)
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
top_documents_key = "top_documents"
|
top_documents_key = "top_documents"
|
||||||
|
unranked_top_docs_key = "unranked_top_documents"
|
||||||
|
|
||||||
def stream_qa_portions() -> Generator[str, None, None]:
|
def stream_qa_portions() -> Generator[str, None, None]:
|
||||||
query = question.query
|
query = question.query
|
||||||
collection = question.collection
|
collection = question.collection
|
||||||
filters = question.filters
|
filters = json.loads(question.filters) if question.filters is not None else None
|
||||||
logger.info(f"Received semantic query: {query}")
|
use_keyword = question.use_keyword
|
||||||
|
logger.info(f"Received QA query: {query}")
|
||||||
|
|
||||||
user_id = None if user is None else int(user.id)
|
user_id = None if user is None else int(user.id)
|
||||||
ranked_chunks = retrieve_ranked_documents(
|
if use_keyword:
|
||||||
query, user_id, filters, create_datastore(collection)
|
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||||
)
|
query, user_id, filters, TypesenseIndex(collection)
|
||||||
|
)
|
||||||
|
unranked_chunks: list[InferenceChunk] | None = []
|
||||||
|
else:
|
||||||
|
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||||
|
query, user_id, filters, QdrantIndex(collection)
|
||||||
|
)
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
yield get_json_line({top_documents_key: None})
|
yield get_json_line({top_documents_key: None, unranked_top_docs_key: None})
|
||||||
return
|
return
|
||||||
|
|
||||||
top_docs = [
|
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||||
SearchDoc(
|
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||||
semantic_identifier=chunk.semantic_identifier,
|
top_docs_dict = {
|
||||||
link=chunk.source_links.get(0) if chunk.source_links else None,
|
top_documents_key: [top_doc.json() for top_doc in top_docs],
|
||||||
blurb=chunk.blurb,
|
unranked_top_docs_key: [doc.json() for doc in unranked_top_docs],
|
||||||
source_type=chunk.source_type,
|
}
|
||||||
)
|
|
||||||
for chunk in ranked_chunks
|
|
||||||
]
|
|
||||||
top_docs_dict = {top_documents_key: [top_doc.json() for top_doc in top_docs]}
|
|
||||||
yield get_json_line(top_docs_dict)
|
yield get_json_line(top_docs_dict)
|
||||||
|
|
||||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||||
try:
|
try:
|
||||||
for response_dict in qa_model.answer_question_stream(
|
for response_dict in qa_model.answer_question_stream(
|
||||||
query, ranked_chunks[:NUM_RERANKED_RESULTS]
|
query, ranked_chunks[:NUM_GENERATIVE_AI_INPUT_DOCS]
|
||||||
):
|
):
|
||||||
if response_dict is None:
|
if response_dict is None:
|
||||||
continue
|
continue
|
||||||
@@ -145,36 +163,6 @@ def stream_direct_qa(
|
|||||||
except Exception:
|
except Exception:
|
||||||
# exception is logged in the answer_question method, no need to re-log
|
# exception is logged in the answer_question method, no need to re-log
|
||||||
pass
|
pass
|
||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/keyword-search", response_model=KeywordResponse)
|
|
||||||
def keyword_search(
|
|
||||||
question: QAQuestion = Depends(), _: User = Depends(current_user)
|
|
||||||
) -> KeywordResponse:
|
|
||||||
ts_client = TSClient.get_instance()
|
|
||||||
query = question.query
|
|
||||||
collection = question.collection
|
|
||||||
|
|
||||||
logger.info(f"Received keyword query: {query}")
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
search_results = ts_client.collections[collection].documents.search(
|
|
||||||
{
|
|
||||||
"q": query,
|
|
||||||
"query_by": CONTENT,
|
|
||||||
"per_page": KEYWORD_MAX_HITS,
|
|
||||||
"limit_hits": KEYWORD_MAX_HITS,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
hits = search_results["hits"]
|
|
||||||
sources = [hit["document"][SOURCE_LINKS][0] for hit in hits]
|
|
||||||
|
|
||||||
total_time = time.time() - start_time
|
|
||||||
logger.info(f"Total Keyword Search took {total_time} seconds")
|
|
||||||
|
|
||||||
return KeywordResponse(results=sources)
|
|
||||||
|
@@ -1,8 +1,4 @@
|
|||||||
from typing import Any
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import typesense # type: ignore
|
import typesense # type: ignore
|
||||||
from danswer.configs.app_configs import DB_CONN_TIMEOUT
|
|
||||||
from danswer.configs.app_configs import QDRANT_API_KEY
|
from danswer.configs.app_configs import QDRANT_API_KEY
|
||||||
from danswer.configs.app_configs import QDRANT_HOST
|
from danswer.configs.app_configs import QDRANT_HOST
|
||||||
from danswer.configs.app_configs import QDRANT_PORT
|
from danswer.configs.app_configs import QDRANT_PORT
|
||||||
@@ -14,6 +10,7 @@ from qdrant_client import QdrantClient
|
|||||||
|
|
||||||
|
|
||||||
_qdrant_client: QdrantClient | None = None
|
_qdrant_client: QdrantClient | None = None
|
||||||
|
_typesense_client: typesense.Client | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_qdrant_client() -> QdrantClient:
|
def get_qdrant_client() -> QdrantClient:
|
||||||
@@ -29,35 +26,23 @@ def get_qdrant_client() -> QdrantClient:
|
|||||||
return _qdrant_client
|
return _qdrant_client
|
||||||
|
|
||||||
|
|
||||||
class TSClient:
|
def get_typesense_client() -> typesense.Client:
|
||||||
__instance: Optional["TSClient"] = None
|
global _typesense_client
|
||||||
|
if _typesense_client is None:
|
||||||
@staticmethod
|
if TYPESENSE_HOST and TYPESENSE_PORT and TYPESENSE_API_KEY:
|
||||||
def get_instance(
|
_typesense_client = typesense.Client(
|
||||||
host: str = TYPESENSE_HOST,
|
|
||||||
port: int = TYPESENSE_PORT,
|
|
||||||
api_key: str = TYPESENSE_API_KEY,
|
|
||||||
timeout: int = DB_CONN_TIMEOUT,
|
|
||||||
) -> "TSClient":
|
|
||||||
if TSClient.__instance is None:
|
|
||||||
TSClient(host, port, api_key, timeout)
|
|
||||||
return TSClient.__instance # type: ignore
|
|
||||||
|
|
||||||
def __init__(self, host: str, port: int, api_key: str, timeout: int) -> None:
|
|
||||||
if TSClient.__instance is not None:
|
|
||||||
raise Exception(
|
|
||||||
"Singleton instance already exists. Use TSClient.get_instance() to get the instance."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
TSClient.__instance = self
|
|
||||||
self.client = typesense.Client(
|
|
||||||
{
|
{
|
||||||
"api_key": api_key,
|
"api_key": TYPESENSE_API_KEY,
|
||||||
"nodes": [{"host": host, "port": str(port), "protocol": "http"}],
|
"nodes": [
|
||||||
"connection_timeout_seconds": timeout,
|
{
|
||||||
|
"host": TYPESENSE_HOST,
|
||||||
|
"port": str(TYPESENSE_PORT),
|
||||||
|
"protocol": "http",
|
||||||
|
}
|
||||||
|
],
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise Exception("Unable to instantiate TypesenseClient")
|
||||||
|
|
||||||
# delegate all client operations to the third party client
|
return _typesense_client
|
||||||
def __getattr__(self, name: str) -> Any:
|
|
||||||
return getattr(self.client, name)
|
|
||||||
|
@@ -1,17 +1,17 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any
|
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
from danswer.chunking.chunk import Chunker
|
from danswer.chunking.chunk import Chunker
|
||||||
from danswer.chunking.chunk import DefaultChunker
|
from danswer.chunking.chunk import DefaultChunker
|
||||||
from danswer.chunking.models import EmbeddedIndexChunk
|
from danswer.chunking.models import EmbeddedIndexChunk
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.datastores.interfaces import Datastore
|
from danswer.datastores.interfaces import KeywordIndex
|
||||||
from danswer.datastores.qdrant.store import QdrantDatastore
|
from danswer.datastores.interfaces import VectorIndex
|
||||||
from danswer.semantic_search.biencoder import DefaultEmbedder
|
from danswer.datastores.qdrant.store import QdrantIndex
|
||||||
from danswer.semantic_search.type_aliases import Embedder
|
from danswer.datastores.typesense.store import TypesenseIndex
|
||||||
|
from danswer.search.semantic_search import DefaultEmbedder
|
||||||
|
from danswer.search.type_aliases import Embedder
|
||||||
|
|
||||||
|
|
||||||
class IndexingPipelineProtocol(Protocol):
|
class IndexingPipelineProtocol(Protocol):
|
||||||
@@ -25,15 +25,18 @@ def _indexing_pipeline(
|
|||||||
*,
|
*,
|
||||||
chunker: Chunker,
|
chunker: Chunker,
|
||||||
embedder: Embedder,
|
embedder: Embedder,
|
||||||
datastore: Datastore,
|
vector_index: VectorIndex,
|
||||||
|
keyword_index: KeywordIndex,
|
||||||
documents: list[Document],
|
documents: list[Document],
|
||||||
user_id: int | None,
|
user_id: int | None,
|
||||||
) -> list[EmbeddedIndexChunk]:
|
) -> list[EmbeddedIndexChunk]:
|
||||||
# TODO: make entire indexing pipeline async to not block the entire process
|
# TODO: make entire indexing pipeline async to not block the entire process
|
||||||
# when running on async endpoints
|
# when running on async endpoints
|
||||||
chunks = list(chain(*[chunker.chunk(document) for document in documents]))
|
chunks = list(chain(*[chunker.chunk(document) for document in documents]))
|
||||||
|
# TODO keyword indexing can occur at same time as embedding
|
||||||
|
keyword_index.index(chunks, user_id)
|
||||||
chunks_with_embeddings = embedder.embed(chunks)
|
chunks_with_embeddings = embedder.embed(chunks)
|
||||||
datastore.index(chunks_with_embeddings, user_id)
|
vector_index.index(chunks_with_embeddings, user_id)
|
||||||
return chunks_with_embeddings
|
return chunks_with_embeddings
|
||||||
|
|
||||||
|
|
||||||
@@ -41,7 +44,8 @@ def build_indexing_pipeline(
|
|||||||
*,
|
*,
|
||||||
chunker: Chunker | None = None,
|
chunker: Chunker | None = None,
|
||||||
embedder: Embedder | None = None,
|
embedder: Embedder | None = None,
|
||||||
datastore: Datastore | None = None,
|
vector_index: VectorIndex | None = None,
|
||||||
|
keyword_index: KeywordIndex | None = None,
|
||||||
) -> IndexingPipelineProtocol:
|
) -> IndexingPipelineProtocol:
|
||||||
"""Builds a pipline which takes in a list of docs and indexes them.
|
"""Builds a pipline which takes in a list of docs and indexes them.
|
||||||
|
|
||||||
@@ -52,9 +56,16 @@ def build_indexing_pipeline(
|
|||||||
if embedder is None:
|
if embedder is None:
|
||||||
embedder = DefaultEmbedder()
|
embedder = DefaultEmbedder()
|
||||||
|
|
||||||
if datastore is None:
|
if vector_index is None:
|
||||||
datastore = QdrantDatastore()
|
vector_index = QdrantIndex()
|
||||||
|
|
||||||
|
if keyword_index is None:
|
||||||
|
keyword_index = TypesenseIndex()
|
||||||
|
|
||||||
return partial(
|
return partial(
|
||||||
_indexing_pipeline, chunker=chunker, embedder=embedder, datastore=datastore
|
_indexing_pipeline,
|
||||||
|
chunker=chunker,
|
||||||
|
embedder=embedder,
|
||||||
|
vector_index=vector_index,
|
||||||
|
keyword_index=keyword_index,
|
||||||
)
|
)
|
||||||
|
@@ -13,6 +13,7 @@ httpcore==0.16.3
|
|||||||
httpx==0.23.3
|
httpx==0.23.3
|
||||||
httpx-oauth==0.11.2
|
httpx-oauth==0.11.2
|
||||||
Mako==1.2.4
|
Mako==1.2.4
|
||||||
|
nltk==3.8.1
|
||||||
openai==0.27.6
|
openai==0.27.6
|
||||||
playwright==1.32.1
|
playwright==1.32.1
|
||||||
psycopg2==2.9.6
|
psycopg2==2.9.6
|
||||||
@@ -21,7 +22,7 @@ pydantic==1.10.7
|
|||||||
PyGithub==1.58.2
|
PyGithub==1.58.2
|
||||||
PyPDF2==3.0.1
|
PyPDF2==3.0.1
|
||||||
pytest-playwright==0.3.2
|
pytest-playwright==0.3.2
|
||||||
qdrant-client==1.1.0
|
qdrant-client==1.2.0
|
||||||
requests==2.28.2
|
requests==2.28.2
|
||||||
rfc3986==1.5.0
|
rfc3986==1.5.0
|
||||||
sentence-transformers==2.2.2
|
sentence-transformers==2.2.2
|
||||||
|
39
backend/scripts/reset_indexes.py
Normal file
39
backend/scripts/reset_indexes.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# This file is purely for development use, not included in any builds
|
||||||
|
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||||
|
from danswer.datastores.typesense.store import create_typesense_collection
|
||||||
|
from danswer.utils.clients import get_qdrant_client
|
||||||
|
from danswer.utils.clients import get_typesense_client
|
||||||
|
from danswer.utils.logging import setup_logger
|
||||||
|
from qdrant_client.http.models import Distance
|
||||||
|
from qdrant_client.http.models import VectorParams
|
||||||
|
from typesense.exceptions import ObjectNotFound # type: ignore
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def recreate_qdrant_collection(
|
||||||
|
collection_name: str, embedding_dim: int = DOC_EMBEDDING_DIM
|
||||||
|
) -> None:
|
||||||
|
logger.info(f"Attempting to recreate Qdrant collection {collection_name}")
|
||||||
|
result = get_qdrant_client().recreate_collection(
|
||||||
|
collection_name=collection_name,
|
||||||
|
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
raise RuntimeError("Could not create Qdrant collection")
|
||||||
|
|
||||||
|
|
||||||
|
def recreate_typesense_collection(collection_name: str) -> None:
|
||||||
|
logger.info(f"Attempting to recreate Typesense collection {collection_name}")
|
||||||
|
ts_client = get_typesense_client()
|
||||||
|
try:
|
||||||
|
ts_client.collections[collection_name].delete()
|
||||||
|
except ObjectNotFound:
|
||||||
|
logger.debug(f"Collection {collection_name} does not already exist")
|
||||||
|
|
||||||
|
create_typesense_collection(collection_name)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
recreate_qdrant_collection("danswer_index")
|
||||||
|
recreate_typesense_collection("danswer_index")
|
@@ -1,13 +1,12 @@
|
|||||||
|
# This file is purely for development use, not included in any builds
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import urllib
|
import urllib
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from danswer.configs.app_configs import APP_PORT
|
from danswer.configs.app_configs import APP_PORT
|
||||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||||
from danswer.configs.constants import BLURB
|
|
||||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
|
||||||
from danswer.configs.constants import SOURCE_LINK
|
|
||||||
from danswer.configs.constants import SOURCE_TYPE
|
from danswer.configs.constants import SOURCE_TYPE
|
||||||
|
|
||||||
|
|
||||||
@@ -16,35 +15,44 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-k",
|
"-f",
|
||||||
"--keyword-search",
|
"--flow",
|
||||||
action="store_true",
|
type=str,
|
||||||
help="Use keyword search if set, semantic search otherwise",
|
default="QA",
|
||||||
|
help='"Search" or "QA", defaults to "QA"',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-t",
|
"-t",
|
||||||
"--source-types",
|
"--type",
|
||||||
type=str,
|
type=str,
|
||||||
help="Comma separated list of source types to filter by",
|
default="Semantic",
|
||||||
|
help='"Semantic" or "Keyword", defaults to "Semantic"',
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-s",
|
"-s",
|
||||||
"--stream",
|
"--stream",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable streaming response",
|
help='Enable streaming response, only for flow="QA"',
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--filters",
|
||||||
|
type=str,
|
||||||
|
help="Comma separated list of source types to filter by (no spaces)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("query", nargs="*", help="The query to process")
|
parser.add_argument("query", nargs="*", help="The query to process")
|
||||||
|
|
||||||
|
previous_input = None
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
user_input = input(
|
user_input = input(
|
||||||
"\n\nAsk any question:\n"
|
"\n\nAsk any question:\n"
|
||||||
" - prefix with -t to add a filter by source type(s)\n"
|
" - Use -f (QA/Search) and -t (Semantic/Keyword) flags to set endpoint.\n"
|
||||||
" - prefix with -s to stream answer\n"
|
" - prefix with -s to stream answer, --filters web,slack etc. for filters.\n"
|
||||||
" - input an empty string to rerun last query\n\t"
|
" - input an empty string to rerun last query.\n\t"
|
||||||
)
|
)
|
||||||
|
|
||||||
if user_input:
|
if user_input:
|
||||||
@@ -58,62 +66,51 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parser.parse_args(user_input.split())
|
args = parser.parse_args(user_input.split())
|
||||||
|
|
||||||
keyword_search = args.keyword_search
|
flow = str(args.flow).lower()
|
||||||
source_types = args.source_types.split(",") if args.source_types else None
|
flow_type = str(args.type).lower()
|
||||||
|
stream = args.stream
|
||||||
|
source_types = args.filters.split(",") if args.filters else None
|
||||||
if source_types and len(source_types) == 1:
|
if source_types and len(source_types) == 1:
|
||||||
source_types = source_types[0]
|
source_types = source_types[0]
|
||||||
query = " ".join(args.query)
|
query = " ".join(args.query)
|
||||||
|
|
||||||
endpoint = (
|
if flow not in ["qa", "search"]:
|
||||||
f"http://127.0.0.1:{APP_PORT}/direct-qa"
|
raise ValueError("Flow value must be QA or Search")
|
||||||
if not args.stream
|
if flow_type not in ["keyword", "semantic"]:
|
||||||
else f"http://127.0.0.1:{APP_PORT}/stream-direct-qa"
|
raise ValueError("Type value must be keyword or semantic")
|
||||||
)
|
if flow != "qa" and stream:
|
||||||
if args.keyword_search:
|
raise ValueError("Can only stream results for QA")
|
||||||
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
|
|
||||||
raise NotImplementedError("keyword search is not supported for now")
|
if (flow, flow_type) == ("search", "keyword"):
|
||||||
|
path = "keyword-search"
|
||||||
|
elif (flow, flow_type) == ("search", "semantic"):
|
||||||
|
path = "semantic-search"
|
||||||
|
elif stream:
|
||||||
|
path = "stream-direct-qa"
|
||||||
|
else:
|
||||||
|
path = "direct-qa"
|
||||||
|
|
||||||
|
endpoint = f"http://127.0.0.1:{APP_PORT}/{path}"
|
||||||
|
|
||||||
query_json = {
|
query_json = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"collection": QDRANT_DEFAULT_COLLECTION,
|
"collection": QDRANT_DEFAULT_COLLECTION,
|
||||||
"filters": [{SOURCE_TYPE: source_types}],
|
"use_keyword": flow_type == "keyword", # Ignore if not QA Endpoints
|
||||||
|
"filters": json.dumps([{SOURCE_TYPE: source_types}]),
|
||||||
}
|
}
|
||||||
if not args.stream:
|
|
||||||
response = requests.get(
|
if args.stream:
|
||||||
endpoint, params=urllib.parse.urlencode(query_json)
|
|
||||||
)
|
|
||||||
contents = json.loads(response.content)
|
|
||||||
if keyword_search:
|
|
||||||
if contents["results"]:
|
|
||||||
for link in contents["results"]:
|
|
||||||
print(link)
|
|
||||||
else:
|
|
||||||
print("No matches found")
|
|
||||||
else:
|
|
||||||
answer = contents.get("answer")
|
|
||||||
if answer:
|
|
||||||
print("Answer: " + answer)
|
|
||||||
else:
|
|
||||||
print("Answer: ?")
|
|
||||||
if contents.get("quotes"):
|
|
||||||
for ind, (quote, quote_info) in enumerate(
|
|
||||||
contents["quotes"].items()
|
|
||||||
):
|
|
||||||
print(f"Quote {str(ind + 1)}:\n{quote}")
|
|
||||||
print(
|
|
||||||
f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}"
|
|
||||||
)
|
|
||||||
print(f"Blurb: {quote_info[BLURB]}")
|
|
||||||
print(f"Link: {quote_info[SOURCE_LINK]}")
|
|
||||||
print(f"Source: {quote_info[SOURCE_TYPE]}")
|
|
||||||
else:
|
|
||||||
print("No quotes found")
|
|
||||||
else:
|
|
||||||
with requests.get(
|
with requests.get(
|
||||||
endpoint, params=urllib.parse.urlencode(query_json), stream=True
|
endpoint, params=urllib.parse.urlencode(query_json), stream=True
|
||||||
) as r:
|
) as r:
|
||||||
for json_response in r.iter_lines():
|
for json_response in r.iter_lines():
|
||||||
print(json.loads(json_response.decode()))
|
pprint(json.loads(json_response.decode()))
|
||||||
|
else:
|
||||||
|
response = requests.get(
|
||||||
|
endpoint, params=urllib.parse.urlencode(query_json)
|
||||||
|
)
|
||||||
|
contents = json.loads(response.content)
|
||||||
|
pprint(contents)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed due to {e}, retrying")
|
print(f"Failed due to {e}, retrying")
|
||||||
|
@@ -1,5 +1,2 @@
|
|||||||
# For a local deployment, no additional setup is needed
|
# This empty .env file is provided for compatibility with older Docker/Docker-Compose installations
|
||||||
# Refer to env.dev.template and env.prod.template for additional options
|
# To change default values, check env.dev.template or env.prod.template
|
||||||
|
|
||||||
# Setting Auth to false for local setup convenience to avoid setting up Google OAuth app in GPC.
|
|
||||||
DISABLE_AUTH=True
|
|
||||||
|
@@ -7,6 +7,7 @@ services:
|
|||||||
depends_on:
|
depends_on:
|
||||||
- relational_db
|
- relational_db
|
||||||
- vector_db
|
- vector_db
|
||||||
|
- search_engine
|
||||||
restart: always
|
restart: always
|
||||||
ports:
|
ports:
|
||||||
- "8080:8080"
|
- "8080:8080"
|
||||||
@@ -15,6 +16,9 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- POSTGRES_HOST=relational_db
|
- POSTGRES_HOST=relational_db
|
||||||
- QDRANT_HOST=vector_db
|
- QDRANT_HOST=vector_db
|
||||||
|
- TYPESENSE_HOST=search_engine
|
||||||
|
- TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-local_dev_typesense}
|
||||||
|
- DISABLE_AUTH=True
|
||||||
volumes:
|
volumes:
|
||||||
- local_dynamic_storage:/home/storage
|
- local_dynamic_storage:/home/storage
|
||||||
background:
|
background:
|
||||||
@@ -43,12 +47,13 @@ services:
|
|||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
- INTERNAL_URL=http://api_server:8080
|
- INTERNAL_URL=http://api_server:8080
|
||||||
|
- DISABLE_AUTH=True
|
||||||
relational_db:
|
relational_db:
|
||||||
image: postgres:15.2-alpine
|
image: postgres:15.2-alpine
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: ${POSTGRES_USER:-postgres}
|
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-password}
|
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
ports:
|
ports:
|
||||||
@@ -58,10 +63,24 @@ services:
|
|||||||
vector_db:
|
vector_db:
|
||||||
image: qdrant/qdrant:v1.1.3
|
image: qdrant/qdrant:v1.1.3
|
||||||
restart: always
|
restart: always
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
ports:
|
ports:
|
||||||
- "6333:6333"
|
- "6333:6333"
|
||||||
volumes:
|
volumes:
|
||||||
- qdrant_volume:/qdrant/storage
|
- qdrant_volume:/qdrant/storage
|
||||||
|
search_engine:
|
||||||
|
image: typesense/typesense:0.24.1
|
||||||
|
restart: always
|
||||||
|
environment:
|
||||||
|
- TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-local_dev_typesense}
|
||||||
|
- TYPESENSE_DATA_DIR=/typesense/storage
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
ports:
|
||||||
|
- "8108:8108"
|
||||||
|
volumes:
|
||||||
|
- typesense_volume:/typesense/storage
|
||||||
nginx:
|
nginx:
|
||||||
image: nginx:1.23.4-alpine
|
image: nginx:1.23.4-alpine
|
||||||
restart: always
|
restart: always
|
||||||
@@ -82,3 +101,4 @@ volumes:
|
|||||||
local_dynamic_storage:
|
local_dynamic_storage:
|
||||||
db_volume:
|
db_volume:
|
||||||
qdrant_volume:
|
qdrant_volume:
|
||||||
|
typesense_volume:
|
||||||
|
@@ -7,12 +7,15 @@ services:
|
|||||||
depends_on:
|
depends_on:
|
||||||
- relational_db
|
- relational_db
|
||||||
- vector_db
|
- vector_db
|
||||||
|
- search_engine
|
||||||
restart: always
|
restart: always
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- .env
|
||||||
environment:
|
environment:
|
||||||
- POSTGRES_HOST=relational_db
|
- POSTGRES_HOST=relational_db
|
||||||
- QDRANT_HOST=vector_db
|
- QDRANT_HOST=vector_db
|
||||||
|
- TYPESENSE_HOST=search_engine
|
||||||
|
- TYPESENSE_API_KEY=${TYPESENSE_API_KEY:-local_dev_typesense}
|
||||||
volumes:
|
volumes:
|
||||||
- local_dynamic_storage:/home/storage
|
- local_dynamic_storage:/home/storage
|
||||||
background:
|
background:
|
||||||
@@ -54,8 +57,22 @@ services:
|
|||||||
vector_db:
|
vector_db:
|
||||||
image: qdrant/qdrant:v1.1.3
|
image: qdrant/qdrant:v1.1.3
|
||||||
restart: always
|
restart: always
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
volumes:
|
volumes:
|
||||||
- qdrant_volume:/qdrant/storage
|
- qdrant_volume:/qdrant/storage
|
||||||
|
search_engine:
|
||||||
|
image: typesense/typesense:0.24.1
|
||||||
|
restart: always
|
||||||
|
# TYPESENSE_API_KEY must be set in .env file
|
||||||
|
environment:
|
||||||
|
- TYPESENSE_DATA_DIR=/typesense/storage
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
ports:
|
||||||
|
- "8108:8108"
|
||||||
|
volumes:
|
||||||
|
- typesense_volume:/typesense/storage
|
||||||
nginx:
|
nginx:
|
||||||
image: nginx:1.23.4-alpine
|
image: nginx:1.23.4-alpine
|
||||||
restart: always
|
restart: always
|
||||||
@@ -83,3 +100,4 @@ volumes:
|
|||||||
local_dynamic_storage:
|
local_dynamic_storage:
|
||||||
db_volume:
|
db_volume:
|
||||||
qdrant_volume:
|
qdrant_volume:
|
||||||
|
typesense_volume:
|
||||||
|
@@ -1,13 +1,8 @@
|
|||||||
# Very basic .env file with options that are easy to change. Allows you to deploy everything on a single machine.
|
# Very basic .env file with options that are easy to change. Allows you to deploy everything on a single machine.
|
||||||
# We don't suggest using these settings for production.
|
# .env is not required unless you wish to change defaults
|
||||||
|
|
||||||
|
|
||||||
# Choose between "openai-chat-completion" and "openai-completion"
|
# Choose between "openai-chat-completion" and "openai-completion"
|
||||||
INTERNAL_MODEL_VERSION=openai-chat-completion
|
INTERNAL_MODEL_VERSION=openai-chat-completion
|
||||||
|
|
||||||
# Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility
|
# Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
OPENAPI_MODEL_VERSION=gpt-3.5-turbo
|
OPENAPI_MODEL_VERSION=gpt-3.5-turbo
|
||||||
|
|
||||||
|
|
||||||
# Auth not necessary for local
|
|
||||||
DISABLE_AUTH=True
|
|
||||||
|
@@ -13,14 +13,8 @@ OPENAI_MODEL_VERSION=gpt-4
|
|||||||
# Could be something like danswer.companyname.com. Requires additional setup if not localhost
|
# Could be something like danswer.companyname.com. Requires additional setup if not localhost
|
||||||
WEB_DOMAIN=http://localhost:3000
|
WEB_DOMAIN=http://localhost:3000
|
||||||
|
|
||||||
|
# Required
|
||||||
# BACKEND DB can leave these as defaults
|
TYPESENSE_API_KEY=
|
||||||
POSTGRES_USER=postgres
|
|
||||||
POSTGRES_PASSWORD=password
|
|
||||||
|
|
||||||
|
|
||||||
# AUTH CONFIGS
|
|
||||||
DISABLE_AUTH=False
|
|
||||||
|
|
||||||
# Currently frontend page doesn't have basic auth, use OAuth if user auth is enabled.
|
# Currently frontend page doesn't have basic auth, use OAuth if user auth is enabled.
|
||||||
ENABLE_OAUTH=True
|
ENABLE_OAUTH=True
|
||||||
|
@@ -64,7 +64,7 @@ const searchRequestStreamed = async (
|
|||||||
const url = new URL("/api/stream-direct-qa", window.location.origin);
|
const url = new URL("/api/stream-direct-qa", window.location.origin);
|
||||||
const params = new URLSearchParams({
|
const params = new URLSearchParams({
|
||||||
query,
|
query,
|
||||||
collection: "semantic_search",
|
collection: "danswer_index",
|
||||||
}).toString();
|
}).toString();
|
||||||
url.search = params;
|
url.search = params;
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user