Default LLM Update (#1042)

This commit is contained in:
Yuhong Sun 2024-02-05 01:25:51 -08:00 committed by GitHub
parent b3b88f05d3
commit 6768c24723
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 201 additions and 93 deletions

View File

@ -14,12 +14,9 @@ from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_chat_messages_by_session
@ -28,7 +25,7 @@ from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_llm_max_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
from danswer.prompts.chat_prompts import CITATION_REMINDER
@ -239,7 +236,7 @@ def _get_usable_chunks(
def get_usable_chunks(
chunks: list[InferenceChunk],
token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
token_limit: int,
offset: int = 0,
) -> list[InferenceChunk]:
offset_into_chunks = 0
@ -261,7 +258,7 @@ def get_usable_chunks(
def get_chunks_for_qa(
chunks: list[InferenceChunk],
llm_chunk_selection: list[bool],
token_limit: float | None = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
token_limit: int | None,
batch_offset: int = 0,
) -> list[int]:
"""
@ -363,10 +360,10 @@ def create_chat_chain(
def combine_message_chain(
messages: list[ChatMessage],
msg_limit: int | None = 10,
token_limit: int | None = GEN_AI_HISTORY_CUTOFF,
token_limit: int,
msg_limit: int | None = None,
) -> str:
"""Used for secondary LLM flows that require the chat history"""
"""Used for secondary LLM flows that require the chat history,"""
message_strs: list[str] = []
total_token_count = 0
@ -376,10 +373,7 @@ def combine_message_chain(
for message in reversed(messages):
message_token_count = message.token_count
if (
token_limit is not None
and total_token_count + message_token_count > token_limit
):
if total_token_count + message_token_count > token_limit:
break
role = message.message_type.value.upper()
@ -557,7 +551,9 @@ _MISC_BUFFER = 40
def compute_max_document_tokens(
persona: Persona, actual_user_input: str | None = None
persona: Persona,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
"""Estimates the number of tokens available for context documents. Formula is roughly:
@ -575,8 +571,13 @@ def compute_max_document_tokens(
llm_name = persona.llm_model_version_override
# if we can't find a number of tokens, just assume some common default
model_full_context_window = get_llm_max_tokens(llm_name) or 4096
max_input_tokens = (
max_llm_token_override
if max_llm_token_override
else get_max_input_tokens(llm_name)
)
if persona.prompts:
# TODO this may not always be the first prompt
prompt_tokens = get_prompt_tokens(persona.prompts[0])
else:
raise RuntimeError("Persona has no prompts - this should never happen")
@ -586,13 +587,7 @@ def compute_max_document_tokens(
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
)
return (
model_full_context_window
- GEN_AI_MAX_OUTPUT_TOKENS
- prompt_tokens
- user_input_tokens
- _MISC_BUFFER
)
return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER
def compute_max_llm_input_tokens(persona: Persona) -> int:
@ -601,5 +596,5 @@ def compute_max_llm_input_tokens(persona: Persona) -> int:
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
model_full_context_window = get_llm_max_tokens(llm_name) or 4096
return model_full_context_window - GEN_AI_MAX_OUTPUT_TOKENS - _MISC_BUFFER
input_tokens = get_max_input_tokens(model_name=llm_name)
return input_tokens - _MISC_BUFFER

View File

@ -3,7 +3,7 @@ from typing import cast
import yaml
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.chat import get_prompt_by_name
@ -42,7 +42,7 @@ def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
def load_personas_from_yaml(
personas_yaml: str = PERSONAS_YAML,
default_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> None:
with open(personas_yaml, "r") as file:
data = yaml.safe_load(file)

View File

@ -13,9 +13,8 @@ personas:
- "Answer-Question"
# Default number of chunks to include as context, set to 0 to disable retrieval
# Remove the field to set to the system default number of chunks/tokens to pass to Gen AI
# If selecting documents, user can bypass this up until NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
# Each chunk is 512 tokens long
num_chunks: 5
num_chunks: 10
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
# if the chunk is useful or not towards the latest user query
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
@ -46,7 +45,7 @@ personas:
extrapolate any answers for you.
prompts:
- "Summarize"
num_chunks: 5
num_chunks: 10
llm_relevance_filter: true
llm_filter_extraction: true
recency_bias: "auto"
@ -58,7 +57,7 @@ personas:
The least creative default assistant that only provides quotes from the documents.
prompts:
- "Paraphrase"
num_chunks: 5
num_chunks: 10
llm_relevance_filter: true
llm_filter_extraction: true
recency_bias: "auto"

View File

@ -22,10 +22,12 @@ from danswer.chat.models import LlmDoc
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import CHUNK_SIZE
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import CHUNK_SIZE
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
@ -46,6 +48,7 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import tokenizer_trim_content
from danswer.llm.utils import translate_history_to_basemessages
from danswer.search.models import OptionalSearchSetting
@ -156,8 +159,11 @@ def stream_chat_message(
user: User | None,
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
default_chunk_size: int = CHUNK_SIZE,
# For flow with search, don't include as many chunks as possible since we need to leave space
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
) -> Iterator[str]:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@ -260,6 +266,10 @@ def stream_chat_message(
query_message=final_msg, history=history_msgs, llm=llm
)
max_document_tokens = compute_max_document_tokens(
persona=persona, actual_user_input=message_text
)
rephrased_query = None
if reference_doc_ids:
identifier_tuples = get_doc_query_identifiers_from_model(
@ -277,9 +287,6 @@ def stream_chat_message(
)
# truncate the last document if it exceeds the token limit
max_document_tokens = compute_max_document_tokens(
persona, actual_user_input=message_text
)
tokens_per_doc = [
len(
llm_tokenizer_encode_func(
@ -431,10 +438,26 @@ def stream_chat_message(
if persona.num_chunks is not None
else default_num_chunks
)
llm_name = GEN_AI_MODEL_VERSION
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
llm_max_input_tokens = get_max_input_tokens(llm_name)
llm_token_based_chunk_lim = max_document_percentage * llm_max_input_tokens
chunk_token_limit = int(
min(
num_llm_chunks * default_chunk_size,
max_document_tokens,
llm_token_based_chunk_lim,
)
)
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
token_limit=num_llm_chunks * default_chunk_size,
token_limit=chunk_token_limit,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]

View File

@ -1,28 +1,18 @@
import os
from danswer.configs.model_configs import CHUNK_SIZE
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
PERSONAS_YAML = "./danswer/chat/personas.yaml"
NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15
# We feed in document chunks until we reach this token limit.
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks may be
# significantly smaller which could result in passing in more total chunks.
# There is also a slight bit of overhead, not accounted for here such as separator patterns
# between the docs, metadata for the docs, etc.
# Finally, this is combined with the rest of the QA prompt, so don't set this too close to the
# model token limit
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (CHUNK_SIZE * 5)
)
DEFAULT_NUM_CHUNKS_FED_TO_CHAT: float = (
float(NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL) / CHUNK_SIZE
)
NUM_DOCUMENT_TOKENS_FED_TO_CHAT = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_CHAT") or (CHUNK_SIZE * 3)
)
# May be less depending on model
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
# For Chat, need to keep enough space for history and other prompt pieces
# ~3k input, half for docs, half for chat history + prompts
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
# For selecting a different LLM question-answering prompt format
# Valid values: default, cot, weak
QA_PROMPT_OVERRIDE = os.environ.get("QA_PROMPT_OVERRIDE") or None
@ -60,7 +50,7 @@ if os.environ.get("EDIT_KEYWORD_QUERY"):
else:
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.66)))
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
# Weighting factor between Title and Content of documents during search, 1 for completely
# Title based. Default heavily favors Content because Title is also included at the top of
# Content. This is to avoid cases where the Content is very relevant but it may not be clear

View File

@ -7,6 +7,8 @@ DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
)
# How much of the available input context can be used for thread context
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
# Number of docs to display in "Reference Documents"
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")

View File

@ -78,7 +78,7 @@ INTENT_MODEL_VERSION = "danswer/intent-model"
# Set GEN_AI_MODEL_PROVIDER to "gpt4all" to use gpt4all models running locally
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
# If using Azure, it's the engine name, for example: Danswer
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo"
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo-0125"
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
FAST_GEN_AI_MODEL_VERSION = (
@ -96,14 +96,15 @@ GEN_AI_API_ENDPOINT = os.environ.get("GEN_AI_API_ENDPOINT") or None
GEN_AI_API_VERSION = os.environ.get("GEN_AI_API_VERSION") or None
# LiteLLM custom_llm_provider
GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None
# If the max tokens can't be found from the name, use this as the backup
# This happens if user is configuring a different LLM to use
GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 4096)
# Set this to be enough for an answer + quotes. Also used for Chat
GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024)
# This next restriction is only used for chat ATM, used to expire old messages as needed
GEN_AI_MAX_INPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_INPUT_TOKENS") or 3000)
# History for secondary LLM flows, not primary chat flow, generally we don't need to
# include as much as possible as this just bumps up the cost unnecessarily
GEN_AI_HISTORY_CUTOFF = int(0.5 * GEN_AI_MAX_INPUT_TOKENS)
# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible
# as this drives up the cost unnecessarily
GEN_AI_HISTORY_CUTOFF = 3000
# This is used when computing how much context space is available for documents
# ahead of time in order to let the user know if they can "select" more documents
# It represents a maximum "expected" number of input tokens from the latest user

View File

@ -11,14 +11,17 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import compute_max_document_tokens
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
@ -33,6 +36,8 @@ from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
@ -98,6 +103,7 @@ def handle_message(
disable_auto_detect_filters: bool = DISABLE_DANSWER_BOT_FILTER_DETECT,
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
) -> bool:
"""Potentially respond to the user message depending on filters and if an answer was generated
@ -215,11 +221,36 @@ def handle_message(
slack_usage_report(action=action, sender_id=sender_id, client=client)
max_document_tokens: int | None = None
max_history_tokens: int | None = None
if len(new_message_request.messages) > 1:
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(GEN_AI_MODEL_VERSION)
max_history_tokens = int(input_tokens * thread_context_percent)
remaining_tokens = input_tokens - max_history_tokens
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
)
else:
max_document_tokens = (
remaining_tokens
- 512 # Needs to be more than any of the QA prompts
- check_number_of_tokens(query_text)
)
with Session(get_sqlalchemy_engine()) as db_session:
# This also handles creating the query event in postgres
answer = get_search_answer(
query_req=new_message_request,
user=None,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
enable_reflexion=reflexion,

View File

@ -723,7 +723,6 @@ class Persona(Base):
Enum(SearchType), default=SearchType.HYBRID
)
# Number of chunks to pass to the LLM for generation.
# If unspecified, uses the default DEFAULT_NUM_CHUNKS_FED_TO_CHAT set in the env variable
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
# Pass every chunk through LLM for evaluation, fairly expensive
# Can be turned off globally by admin, in which case, this setting is ignored

View File

@ -3,7 +3,7 @@ from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.db.chat import upsert_persona
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
from danswer.db.document_set import get_document_sets_by_ids
@ -35,7 +35,7 @@ def create_slack_bot_persona(
channel_names: list[str],
document_set_ids: list[int],
existing_persona_id: int | None = None,
num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> Persona:
"""NOTE: does not commit changes"""
document_sets = list(

View File

@ -22,6 +22,9 @@ from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.models import ChatMessage
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
@ -59,7 +62,6 @@ def get_default_llm_token_encode() -> Callable[[str], Any]:
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: Encoding
) -> str:
tokenizer = get_default_llm_tokenizer()
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
@ -201,9 +203,24 @@ def test_llm(llm: LLM) -> bool:
return False
def get_llm_max_tokens(model_name: str) -> int | None:
def get_llm_max_tokens(model_name: str | None = GEN_AI_MODEL_VERSION) -> int:
"""Best effort attempt to get the max tokens for the LLM"""
if not model_name:
return GEN_AI_MAX_TOKENS
try:
return get_max_tokens(model_name)
except Exception:
return None
return GEN_AI_MAX_TOKENS
def get_max_input_tokens(
model_name: str | None = GEN_AI_MODEL_VERSION,
output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
) -> int:
input_toks = get_llm_max_tokens(model_name) - output_tokens
if input_toks <= 0:
raise RuntimeError("No tokens for input for the LLM given settings")
return input_toks

View File

@ -5,6 +5,7 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import compute_max_document_tokens
from danswer.chat.chat_utils import get_chunks_for_qa
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContext
@ -14,7 +15,7 @@ from danswer.chat.models import LLMMetricsContainer
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import DEFAULT_NUM_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import CHUNK_SIZE
@ -54,9 +55,14 @@ logger = setup_logger()
def stream_answer_objects(
query_req: DirectQARequest,
user: User | None,
# These need to be passed in because in Web UI one shot flow,
# we can have much more document as there is no history.
# For Slack flow, we need to save more tokens for the thread context
max_document_tokens: int | None,
max_history_tokens: int | None,
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = DEFAULT_NUM_CHUNKS_FED_TO_CHAT,
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
default_chunk_size: int = CHUNK_SIZE,
timeout: int = QA_TIMEOUT,
bypass_acl: bool = False,
@ -106,7 +112,9 @@ def stream_answer_objects(
chat_session_id=chat_session.id, db_session=db_session
)
history_str = combine_message_thread(history)
history_str = combine_message_thread(
messages=history, max_tokens=max_history_tokens
)
rephrased_query = thread_based_query_rephrase(
user_query=query_msg.message,
@ -174,10 +182,20 @@ def stream_answer_objects(
if chat_session.persona.num_chunks is not None
else default_num_chunks
)
chunk_token_limit = int(num_llm_chunks * default_chunk_size)
if max_document_tokens:
chunk_token_limit = min(chunk_token_limit, max_document_tokens)
else:
max_document_tokens = compute_max_document_tokens(
persona=chat_session.persona, actual_user_input=query_msg.message
)
chunk_token_limit = min(chunk_token_limit, max_document_tokens)
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
token_limit=num_llm_chunks * default_chunk_size,
token_limit=chunk_token_limit,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
@ -288,10 +306,16 @@ def stream_answer_objects(
def stream_search_answer(
query_req: DirectQARequest,
user: User | None,
max_document_tokens: int | None,
max_history_tokens: int | None,
db_session: Session,
) -> Iterator[str]:
objects = stream_answer_objects(
query_req=query_req, user=user, db_session=db_session
query_req=query_req,
user=user,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
)
for obj in objects:
yield get_json_line(obj.dict())
@ -300,6 +324,8 @@ def stream_search_answer(
def get_search_answer(
query_req: DirectQARequest,
user: User | None,
max_document_tokens: int | None,
max_history_tokens: int | None,
db_session: Session,
answer_generation_timeout: int = QA_TIMEOUT,
enable_reflexion: bool = False,
@ -315,6 +341,8 @@ def get_search_answer(
results = stream_answer_objects(
query_req=query_req,
user=user,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
bypass_acl=bypass_acl,
timeout=answer_generation_timeout,

View File

@ -15,7 +15,6 @@ from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import ThreadMessage
@ -279,10 +278,13 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
def combine_message_thread(
messages: list[ThreadMessage],
token_limit: int | None = GEN_AI_HISTORY_CUTOFF,
max_tokens: int | None,
llm_tokenizer: Callable | None = None,
) -> str:
"""Used to create a single combined message context from threads"""
if not messages:
return ""
message_strs: list[str] = []
total_token_count = 0
if llm_tokenizer is None:
@ -304,8 +306,8 @@ def combine_message_thread(
message_token_count = len(llm_tokenizer(msg_str))
if (
token_limit is not None
and total_token_count + message_token_count > token_limit
max_tokens is not None
and total_token_count + message_token_count > max_tokens
):
break

View File

@ -1,4 +1,5 @@
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
@ -31,7 +32,9 @@ def get_renamed_conversation_name(
# clear thing we can do
return full_history[0].message
history_str = combine_message_chain(full_history)
history_str = combine_message_chain(
messages=full_history, token_limit=GEN_AI_HISTORY_CUTOFF
)
prompt_msgs = get_chat_rename_messages(history_str)

View File

@ -4,6 +4,7 @@ from langchain.schema import SystemMessage
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
@ -77,7 +78,9 @@ def check_if_need_search(
# as just a search engine
return True
history_str = combine_message_chain(history)
history_str = combine_message_chain(
messages=history, token_limit=GEN_AI_HISTORY_CUTOFF
)
prompt_msgs = _get_search_messages(
question=query_message.message, history_str=history_str

View File

@ -2,6 +2,7 @@ from collections.abc import Callable
from typing import cast
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
@ -119,7 +120,9 @@ def history_based_query_rephrase(
if count_punctuation(user_query) >= punctuation_heuristic:
return user_query
history_str = combine_message_chain(history)
history_str = combine_message_chain(
messages=history, token_limit=GEN_AI_HISTORY_CUTOFF
)
prompt_msgs = get_contextual_rephrase_messages(
question=user_query, history_str=history_str

View File

@ -153,6 +153,10 @@ def get_answer_with_quote(
query = query_request.messages[0].message
logger.info(f"Received query for one shot answer with quotes: {query}")
packets = stream_search_answer(
query_req=query_request, user=user, db_session=db_session
query_req=query_request,
user=user,
max_document_tokens=None,
max_history_tokens=0,
db_session=db_session,
)
return StreamingResponse(packets, media_type="application/json")

View File

@ -10,6 +10,7 @@ VALID_MODEL_LIST = [
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-32k-0613",
"gpt-3.5-turbo-0125",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",

View File

@ -108,6 +108,8 @@ def get_answer_for_question(
answer = get_search_answer(
query_req=new_message_request,
user=None,
max_document_tokens=None,
max_history_tokens=None,
db_session=db_session,
answer_generation_timeout=100,
enable_reflexion=False,

View File

@ -41,6 +41,8 @@ def get_answer_for_question(query: str, db_session: Session) -> OneShotQARespons
answer = get_search_answer(
query_req=new_message_request,
user=None,
max_document_tokens=None,
max_history_tokens=None,
db_session=db_session,
answer_generation_timeout=100,
enable_reflexion=False,

View File

@ -30,14 +30,15 @@ services:
- EMAIL_FROM=${EMAIL_FROM:-}
# Gen AI Settings
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125}
- FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
- MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-}
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
@ -93,14 +94,15 @@ services:
environment:
# Gen AI Settings (Needed by DanswerBot)
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125}
- FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-gpt-3.5-turbo-0125}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
- GEN_AI_LLM_PROVIDER_TYPE=${GEN_AI_LLM_PROVIDER_TYPE:-}
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
- QA_TIMEOUT=${QA_TIMEOUT:-}
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
- MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-}
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}

View File

@ -14,14 +14,15 @@ data:
EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead
# Gen AI Settings
GEN_AI_MODEL_PROVIDER: "openai"
GEN_AI_MODEL_VERSION: "gpt-3.5-turbo" # Use GPT-4 if you have it
FAST_GEN_AI_MODEL_VERSION: "gpt-3.5-turbo"
GEN_AI_MODEL_VERSION: "gpt-3.5-turbo-0125" # Use GPT-4 if you have it
FAST_GEN_AI_MODEL_VERSION: "gpt-3.5-turbo-0125"
GEN_AI_API_KEY: ""
GEN_AI_API_ENDPOINT: ""
GEN_AI_API_VERSION: ""
GEN_AI_LLM_PROVIDER_TYPE: ""
GEN_AI_MAX_TOKENS: ""
QA_TIMEOUT: "60"
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL: ""
MAX_CHUNKS_FED_TO_CHAT: ""
DISABLE_LLM_FILTER_EXTRACTION: ""
DISABLE_LLM_CHUNK_FILTER: ""
DISABLE_LLM_CHOOSE_SEARCH: ""

View File

@ -86,7 +86,7 @@ export function PersonaEditor({
description: existingPersona?.description ?? "",
system_prompt: existingPrompt?.system_prompt ?? "",
task_prompt: existingPrompt?.task_prompt ?? "",
disable_retrieval: (existingPersona?.num_chunks ?? 5) === 0,
disable_retrieval: (existingPersona?.num_chunks ?? 10) === 0,
document_set_ids:
existingPersona?.document_sets?.map(
(documentSet) => documentSet.id
@ -148,7 +148,7 @@ export function PersonaEditor({
// to tell the backend to not fetch any documents
const numChunks = values.disable_retrieval
? 0
: values.num_chunks || 5;
: values.num_chunks || 10;
let promptResponse;
let personaResponse;
@ -414,7 +414,7 @@ export function PersonaEditor({
input length limit.
<br />
<br />
If unspecified, will use 5 chunks.
If unspecified, will use 10 chunks.
</div>
}
onChange={(e) => {