Feed in docs till we reach a token limit (#401)

This commit is contained in:
Chris Weaver 2023-09-05 15:20:42 -07:00 committed by GitHub
parent 58b75122f1
commit b06e53a51e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 96 additions and 21 deletions

View File

@ -138,7 +138,12 @@ WEB_CONNECTOR_OAUTH_TOKEN_URL = os.environ.get("WEB_CONNECTOR_OAUTH_TOKEN_URL")
##### #####
NUM_RETURNED_HITS = 50 NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15 NUM_RERANKED_RESULTS = 15
NUM_GENERATIVE_AI_INPUT_DOCS = 5 # We feed in document chunks until we reach this token limit.
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks
# may be smaller which could result in passing in more total chunks
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
)
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow # 1 edit per 2 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05 QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds

View File

@ -2,7 +2,7 @@ from sqlalchemy.orm import Session
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.constants import IGNORE_FOR_QA
from danswer.datastores.document_index import get_default_document_index from danswer.datastores.document_index import get_default_document_index
@ -11,6 +11,7 @@ from danswer.db.models import User
from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow from danswer.search.models import QueryFlow
@ -107,18 +108,19 @@ def answer_qa_query(
chunk for chunk in ranked_chunks if chunk.metadata.get(IGNORE_FOR_QA) chunk for chunk in ranked_chunks if chunk.metadata.get(IGNORE_FOR_QA)
] ]
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS # get all chunks that fit into the token limit
if chunk_offset >= len(filtered_ranked_chunks): usable_chunks = get_usable_chunks(
raise ValueError("Chunks offset too large, should not retry this many times") chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset=offset_count,
)
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
)
error_msg = None error_msg = None
try: try:
answer, quotes = qa_model.answer_question( answer, quotes = qa_model.answer_question(query, usable_chunks)
query,
filtered_ranked_chunks[
chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS
],
)
except Exception as e: except Exception as e:
# exception is logged in the answer_question method, no need to re-log # exception is logged in the answer_question method, no need to re-log
answer, quotes = None, None answer, quotes = None, None

View File

@ -10,6 +10,7 @@ from typing import Tuple
import regex import regex
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.configs.model_configs import GEN_AI_API_KEY
@ -21,6 +22,7 @@ from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs import get_dynamic_config_store
from danswer.llm.utils import check_number_of_tokens
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import shared_precompare_cleanup from danswer.utils.text_processing import shared_precompare_cleanup
@ -254,3 +256,48 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
"""Mock streaming by generating the passed in model output, character by character""" """Mock streaming by generating the passed in model output, character by character"""
for token in model_out: for token in model_out:
yield token yield token
def _get_usable_chunks(
chunks: list[InferenceChunk], token_limit: int
) -> list[InferenceChunk]:
total_token_count = 0
usable_chunks = []
for chunk in chunks:
chunk_token_count = check_number_of_tokens(chunk.content)
if total_token_count + chunk_token_count > token_limit:
break
total_token_count += chunk_token_count
usable_chunks.append(chunk)
# try and return at least one chunk if possible. This chunk will
# get truncated later on in the pipeline. This would only occur if
# the first chunk is larger than the token limit (usually due to character
# count -> token count mismatches caused by special characters / non-ascii
# languages)
if not usable_chunks and chunks:
usable_chunks = [chunks[0]]
return usable_chunks
def get_usable_chunks(
chunks: list[InferenceChunk],
token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset: int = 0,
) -> list[InferenceChunk]:
offset_into_chunks = 0
usable_chunks: list[InferenceChunk] = []
for _ in range(min(offset + 1, 1)): # go through this process at least once
if offset_into_chunks >= len(chunks) and offset_into_chunks > 0:
raise ValueError(
"Chunks offset too large, should not retry this many times"
)
usable_chunks = _get_usable_chunks(
chunks=chunks[offset_into_chunks:], token_limit=token_limit
)
offset_into_chunks += len(usable_chunks)
return usable_chunks

View File

@ -1,5 +1,7 @@
from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
import tiktoken
from langchain.prompts.base import StringPromptValue from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue from langchain.prompts.chat import ChatPromptValue
from langchain.schema import PromptValue from langchain.schema import PromptValue
@ -69,3 +71,17 @@ def convert_input(lm_input: LanguageModelInput) -> str:
def should_be_verbose() -> bool: def should_be_verbose() -> bool:
return LOG_LEVEL == "debug" return LOG_LEVEL == "debug"
def check_number_of_tokens(
text: str, encode_fn: Callable[[str], list] | None = None
) -> int:
"""Get's the number of tokens in the provided text, using the provided encoding
function. If none is provided, default to the tiktoken encoder used by GPT-3.5
and GPT-4.
"""
if encode_fn is None:
encode_fn = tiktoken.get_encoding("cl100k_base").encode
return len(encode_fn(text))

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.constants import IGNORE_FOR_QA
from danswer.datastores.document_index import get_default_document_index from danswer.datastores.document_index import get_default_document_index
from danswer.db.engine import get_session from danswer.db.engine import get_session
@ -22,6 +22,7 @@ from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents from danswer.search.keyword_search import retrieve_keyword_documents
@ -247,17 +248,19 @@ def stream_direct_qa(
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA) chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
] ]
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS # get all chunks that fit into the token limit
if chunk_offset >= len(filtered_ranked_chunks): usable_chunks = get_usable_chunks(
raise ValueError( chunks=filtered_ranked_chunks,
"Chunks offset too large, should not retry this many times" token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
) offset=offset_count,
)
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
)
try: try:
for response_packet in qa_model.answer_question_stream( for response_packet in qa_model.answer_question_stream(
query, query, usable_chunks
filtered_ranked_chunks[
chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS
],
): ):
if response_packet is None: if response_packet is None:
continue continue

View File

@ -64,6 +64,7 @@ services:
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-} - GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-} - GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
- POSTGRES_HOST=relational_db - POSTGRES_HOST=relational_db
- QDRANT_HOST=vector_db - QDRANT_HOST=vector_db
- TYPESENSE_HOST=search_engine - TYPESENSE_HOST=search_engine

View File

@ -21,6 +21,7 @@ services:
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-} - GEN_AI_ENDPOINT=${GEN_AI_ENDPOINT:-}
- GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-} - GEN_AI_HOST_TYPE=${GEN_AI_HOST_TYPE:-}
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
- POSTGRES_HOST=relational_db - POSTGRES_HOST=relational_db
- VESPA_HOST=index - VESPA_HOST=index
- LOG_LEVEL=${LOG_LEVEL:-info} - LOG_LEVEL=${LOG_LEVEL:-info}