mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-04 17:00:24 +02:00
Feed in docs till we reach a token limit (#401)
This commit is contained in:
parent
58b75122f1
commit
b06e53a51e
@ -138,7 +138,12 @@ WEB_CONNECTOR_OAUTH_TOKEN_URL = os.environ.get("WEB_CONNECTOR_OAUTH_TOKEN_URL")
|
|||||||
#####
|
#####
|
||||||
NUM_RETURNED_HITS = 50
|
NUM_RETURNED_HITS = 50
|
||||||
NUM_RERANKED_RESULTS = 15
|
NUM_RERANKED_RESULTS = 15
|
||||||
NUM_GENERATIVE_AI_INPUT_DOCS = 5
|
# We feed in document chunks until we reach this token limit.
|
||||||
|
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks
|
||||||
|
# may be smaller which could result in passing in more total chunks
|
||||||
|
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
|
||||||
|
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
|
||||||
|
)
|
||||||
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
||||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds
|
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds
|
||||||
|
@ -2,7 +2,7 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.constants import IGNORE_FOR_QA
|
from danswer.configs.constants import IGNORE_FOR_QA
|
||||||
from danswer.datastores.document_index import get_default_document_index
|
from danswer.datastores.document_index import get_default_document_index
|
||||||
@ -11,6 +11,7 @@ from danswer.db.models import User
|
|||||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
from danswer.direct_qa.exceptions import UnknownModelError
|
||||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
|
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||||
from danswer.search.danswer_helper import query_intent
|
from danswer.search.danswer_helper import query_intent
|
||||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||||
from danswer.search.models import QueryFlow
|
from danswer.search.models import QueryFlow
|
||||||
@ -107,18 +108,19 @@ def answer_qa_query(
|
|||||||
chunk for chunk in ranked_chunks if chunk.metadata.get(IGNORE_FOR_QA)
|
chunk for chunk in ranked_chunks if chunk.metadata.get(IGNORE_FOR_QA)
|
||||||
]
|
]
|
||||||
|
|
||||||
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
|
# get all chunks that fit into the token limit
|
||||||
if chunk_offset >= len(filtered_ranked_chunks):
|
usable_chunks = get_usable_chunks(
|
||||||
raise ValueError("Chunks offset too large, should not retry this many times")
|
chunks=filtered_ranked_chunks,
|
||||||
|
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||||
|
offset=offset_count,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
|
||||||
|
)
|
||||||
|
|
||||||
error_msg = None
|
error_msg = None
|
||||||
try:
|
try:
|
||||||
answer, quotes = qa_model.answer_question(
|
answer, quotes = qa_model.answer_question(query, usable_chunks)
|
||||||
query,
|
|
||||||
filtered_ranked_chunks[
|
|
||||||
chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS
|
|
||||||
],
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# exception is logged in the answer_question method, no need to re-log
|
# exception is logged in the answer_question method, no need to re-log
|
||||||
answer, quotes = None, None
|
answer, quotes = None, None
|
||||||
|
@ -10,6 +10,7 @@ from typing import Tuple
|
|||||||
import regex
|
import regex
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||||
@ -21,6 +22,7 @@ from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
|||||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
|
from danswer.llm.utils import check_number_of_tokens
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.text_processing import clean_model_quote
|
from danswer.utils.text_processing import clean_model_quote
|
||||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||||
@ -254,3 +256,48 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
|
|||||||
"""Mock streaming by generating the passed in model output, character by character"""
|
"""Mock streaming by generating the passed in model output, character by character"""
|
||||||
for token in model_out:
|
for token in model_out:
|
||||||
yield token
|
yield token
|
||||||
|
|
||||||
|
|
||||||
|
def _get_usable_chunks(
|
||||||
|
chunks: list[InferenceChunk], token_limit: int
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
|
total_token_count = 0
|
||||||
|
usable_chunks = []
|
||||||
|
for chunk in chunks:
|
||||||
|
chunk_token_count = check_number_of_tokens(chunk.content)
|
||||||
|
if total_token_count + chunk_token_count > token_limit:
|
||||||
|
break
|
||||||
|
|
||||||
|
total_token_count += chunk_token_count
|
||||||
|
usable_chunks.append(chunk)
|
||||||
|
|
||||||
|
# try and return at least one chunk if possible. This chunk will
|
||||||
|
# get truncated later on in the pipeline. This would only occur if
|
||||||
|
# the first chunk is larger than the token limit (usually due to character
|
||||||
|
# count -> token count mismatches caused by special characters / non-ascii
|
||||||
|
# languages)
|
||||||
|
if not usable_chunks and chunks:
|
||||||
|
usable_chunks = [chunks[0]]
|
||||||
|
|
||||||
|
return usable_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def get_usable_chunks(
|
||||||
|
chunks: list[InferenceChunk],
|
||||||
|
token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
|
offset_into_chunks = 0
|
||||||
|
usable_chunks: list[InferenceChunk] = []
|
||||||
|
for _ in range(min(offset + 1, 1)): # go through this process at least once
|
||||||
|
if offset_into_chunks >= len(chunks) and offset_into_chunks > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Chunks offset too large, should not retry this many times"
|
||||||
|
)
|
||||||
|
|
||||||
|
usable_chunks = _get_usable_chunks(
|
||||||
|
chunks=chunks[offset_into_chunks:], token_limit=token_limit
|
||||||
|
)
|
||||||
|
offset_into_chunks += len(usable_chunks)
|
||||||
|
|
||||||
|
return usable_chunks
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
from langchain.prompts.base import StringPromptValue
|
from langchain.prompts.base import StringPromptValue
|
||||||
from langchain.prompts.chat import ChatPromptValue
|
from langchain.prompts.chat import ChatPromptValue
|
||||||
from langchain.schema import PromptValue
|
from langchain.schema import PromptValue
|
||||||
@ -69,3 +71,17 @@ def convert_input(lm_input: LanguageModelInput) -> str:
|
|||||||
|
|
||||||
def should_be_verbose() -> bool:
|
def should_be_verbose() -> bool:
|
||||||
return LOG_LEVEL == "debug"
|
return LOG_LEVEL == "debug"
|
||||||
|
|
||||||
|
|
||||||
|
def check_number_of_tokens(
|
||||||
|
text: str, encode_fn: Callable[[str], list] | None = None
|
||||||
|
) -> int:
|
||||||
|
"""Get's the number of tokens in the provided text, using the provided encoding
|
||||||
|
function. If none is provided, default to the tiktoken encoder used by GPT-3.5
|
||||||
|
and GPT-4.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if encode_fn is None:
|
||||||
|
encode_fn = tiktoken.get_encoding("cl100k_base").encode
|
||||||
|
|
||||||
|
return len(encode_fn(text))
|
||||||
|
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
|||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||||
from danswer.configs.constants import IGNORE_FOR_QA
|
from danswer.configs.constants import IGNORE_FOR_QA
|
||||||
from danswer.datastores.document_index import get_default_document_index
|
from danswer.datastores.document_index import get_default_document_index
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
@ -22,6 +22,7 @@ from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
|||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
from danswer.direct_qa.exceptions import UnknownModelError
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
|
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||||
from danswer.search.danswer_helper import query_intent
|
from danswer.search.danswer_helper import query_intent
|
||||||
from danswer.search.danswer_helper import recommend_search_flow
|
from danswer.search.danswer_helper import recommend_search_flow
|
||||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||||
@ -247,17 +248,19 @@ def stream_direct_qa(
|
|||||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||||
]
|
]
|
||||||
|
|
||||||
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
|
# get all chunks that fit into the token limit
|
||||||
if chunk_offset >= len(filtered_ranked_chunks):
|
usable_chunks = get_usable_chunks(
|
||||||
raise ValueError(
|
chunks=filtered_ranked_chunks,
|
||||||
"Chunks offset too large, should not retry this many times"
|
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||||
)
|
offset=offset_count,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for response_packet in qa_model.answer_question_stream(
|
for response_packet in qa_model.answer_question_stream(
|
||||||
query,
|
query, usable_chunks
|
||||||
filtered_ranked_chunks[
|
|
||||||
chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS
|
|
||||||
],
|
|
||||||
):
|
):
|
||||||
if response_packet is None:
|
if response_packet is None:
|
||||||
continue
|
continue
|
||||||
|
@ -64,6 +64,7 @@ services:
|
|||||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||||
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
|
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
|
||||||
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
|
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
|
||||||
|
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
|
||||||
- POSTGRES_HOST=relational_db
|
- POSTGRES_HOST=relational_db
|
||||||
- QDRANT_HOST=vector_db
|
- QDRANT_HOST=vector_db
|
||||||
- TYPESENSE_HOST=search_engine
|
- TYPESENSE_HOST=search_engine
|
||||||
|
@ -21,6 +21,7 @@ services:
|
|||||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||||
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
|
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
|
||||||
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
|
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
|
||||||
|
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
|
||||||
- POSTGRES_HOST=relational_db
|
- POSTGRES_HOST=relational_db
|
||||||
- VESPA_HOST=index
|
- VESPA_HOST=index
|
||||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user