Feed in docs till we reach a token limit (#401)

This commit is contained in:
Chris Weaver 2023-09-05 15:20:42 -07:00 committed by GitHub
parent 58b75122f1
commit b06e53a51e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 96 additions and 21 deletions

View File

@ -138,7 +138,12 @@ WEB_CONNECTOR_OAUTH_TOKEN_URL = os.environ.get("WEB_CONNECTOR_OAUTH_TOKEN_URL")
#####
NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15
NUM_GENERATIVE_AI_INPUT_DOCS = 5
# We feed in document chunks until we reach this token limit.
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks
# may be smaller which could result in passing in more total chunks
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
)
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds

View File

@ -2,7 +2,7 @@ from sqlalchemy.orm import Session
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.datastores.document_index import get_default_document_index
@ -11,6 +11,7 @@ from danswer.db.models import User
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
@ -107,18 +108,19 @@ def answer_qa_query(
chunk for chunk in ranked_chunks if chunk.metadata.get(IGNORE_FOR_QA)
]
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(filtered_ranked_chunks):
raise ValueError("Chunks offset too large, should not retry this many times")
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset=offset_count,
)
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
)
error_msg = None
try:
answer, quotes = qa_model.answer_question(
query,
filtered_ranked_chunks[
chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS
],
)
answer, quotes = qa_model.answer_question(query, usable_chunks)
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
answer, quotes = None, None

View File

@ -10,6 +10,7 @@ from typing import Tuple
import regex
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.model_configs import GEN_AI_API_KEY
@ -21,6 +22,7 @@ from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.llm.utils import check_number_of_tokens
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import shared_precompare_cleanup
@ -254,3 +256,48 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
"""Mock streaming by generating the passed in model output, character by character"""
for token in model_out:
yield token
def _get_usable_chunks(
chunks: list[InferenceChunk], token_limit: int
) -> list[InferenceChunk]:
total_token_count = 0
usable_chunks = []
for chunk in chunks:
chunk_token_count = check_number_of_tokens(chunk.content)
if total_token_count + chunk_token_count > token_limit:
break
total_token_count += chunk_token_count
usable_chunks.append(chunk)
# try and return at least one chunk if possible. This chunk will
# get truncated later on in the pipeline. This would only occur if
# the first chunk is larger than the token limit (usually due to character
# count -> token count mismatches caused by special characters / non-ascii
# languages)
if not usable_chunks and chunks:
usable_chunks = [chunks[0]]
return usable_chunks
def get_usable_chunks(
chunks: list[InferenceChunk],
token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset: int = 0,
) -> list[InferenceChunk]:
offset_into_chunks = 0
usable_chunks: list[InferenceChunk] = []
for _ in range(min(offset + 1, 1)): # go through this process at least once
if offset_into_chunks >= len(chunks) and offset_into_chunks > 0:
raise ValueError(
"Chunks offset too large, should not retry this many times"
)
usable_chunks = _get_usable_chunks(
chunks=chunks[offset_into_chunks:], token_limit=token_limit
)
offset_into_chunks += len(usable_chunks)
return usable_chunks

View File

@ -1,5 +1,7 @@
from collections.abc import Callable
from collections.abc import Iterator
import tiktoken
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import PromptValue
@ -69,3 +71,17 @@ def convert_input(lm_input: LanguageModelInput) -> str:
def should_be_verbose() -> bool:
return LOG_LEVEL == "debug"
def check_number_of_tokens(
text: str, encode_fn: Callable[[str], list] | None = None
) -> int:
"""Get's the number of tokens in the provided text, using the provided encoding
function. If none is provided, default to the tiktoken encoder used by GPT-3.5
and GPT-4.
"""
if encode_fn is None:
encode_fn = tiktoken.get_encoding("cl100k_base").encode
return len(encode_fn(text))

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.datastores.document_index import get_default_document_index
from danswer.db.engine import get_session
@ -22,6 +22,7 @@ from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents
@ -247,17 +248,19 @@ def stream_direct_qa(
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(filtered_ranked_chunks):
raise ValueError(
"Chunks offset too large, should not retry this many times"
)
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset=offset_count,
)
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
)
try:
for response_packet in qa_model.answer_question_stream(
query,
filtered_ranked_chunks[
chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS
],
query, usable_chunks
):
if response_packet is None:
continue

View File

@ -64,6 +64,7 @@ services:
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
- POSTGRES_HOST=relational_db
- QDRANT_HOST=vector_db
- TYPESENSE_HOST=search_engine

View File

@ -21,6 +21,7 @@ services:
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
- POSTGRES_HOST=relational_db
- VESPA_HOST=index
- LOG_LEVEL=${LOG_LEVEL:-info}