URL-based chat seeding

This commit is contained in:
Weves 2024-03-31 17:49:26 -07:00 committed by Chris Weaver
parent b8af1377ba
commit 32f55ddb8f
17 changed files with 308 additions and 89 deletions

View File

@ -34,7 +34,9 @@ from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_default_llm_tokenizer
@ -343,8 +345,12 @@ def stream_chat_message_objects(
),
document_pruning_config=document_pruning_config,
),
prompt=final_msg.prompt,
persona=persona,
prompt_config=PromptConfig.from_model(
final_msg.prompt, prompt_override=new_msg_req.prompt_override
),
llm_config=LLMConfig.from_persona(
persona, llm_override=new_msg_req.llm_override
),
message_history=[
PreviousMessage.from_chat_message(msg) for msg in history_msgs
],

View File

@ -38,7 +38,9 @@ from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens
@ -247,7 +249,7 @@ def handle_message(
query_text = new_message_request.messages[0].message
if persona:
max_document_tokens = compute_max_document_tokens(
max_document_tokens = compute_max_document_tokens_for_persona(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,

View File

@ -10,11 +10,11 @@ from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.llm.answering.doc_pruning import prune_documents
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import StreamProcessor
from danswer.llm.answering.prompts.citations_prompt import build_citations_prompt
from danswer.llm.answering.prompts.quotes_prompt import (
@ -51,8 +51,8 @@ class Answer:
question: str,
docs: list[LlmDoc],
answer_style_config: AnswerStyleConfig,
prompt: Prompt,
persona: Persona,
llm_config: LLMConfig,
prompt_config: PromptConfig,
# must be the same length as `docs`. If None, all docs are considered "relevant"
doc_relevance_list: list[bool] | None = None,
message_history: list[PreviousMessage] | None = None,
@ -72,16 +72,17 @@ class Answer:
self.single_message_history = single_message_history
self.answer_style_config = answer_style_config
self.llm_config = llm_config
self.prompt_config = prompt_config
self.llm = get_default_llm(
gen_ai_model_version_override=persona.llm_model_version_override,
gen_ai_model_provider=self.llm_config.model_provider,
gen_ai_model_version_override=self.llm_config.model_version,
timeout=timeout,
temperature=self.llm_config.temperature,
)
self.llm_tokenizer = get_default_llm_tokenizer()
self.prompt = prompt
self.persona = persona
self.process_stream_fn = _get_stream_processor(docs, answer_style_config)
self._final_prompt: list[BaseMessage] | None = None
@ -99,7 +100,8 @@ class Answer:
self._pruned_docs = prune_documents(
docs=self.docs,
doc_relevance_list=self.doc_relevance_list,
persona=self.persona,
prompt_config=self.prompt_config,
llm_config=self.llm_config,
question=self.question,
document_pruning_config=self.answer_style_config.document_pruning_config,
)
@ -114,8 +116,8 @@ class Answer:
self._final_prompt = build_citations_prompt(
question=self.question,
message_history=self.message_history,
persona=self.persona,
prompt=self.prompt,
llm_config=self.llm_config,
prompt_config=self.prompt_config,
context_docs=self.pruned_docs,
all_doc_useful=self.answer_style_config.citation_config.all_docs_useful,
llm_tokenizer_encode_func=self.llm_tokenizer.encode,
@ -126,7 +128,7 @@ class Answer:
question=self.question,
context_docs=self.pruned_docs,
history_str=self.single_message_history or "",
prompt=self.prompt,
prompt=self.prompt_config,
)
return cast(list[BaseMessage], self._final_prompt)

View File

@ -6,9 +6,10 @@ from danswer.chat.models import (
)
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import Persona
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import tokenizer_trim_content
@ -28,14 +29,15 @@ class PruningError(Exception):
def _compute_limit(
persona: Persona,
prompt_config: PromptConfig,
llm_config: LLMConfig,
question: str,
max_chunks: int | None,
max_window_percentage: float | None,
max_tokens: int | None,
) -> int:
llm_max_document_tokens = compute_max_document_tokens(
persona=persona, actual_user_input=question
prompt_config=prompt_config, llm_config=llm_config, actual_user_input=question
)
window_percentage_based_limit = (
@ -183,7 +185,8 @@ def _apply_pruning(
def prune_documents(
docs: list[LlmDoc],
doc_relevance_list: list[bool] | None,
persona: Persona,
prompt_config: PromptConfig,
llm_config: LLMConfig,
question: str,
document_pruning_config: DocumentPruningConfig,
) -> list[LlmDoc]:
@ -191,7 +194,8 @@ def prune_documents(
assert len(docs) == len(doc_relevance_list)
doc_token_limit = _compute_limit(
persona=persona,
prompt_config=prompt_config,
llm_config=llm_config,
question=question,
max_chunks=document_pruning_config.max_chunks,
max_window_percentage=document_pruning_config.max_window_percentage,

View File

@ -9,9 +9,15 @@ from pydantic import root_validator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.llm.utils import get_default_llm_version
from danswer.server.query_and_chat.models import LLMOverride
from danswer.server.query_and_chat.models import PromptOverride
if TYPE_CHECKING:
from danswer.db.models import ChatMessage
from danswer.db.models import Prompt
from danswer.db.models import Persona
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
@ -75,3 +81,63 @@ class AnswerStyleConfig(BaseModel):
)
return values
class LLMConfig(BaseModel):
"""Final representation of the LLM configuration passed into
the `Answer` object."""
model_provider: str
model_version: str
temperature: float
@classmethod
def from_persona(
cls, persona: "Persona", llm_override: LLMOverride | None = None
) -> "LLMConfig":
model_provider_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
temperature_override = llm_override.temperature if llm_override else None
return cls(
model_provider=model_provider_override or GEN_AI_MODEL_PROVIDER,
model_version=(
model_version_override
or persona.llm_model_version_override
or get_default_llm_version()[0]
),
temperature=temperature_override or 0.0,
)
class Config:
frozen = True
class PromptConfig(BaseModel):
"""Final representation of the Prompt configuration passed
into the `Answer` object."""
system_prompt: str
task_prompt: str
datetime_aware: bool
include_citations: bool
@classmethod
def from_model(
cls, model: "Prompt", prompt_override: PromptOverride | None = None
) -> "PromptConfig":
override_system_prompt = (
prompt_override.system_prompt if prompt_override else None
)
override_task_prompt = prompt_override.task_prompt if prompt_override else None
return cls(
system_prompt=override_system_prompt or model.system_prompt,
task_prompt=override_task_prompt or model.task_prompt,
datetime_aware=model.datetime_aware,
include_citations=model.include_citations,
)
# needed so that this can be passed into lru_cache funcs
class Config:
frozen = True

View File

@ -11,12 +11,12 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_default_prompt
from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import translate_history_to_basemessages
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
@ -92,16 +92,16 @@ def drop_messages_history_overflow(
return prompt
def get_prompt_tokens(prompt: Prompt) -> int:
def get_prompt_tokens(prompt_config: PromptConfig) -> int:
# Note: currently custom prompts do not allow datetime aware, only default prompts
return (
check_number_of_tokens(prompt.system_prompt)
+ check_number_of_tokens(prompt.task_prompt)
check_number_of_tokens(prompt_config.system_prompt)
+ check_number_of_tokens(prompt_config.task_prompt)
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
+ CITATION_STATEMENT_TOKEN_CNT
+ CITATION_REMINDER_TOKEN_CNT
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0)
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0)
)
@ -111,7 +111,8 @@ _MISC_BUFFER = 40
def compute_max_document_tokens(
persona: Persona,
prompt_config: PromptConfig,
llm_config: LLMConfig,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
@ -126,21 +127,13 @@ def compute_max_document_tokens(
if we're trying to determine if the user should be able to select another document) then we just set an
arbitrary "upper bound".
"""
llm_name = get_default_llm_version()[0]
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
# if we can't find a number of tokens, just assume some common default
max_input_tokens = (
max_llm_token_override
if max_llm_token_override
else get_max_input_tokens(model_name=llm_name)
else get_max_input_tokens(model_name=llm_config.model_version)
)
if persona.prompts:
# TODO this may not always be the first prompt
prompt_tokens = get_prompt_tokens(persona.prompts[0])
else:
prompt_tokens = get_prompt_tokens(get_default_prompt())
prompt_tokens = get_prompt_tokens(prompt_config)
user_input_tokens = (
check_number_of_tokens(actual_user_input)
@ -151,31 +144,44 @@ def compute_max_document_tokens(
return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER
def compute_max_llm_input_tokens(persona: Persona) -> int:
"""Maximum tokens allows in the input to the LLM (of any type)."""
llm_name = get_default_llm_version()[0]
if persona.llm_model_version_override:
llm_name = persona.llm_model_version_override
def compute_max_document_tokens_for_persona(
persona: Persona,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
prompt = persona.prompts[0] if persona.prompts else get_default_prompt()
return compute_max_document_tokens(
prompt_config=PromptConfig.from_model(prompt),
llm_config=LLMConfig.from_persona(persona),
actual_user_input=actual_user_input,
max_llm_token_override=max_llm_token_override,
)
input_tokens = get_max_input_tokens(model_name=llm_name)
def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int:
"""Maximum tokens allows in the input to the LLM (of any type)."""
input_tokens = get_max_input_tokens(
model_name=llm_config.model_version, model_provider=llm_config.model_provider
)
return input_tokens - _MISC_BUFFER
@lru_cache()
def build_system_message(
prompt: Prompt,
prompt_config: PromptConfig,
context_exists: bool,
llm_tokenizer_encode_func: Callable,
citation_line: str = REQUIRE_CITATION_STATEMENT,
no_citation_line: str = NO_CITATION_STATEMENT,
) -> tuple[SystemMessage | None, int]:
system_prompt = prompt.system_prompt.strip()
if prompt.include_citations:
system_prompt = prompt_config.system_prompt.strip()
if prompt_config.include_citations:
if context_exists:
system_prompt += citation_line
else:
system_prompt += no_citation_line
if prompt.datetime_aware:
if prompt_config.datetime_aware:
if system_prompt:
system_prompt += ADDITIONAL_INFO.format(
datetime_info=get_current_llm_day_time()
@ -194,7 +200,7 @@ def build_system_message(
def build_user_message(
question: str,
prompt: Prompt,
prompt_config: PromptConfig,
context_docs: list[LlmDoc] | list[InferenceChunk],
all_doc_useful: bool,
history_message: str,
@ -206,9 +212,9 @@ def build_user_message(
# Simpler prompt for cases where there is no context
user_prompt = (
CHAT_USER_CONTEXT_FREE_PROMPT.format(
task_prompt=prompt.task_prompt, user_query=question
task_prompt=prompt_config.task_prompt, user_query=question
)
if prompt.task_prompt
if prompt_config.task_prompt
else question
)
user_prompt = user_prompt.strip()
@ -219,7 +225,7 @@ def build_user_message(
context_docs_str = build_complete_context_str(context_docs)
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
task_prompt_with_reminder = build_task_prompt_reminders(prompt_config)
user_prompt = CITATIONS_PROMPT.format(
optional_ignore_statement=optional_ignore,
@ -239,8 +245,8 @@ def build_user_message(
def build_citations_prompt(
question: str,
message_history: list[PreviousMessage],
persona: Persona,
prompt: Prompt,
prompt_config: PromptConfig,
llm_config: LLMConfig,
context_docs: list[LlmDoc] | list[InferenceChunk],
all_doc_useful: bool,
history_message: str,
@ -249,7 +255,7 @@ def build_citations_prompt(
context_exists = len(context_docs) > 0
system_message_or_none, system_tokens = build_system_message(
prompt=prompt,
prompt_config=prompt_config,
context_exists=context_exists,
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
)
@ -262,7 +268,7 @@ def build_citations_prompt(
# Is the same as passed in later for extracting citations
user_message, user_tokens = build_user_message(
question=question,
prompt=prompt,
prompt_config=prompt_config,
context_docs=context_docs,
all_doc_useful=all_doc_useful,
history_message=history_message,
@ -275,7 +281,7 @@ def build_citations_prompt(
history_token_counts=history_token_counts,
final_msg=user_message,
final_msg_token_count=user_tokens,
max_allowed_tokens=compute_max_llm_input_tokens(persona),
max_allowed_tokens=compute_max_llm_input_tokens(llm_config),
)
return final_prompt_msgs

View File

@ -4,8 +4,8 @@ from langchain.schema.messages import HumanMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import PromptConfig
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
@ -18,7 +18,7 @@ def _build_weak_llm_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: Prompt,
prompt: PromptConfig,
use_language_hint: bool,
) -> list[BaseMessage]:
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
@ -43,7 +43,7 @@ def _build_strong_llm_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: Prompt,
prompt: PromptConfig,
use_language_hint: bool,
) -> list[BaseMessage]:
context_block = ""
@ -70,7 +70,7 @@ def build_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: Prompt,
prompt: PromptConfig,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> list[BaseMessage]:
prompt_builder = (

View File

@ -1,6 +1,7 @@
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.exceptions import GenAIDisabledException
@ -14,6 +15,7 @@ def get_default_llm(
gen_ai_model_provider: str = GEN_AI_MODEL_PROVIDER,
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
temperature: float = GEN_AI_TEMPERATURE,
use_fast_llm: bool = False,
gen_ai_model_version_override: str | None = None,
) -> LLM:
@ -34,8 +36,13 @@ def get_default_llm(
return CustomModelServer(api_key=api_key, timeout=timeout)
if gen_ai_model_provider.lower() == "gpt4all":
return DanswerGPT4All(model_version=model_version, timeout=timeout)
return DanswerGPT4All(
model_version=model_version, timeout=timeout, temperature=temperature
)
return DefaultMultiLLM(
model_version=model_version, api_key=api_key, timeout=timeout
model_version=model_version,
api_key=api_key,
timeout=timeout,
temperature=temperature,
)

View File

@ -4,6 +4,8 @@ from copy import copy
from functools import lru_cache
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import Union
import litellm # type: ignore
import tiktoken
@ -33,10 +35,12 @@ from danswer.db.models import ChatMessage
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.utils.logger import setup_logger
if TYPE_CHECKING:
from danswer.llm.answering.models import PreviousMessage
logger = setup_logger()
_LLM_TOKENIZER: Any = None
@ -116,7 +120,7 @@ def tokenizer_trim_chunks(
def translate_danswer_msg_to_langchain(
msg: ChatMessage | PreviousMessage,
msg: Union[ChatMessage, "PreviousMessage"],
) -> BaseMessage:
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
@ -129,7 +133,7 @@ def translate_danswer_msg_to_langchain(
def translate_history_to_basemessages(
history: list[ChatMessage] | list[PreviousMessage],
history: list[ChatMessage] | list["PreviousMessage"],
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_danswer_msg_to_langchain(msg)

View File

@ -28,6 +28,8 @@ from danswer.llm.answering.answer import Answer
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import QuotesConfig
from danswer.llm.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import DirectQARequest
@ -203,8 +205,8 @@ def stream_answer_objects(
question=query_msg.message,
docs=[llm_doc_from_inference_chunk(chunk) for chunk in top_chunks],
answer_style_config=answer_config,
prompt=prompt,
persona=chat_session.persona,
prompt_config=PromptConfig.from_model(prompt),
llm_config=LLMConfig.from_persona(chat_session.persona),
doc_relevance_list=search_pipeline.chunk_relevance_list,
single_message_history=history_str,
timeout=timeout,

View File

@ -6,6 +6,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import DocumentSource
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import PromptConfig
from danswer.prompts.chat_prompts import CITATION_REMINDER
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
@ -20,7 +21,7 @@ def get_current_llm_day_time() -> str:
def build_task_prompt_reminders(
prompt: Prompt,
prompt: Prompt | PromptConfig,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
citation_str: str = CITATION_REMINDER,
language_hint_str: str = LANGUAGE_HINT,

View File

@ -24,7 +24,9 @@ from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.models import User
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
)
@ -303,5 +305,5 @@ def get_max_document_tokens(
raise HTTPException(status_code=404, detail="Persona not found")
return MaxSelectedDocumentTokens(
max_tokens=compute_max_document_tokens(persona),
max_tokens=compute_max_document_tokens_for_persona(persona),
)

View File

@ -67,6 +67,17 @@ class DocumentSearchRequest(BaseModel):
skip_rerank: bool = False
class LLMOverride(BaseModel):
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None
class PromptOverride(BaseModel):
system_prompt: str | None = None
task_prompt: str | None = None
"""
Currently the different branches are generated by changing the search query
@ -98,6 +109,10 @@ class CreateChatMessageRequest(BaseModel):
query_override: str | None = None
no_ai_answer: bool = False
# allows the caller to override the Persona / Prompt
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
@root_validator
def check_search_doc_ids_or_retrieval_options(cls: BaseModel, values: dict) -> dict:
search_doc_ids, retrieval_options = values.get("search_doc_ids"), values.get(

View File

@ -13,9 +13,10 @@ import {
RetrievalType,
StreamingError,
} from "./interfaces";
import { useRouter } from "next/navigation";
import { useRouter, useSearchParams } from "next/navigation";
import { FeedbackType } from "./types";
import {
buildChatUrl,
createChatSession,
getCitedDocumentsFromMessage,
getHumanAndAIMessageFromMessageNumber,
@ -46,6 +47,7 @@ import { computeAvailableFilters } from "@/lib/filters";
import { useDocumentSelection } from "./useDocumentSelection";
import { StarterMessage } from "./StarterMessage";
import { ShareChatSessionModal } from "./modal/ShareChatSessionModal";
import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams";
const MAX_INPUT_HEIGHT = 200;
@ -71,6 +73,13 @@ export const Chat = ({
shouldhideBeforeScroll?: boolean;
}) => {
const router = useRouter();
const searchParams = useSearchParams();
// used to track whether or not the initial "submit on load" has been performed
// this only applies if `?submit-on-load=true` or `?submit-on-load=1` is in the URL
// NOTE: this is required due to React strict mode, where all `useEffect` hooks
// are run twice on initial load during development
const submitOnLoadPerformed = useRef<boolean>(false);
const { popup, setPopup } = usePopup();
// fetch messages for the chat session
@ -117,6 +126,16 @@ export const Chat = ({
}
setMessageHistory([]);
setChatSessionSharedStatus(ChatSessionSharedStatus.Private);
// if we're supposed to submit on initial load, then do that here
if (
shouldSubmitOnLoad(searchParams) &&
!submitOnLoadPerformed.current
) {
submitOnLoadPerformed.current = true;
onSubmit();
}
return;
}
@ -151,7 +170,9 @@ export const Chat = ({
const [chatSessionId, setChatSessionId] = useState<number | null>(
existingChatSessionId
);
const [message, setMessage] = useState("");
const [message, setMessage] = useState(
searchParams.get(SEARCH_PARAM_NAMES.USER_MESSAGE) || ""
);
const [messageHistory, setMessageHistory] = useState<Message[]>([]);
const [isStreaming, setIsStreaming] = useState(false);
@ -385,6 +406,13 @@ export const Chat = ({
.map((document) => document.db_doc_id as number),
queryOverride,
forceSearch,
modelVersion:
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || undefined,
temperature:
parseFloat(searchParams.get(SEARCH_PARAM_NAMES.TEMPERATURE) || "") ||
undefined,
systemPromptOverride:
searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
})) {
for (const packet of packetBunch) {
if (Object.hasOwn(packet, "answer_piece")) {
@ -454,7 +482,7 @@ export const Chat = ({
currChatSessionId === urlChatSessionId.current ||
urlChatSessionId.current === null
) {
router.push(`/chat?chatId=${currChatSessionId}`, {
router.push(buildChatUrl(searchParams, currChatSessionId, null), {
scroll: false,
});
}
@ -550,7 +578,9 @@ export const Chat = ({
if (persona) {
setSelectedPersona(persona);
textareaRef.current?.focus();
router.push(`/chat?personaId=${persona.id}`);
router.push(
buildChatUrl(searchParams, null, persona.id)
);
}
}}
/>
@ -577,7 +607,7 @@ export const Chat = ({
handlePersonaSelect={(persona) => {
setSelectedPersona(persona);
textareaRef.current?.focus();
router.push(`/chat?personaId=${persona.id}`);
router.push(buildChatUrl(searchParams, null, persona.id));
}}
/>
)}

View File

@ -15,6 +15,8 @@ import {
StreamingError,
} from "./interfaces";
import { Persona } from "../admin/personas/interfaces";
import { ReadonlyURLSearchParams } from "next/navigation";
import { SEARCH_PARAM_NAMES } from "./searchParams";
export async function createChatSession(personaId: number): Promise<number> {
const createChatSessionResponse = await fetch(
@ -39,17 +41,6 @@ export async function createChatSession(personaId: number): Promise<number> {
return chatSessionResponseJson.chat_session_id;
}
export interface SendMessageRequest {
message: string;
parentMessageId: number | null;
chatSessionId: number;
promptId: number | null | undefined;
filters: Filters | null;
selectedDocumentIds: number[] | null;
queryOverride?: string;
forceSearch?: boolean;
}
export async function* sendMessage({
message,
parentMessageId,
@ -59,7 +50,24 @@ export async function* sendMessage({
selectedDocumentIds,
queryOverride,
forceSearch,
}: SendMessageRequest) {
modelVersion,
temperature,
systemPromptOverride,
}: {
message: string;
parentMessageId: number | null;
chatSessionId: number;
promptId: number | null | undefined;
filters: Filters | null;
selectedDocumentIds: number[] | null;
queryOverride?: string;
forceSearch?: boolean;
// LLM overrides
modelVersion?: string;
temperature?: number;
// prompt overrides
systemPromptOverride?: string;
}) {
const documentsAreSelected =
selectedDocumentIds && selectedDocumentIds.length > 0;
const sendMessageResponse = await fetch("/api/chat/send-message", {
@ -87,6 +95,13 @@ export async function* sendMessage({
}
: null,
query_override: queryOverride,
prompt_override: {
system_prompt: systemPromptOverride,
},
llm_override: {
temperature,
model_version: modelVersion,
},
}),
});
if (!sendMessageResponse.ok) {
@ -354,3 +369,38 @@ export function processRawChatHistory(rawMessages: BackendMessage[]) {
export function personaIncludesRetrieval(selectedPersona: Persona) {
return selectedPersona.num_chunks !== 0;
}
const PARAMS_TO_SKIP = [
SEARCH_PARAM_NAMES.SUBMIT_ON_LOAD,
SEARCH_PARAM_NAMES.USER_MESSAGE,
// only use these if explicitly passed in
SEARCH_PARAM_NAMES.CHAT_ID,
SEARCH_PARAM_NAMES.PERSONA_ID,
];
export function buildChatUrl(
existingSearchParams: ReadonlyURLSearchParams,
chatSessionId: number | null,
personaId: number | null
) {
const finalSearchParams: string[] = [];
if (chatSessionId) {
finalSearchParams.push(`${SEARCH_PARAM_NAMES.CHAT_ID}=${chatSessionId}`);
}
if (personaId) {
finalSearchParams.push(`${SEARCH_PARAM_NAMES.PERSONA_ID}=${personaId}`);
}
existingSearchParams.forEach((value, key) => {
if (!PARAMS_TO_SKIP.includes(key)) {
finalSearchParams.push(`${key}=${value}`);
}
});
const finalSearchParamsString = finalSearchParams.join("&");
if (finalSearchParamsString) {
return `/chat?${finalSearchParamsString}`;
}
return "/chat";
}

View File

@ -0,0 +1,22 @@
import { ReadonlyURLSearchParams } from "next/navigation";
// search params
export const SEARCH_PARAM_NAMES = {
CHAT_ID: "chatId",
PERSONA_ID: "personaId",
// overrides
TEMPERATURE: "temperature",
MODEL_VERSION: "model-version",
SYSTEM_PROMPT: "system-prompt",
// user message
USER_MESSAGE: "user-message",
SUBMIT_ON_LOAD: "submit-on-load",
};
export function shouldSubmitOnLoad(searchParams: ReadonlyURLSearchParams) {
const rawSubmitOnLoad = searchParams.get(SEARCH_PARAM_NAMES.SUBMIT_ON_LOAD);
if (rawSubmitOnLoad === "true" || rawSubmitOnLoad === "1") {
return true;
}
return false;
}

View File

@ -117,7 +117,7 @@ export const ChatSidebar = ({
{chatSessions.map((chat) => {
const isSelected = currentChatId === chat.id;
return (
<div key={chat.id} className="mr-3">
<div key={`${chat.id}-${chat.name}`} className="mr-3">
<ChatSessionDisplay
chatSession={chat}
isSelected={isSelected}