Apply passthrough headers for chat renaming

This commit is contained in:
Weves
2024-06-10 17:53:47 -07:00
committed by Chris Weaver
parent 36afa9370f
commit cc0320b50a
2 changed files with 20 additions and 22 deletions

View File

@@ -1,8 +1,6 @@
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
@@ -14,30 +12,18 @@ logger = setup_logger()
def get_renamed_conversation_name(
full_history: list[ChatMessage],
llm: LLM | None = None,
llm: LLM,
) -> str:
def get_chat_rename_messages(history_str: str) -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": CHAT_NAMING.format(chat_history=history_str),
},
]
return messages
if llm is None:
try:
llm = get_default_llm()
except GenAIDisabledException:
# This may be longer than what the LLM tends to produce but is the most
# clear thing we can do
return full_history[0].message
history_str = combine_message_chain(
messages=full_history, token_limit=GEN_AI_HISTORY_CUTOFF
)
prompt_msgs = get_chat_rename_messages(history_str)
prompt_msgs = [
{
"role": "user",
"content": CHAT_NAMING.format(chat_history=history_str),
},
]
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs)
new_name_raw = message_to_string(llm.invoke(filled_llm_prompt))

View File

@@ -42,6 +42,8 @@ from danswer.file_store.models import FileDescriptor
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.headers import get_litellm_additional_request_headers
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.secondary_llm_flows.chat_session_naming import (
@@ -170,6 +172,7 @@ def create_new_chat_session(
@router.put("/rename-chat-session")
def rename_chat_session(
rename_req: ChatRenameRequest,
request: Request,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RenameChatSessionResponse:
@@ -193,7 +196,16 @@ def rename_chat_session(
)
full_history = history_msgs + [final_msg]
new_name = get_renamed_conversation_name(full_history=full_history)
try:
llm = get_default_llm(
additional_headers=get_litellm_additional_request_headers(request.headers)
)
except GenAIDisabledException:
# This may be longer than what the LLM tends to produce but is the most
# clear thing we can do
return RenameChatSessionResponse(new_name=full_history[0].message)
new_name = get_renamed_conversation_name(full_history=full_history, llm=llm)
update_chat_session(
db_session=db_session,