mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-11 13:32:16 +02:00
URL-based chat seeding
This commit is contained in:
parent
b8af1377ba
commit
32f55ddb8f
@ -34,7 +34,9 @@ from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import LLMConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
@ -343,8 +345,12 @@ def stream_chat_message_objects(
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
),
|
||||
prompt=final_msg.prompt,
|
||||
persona=persona,
|
||||
prompt_config=PromptConfig.from_model(
|
||||
final_msg.prompt, prompt_override=new_msg_req.prompt_override
|
||||
),
|
||||
llm_config=LLMConfig.from_persona(
|
||||
persona, llm_override=new_msg_req.llm_override
|
||||
),
|
||||
message_history=[
|
||||
PreviousMessage.from_chat_message(msg) for msg in history_msgs
|
||||
],
|
||||
|
@ -38,7 +38,9 @@ from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
@ -247,7 +249,7 @@ def handle_message(
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens(
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
|
@ -10,11 +10,11 @@ from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.llm.answering.doc_pruning import prune_documents
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import LLMConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import StreamProcessor
|
||||
from danswer.llm.answering.prompts.citations_prompt import build_citations_prompt
|
||||
from danswer.llm.answering.prompts.quotes_prompt import (
|
||||
@ -51,8 +51,8 @@ class Answer:
|
||||
question: str,
|
||||
docs: list[LlmDoc],
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt: Prompt,
|
||||
persona: Persona,
|
||||
llm_config: LLMConfig,
|
||||
prompt_config: PromptConfig,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
doc_relevance_list: list[bool] | None = None,
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
@ -72,16 +72,17 @@ class Answer:
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.answer_style_config = answer_style_config
|
||||
self.llm_config = llm_config
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.llm = get_default_llm(
|
||||
gen_ai_model_version_override=persona.llm_model_version_override,
|
||||
gen_ai_model_provider=self.llm_config.model_provider,
|
||||
gen_ai_model_version_override=self.llm_config.model_version,
|
||||
timeout=timeout,
|
||||
temperature=self.llm_config.temperature,
|
||||
)
|
||||
self.llm_tokenizer = get_default_llm_tokenizer()
|
||||
|
||||
self.prompt = prompt
|
||||
self.persona = persona
|
||||
|
||||
self.process_stream_fn = _get_stream_processor(docs, answer_style_config)
|
||||
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
@ -99,7 +100,8 @@ class Answer:
|
||||
self._pruned_docs = prune_documents(
|
||||
docs=self.docs,
|
||||
doc_relevance_list=self.doc_relevance_list,
|
||||
persona=self.persona,
|
||||
prompt_config=self.prompt_config,
|
||||
llm_config=self.llm_config,
|
||||
question=self.question,
|
||||
document_pruning_config=self.answer_style_config.document_pruning_config,
|
||||
)
|
||||
@ -114,8 +116,8 @@ class Answer:
|
||||
self._final_prompt = build_citations_prompt(
|
||||
question=self.question,
|
||||
message_history=self.message_history,
|
||||
persona=self.persona,
|
||||
prompt=self.prompt,
|
||||
llm_config=self.llm_config,
|
||||
prompt_config=self.prompt_config,
|
||||
context_docs=self.pruned_docs,
|
||||
all_doc_useful=self.answer_style_config.citation_config.all_docs_useful,
|
||||
llm_tokenizer_encode_func=self.llm_tokenizer.encode,
|
||||
@ -126,7 +128,7 @@ class Answer:
|
||||
question=self.question,
|
||||
context_docs=self.pruned_docs,
|
||||
history_str=self.single_message_history or "",
|
||||
prompt=self.prompt,
|
||||
prompt=self.prompt_config,
|
||||
)
|
||||
|
||||
return cast(list[BaseMessage], self._final_prompt)
|
||||
|
@ -6,9 +6,10 @@ from danswer.chat.models import (
|
||||
)
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import Persona
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import LLMConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
@ -28,14 +29,15 @@ class PruningError(Exception):
|
||||
|
||||
|
||||
def _compute_limit(
|
||||
persona: Persona,
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
question: str,
|
||||
max_chunks: int | None,
|
||||
max_window_percentage: float | None,
|
||||
max_tokens: int | None,
|
||||
) -> int:
|
||||
llm_max_document_tokens = compute_max_document_tokens(
|
||||
persona=persona, actual_user_input=question
|
||||
prompt_config=prompt_config, llm_config=llm_config, actual_user_input=question
|
||||
)
|
||||
|
||||
window_percentage_based_limit = (
|
||||
@ -183,7 +185,8 @@ def _apply_pruning(
|
||||
def prune_documents(
|
||||
docs: list[LlmDoc],
|
||||
doc_relevance_list: list[bool] | None,
|
||||
persona: Persona,
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
question: str,
|
||||
document_pruning_config: DocumentPruningConfig,
|
||||
) -> list[LlmDoc]:
|
||||
@ -191,7 +194,8 @@ def prune_documents(
|
||||
assert len(docs) == len(doc_relevance_list)
|
||||
|
||||
doc_token_limit = _compute_limit(
|
||||
persona=persona,
|
||||
prompt_config=prompt_config,
|
||||
llm_config=llm_config,
|
||||
question=question,
|
||||
max_chunks=document_pruning_config.max_chunks,
|
||||
max_window_percentage=document_pruning_config.max_window_percentage,
|
||||
|
@ -9,9 +9,15 @@ from pydantic import root_validator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.server.query_and_chat.models import LLMOverride
|
||||
from danswer.server.query_and_chat.models import PromptOverride
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Persona
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
@ -75,3 +81,63 @@ class AnswerStyleConfig(BaseModel):
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
"""Final representation of the LLM configuration passed into
|
||||
the `Answer` object."""
|
||||
|
||||
model_provider: str
|
||||
model_version: str
|
||||
temperature: float
|
||||
|
||||
@classmethod
|
||||
def from_persona(
|
||||
cls, persona: "Persona", llm_override: LLMOverride | None = None
|
||||
) -> "LLMConfig":
|
||||
model_provider_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
|
||||
return cls(
|
||||
model_provider=model_provider_override or GEN_AI_MODEL_PROVIDER,
|
||||
model_version=(
|
||||
model_version_override
|
||||
or persona.llm_model_version_override
|
||||
or get_default_llm_version()[0]
|
||||
),
|
||||
temperature=temperature_override or 0.0,
|
||||
)
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
include_citations: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
||||
) -> "PromptConfig":
|
||||
override_system_prompt = (
|
||||
prompt_override.system_prompt if prompt_override else None
|
||||
)
|
||||
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
||||
|
||||
return cls(
|
||||
system_prompt=override_system_prompt or model.system_prompt,
|
||||
task_prompt=override_task_prompt or model.task_prompt,
|
||||
datetime_aware=model.datetime_aware,
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
# needed so that this can be passed into lru_cache funcs
|
||||
class Config:
|
||||
frozen = True
|
||||
|
@ -11,12 +11,12 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.db.chat import get_default_prompt
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.answering.models import LLMConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
|
||||
@ -92,16 +92,16 @@ def drop_messages_history_overflow(
|
||||
return prompt
|
||||
|
||||
|
||||
def get_prompt_tokens(prompt: Prompt) -> int:
|
||||
def get_prompt_tokens(prompt_config: PromptConfig) -> int:
|
||||
# Note: currently custom prompts do not allow datetime aware, only default prompts
|
||||
return (
|
||||
check_number_of_tokens(prompt.system_prompt)
|
||||
+ check_number_of_tokens(prompt.task_prompt)
|
||||
check_number_of_tokens(prompt_config.system_prompt)
|
||||
+ check_number_of_tokens(prompt_config.task_prompt)
|
||||
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
|
||||
+ CITATION_STATEMENT_TOKEN_CNT
|
||||
+ CITATION_REMINDER_TOKEN_CNT
|
||||
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
|
||||
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt.datetime_aware else 0)
|
||||
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0)
|
||||
)
|
||||
|
||||
|
||||
@ -111,7 +111,8 @@ _MISC_BUFFER = 40
|
||||
|
||||
|
||||
def compute_max_document_tokens(
|
||||
persona: Persona,
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
actual_user_input: str | None = None,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
@ -126,21 +127,13 @@ def compute_max_document_tokens(
|
||||
if we're trying to determine if the user should be able to select another document) then we just set an
|
||||
arbitrary "upper bound".
|
||||
"""
|
||||
llm_name = get_default_llm_version()[0]
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
|
||||
# if we can't find a number of tokens, just assume some common default
|
||||
max_input_tokens = (
|
||||
max_llm_token_override
|
||||
if max_llm_token_override
|
||||
else get_max_input_tokens(model_name=llm_name)
|
||||
else get_max_input_tokens(model_name=llm_config.model_version)
|
||||
)
|
||||
if persona.prompts:
|
||||
# TODO this may not always be the first prompt
|
||||
prompt_tokens = get_prompt_tokens(persona.prompts[0])
|
||||
else:
|
||||
prompt_tokens = get_prompt_tokens(get_default_prompt())
|
||||
prompt_tokens = get_prompt_tokens(prompt_config)
|
||||
|
||||
user_input_tokens = (
|
||||
check_number_of_tokens(actual_user_input)
|
||||
@ -151,31 +144,44 @@ def compute_max_document_tokens(
|
||||
return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER
|
||||
|
||||
|
||||
def compute_max_llm_input_tokens(persona: Persona) -> int:
|
||||
"""Maximum tokens allows in the input to the LLM (of any type)."""
|
||||
llm_name = get_default_llm_version()[0]
|
||||
if persona.llm_model_version_override:
|
||||
llm_name = persona.llm_model_version_override
|
||||
def compute_max_document_tokens_for_persona(
|
||||
persona: Persona,
|
||||
actual_user_input: str | None = None,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt()
|
||||
return compute_max_document_tokens(
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm_config=LLMConfig.from_persona(persona),
|
||||
actual_user_input=actual_user_input,
|
||||
max_llm_token_override=max_llm_token_override,
|
||||
)
|
||||
|
||||
input_tokens = get_max_input_tokens(model_name=llm_name)
|
||||
|
||||
def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int:
|
||||
"""Maximum tokens allows in the input to the LLM (of any type)."""
|
||||
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm_config.model_version, model_provider=llm_config.model_provider
|
||||
)
|
||||
return input_tokens - _MISC_BUFFER
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def build_system_message(
|
||||
prompt: Prompt,
|
||||
prompt_config: PromptConfig,
|
||||
context_exists: bool,
|
||||
llm_tokenizer_encode_func: Callable,
|
||||
citation_line: str = REQUIRE_CITATION_STATEMENT,
|
||||
no_citation_line: str = NO_CITATION_STATEMENT,
|
||||
) -> tuple[SystemMessage | None, int]:
|
||||
system_prompt = prompt.system_prompt.strip()
|
||||
if prompt.include_citations:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
if prompt_config.include_citations:
|
||||
if context_exists:
|
||||
system_prompt += citation_line
|
||||
else:
|
||||
system_prompt += no_citation_line
|
||||
if prompt.datetime_aware:
|
||||
if prompt_config.datetime_aware:
|
||||
if system_prompt:
|
||||
system_prompt += ADDITIONAL_INFO.format(
|
||||
datetime_info=get_current_llm_day_time()
|
||||
@ -194,7 +200,7 @@ def build_system_message(
|
||||
|
||||
def build_user_message(
|
||||
question: str,
|
||||
prompt: Prompt,
|
||||
prompt_config: PromptConfig,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
all_doc_useful: bool,
|
||||
history_message: str,
|
||||
@ -206,9 +212,9 @@ def build_user_message(
|
||||
# Simpler prompt for cases where there is no context
|
||||
user_prompt = (
|
||||
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
||||
task_prompt=prompt.task_prompt, user_query=question
|
||||
task_prompt=prompt_config.task_prompt, user_query=question
|
||||
)
|
||||
if prompt.task_prompt
|
||||
if prompt_config.task_prompt
|
||||
else question
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
@ -219,7 +225,7 @@ def build_user_message(
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
|
||||
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(prompt)
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(prompt_config)
|
||||
|
||||
user_prompt = CITATIONS_PROMPT.format(
|
||||
optional_ignore_statement=optional_ignore,
|
||||
@ -239,8 +245,8 @@ def build_user_message(
|
||||
def build_citations_prompt(
|
||||
question: str,
|
||||
message_history: list[PreviousMessage],
|
||||
persona: Persona,
|
||||
prompt: Prompt,
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
all_doc_useful: bool,
|
||||
history_message: str,
|
||||
@ -249,7 +255,7 @@ def build_citations_prompt(
|
||||
context_exists = len(context_docs) > 0
|
||||
|
||||
system_message_or_none, system_tokens = build_system_message(
|
||||
prompt=prompt,
|
||||
prompt_config=prompt_config,
|
||||
context_exists=context_exists,
|
||||
llm_tokenizer_encode_func=llm_tokenizer_encode_func,
|
||||
)
|
||||
@ -262,7 +268,7 @@ def build_citations_prompt(
|
||||
# Is the same as passed in later for extracting citations
|
||||
user_message, user_tokens = build_user_message(
|
||||
question=question,
|
||||
prompt=prompt,
|
||||
prompt_config=prompt_config,
|
||||
context_docs=context_docs,
|
||||
all_doc_useful=all_doc_useful,
|
||||
history_message=history_message,
|
||||
@ -275,7 +281,7 @@ def build_citations_prompt(
|
||||
history_token_counts=history_token_counts,
|
||||
final_msg=user_message,
|
||||
final_msg_token_count=user_tokens,
|
||||
max_allowed_tokens=compute_max_llm_input_tokens(persona),
|
||||
max_allowed_tokens=compute_max_llm_input_tokens(llm_config),
|
||||
)
|
||||
|
||||
return final_prompt_msgs
|
||||
|
@ -4,8 +4,8 @@ from langchain.schema.messages import HumanMessage
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
@ -18,7 +18,7 @@ def _build_weak_llm_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: Prompt,
|
||||
prompt: PromptConfig,
|
||||
use_language_hint: bool,
|
||||
) -> list[BaseMessage]:
|
||||
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
|
||||
@ -43,7 +43,7 @@ def _build_strong_llm_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: Prompt,
|
||||
prompt: PromptConfig,
|
||||
use_language_hint: bool,
|
||||
) -> list[BaseMessage]:
|
||||
context_block = ""
|
||||
@ -70,7 +70,7 @@ def build_quotes_prompt(
|
||||
question: str,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: Prompt,
|
||||
prompt: PromptConfig,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> list[BaseMessage]:
|
||||
prompt_builder = (
|
||||
|
@ -1,6 +1,7 @@
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.chat_llm import DefaultMultiLLM
|
||||
from danswer.llm.custom_llm import CustomModelServer
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
@ -14,6 +15,7 @@ def get_default_llm(
|
||||
gen_ai_model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
use_fast_llm: bool = False,
|
||||
gen_ai_model_version_override: str | None = None,
|
||||
) -> LLM:
|
||||
@ -34,8 +36,13 @@ def get_default_llm(
|
||||
return CustomModelServer(api_key=api_key, timeout=timeout)
|
||||
|
||||
if gen_ai_model_provider.lower() == "gpt4all":
|
||||
return DanswerGPT4All(model_version=model_version, timeout=timeout)
|
||||
return DanswerGPT4All(
|
||||
model_version=model_version, timeout=timeout, temperature=temperature
|
||||
)
|
||||
|
||||
return DefaultMultiLLM(
|
||||
model_version=model_version, api_key=api_key, timeout=timeout
|
||||
model_version=model_version,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
)
|
||||
|
@ -4,6 +4,8 @@ from copy import copy
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
import litellm # type: ignore
|
||||
import tiktoken
|
||||
@ -33,10 +35,12 @@ from danswer.db.models import ChatMessage
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_LLM_TOKENIZER: Any = None
|
||||
@ -116,7 +120,7 @@ def tokenizer_trim_chunks(
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
) -> BaseMessage:
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
@ -129,7 +133,7 @@ def translate_danswer_msg_to_langchain(
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list[PreviousMessage],
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
|
@ -28,6 +28,8 @@ from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import LLMConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
@ -203,8 +205,8 @@ def stream_answer_objects(
|
||||
question=query_msg.message,
|
||||
docs=[llm_doc_from_inference_chunk(chunk) for chunk in top_chunks],
|
||||
answer_style_config=answer_config,
|
||||
prompt=prompt,
|
||||
persona=chat_session.persona,
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm_config=LLMConfig.from_persona(chat_session.persona),
|
||||
doc_relevance_list=search_pipeline.chunk_relevance_list,
|
||||
single_message_history=history_str,
|
||||
timeout=timeout,
|
||||
|
@ -6,6 +6,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||
@ -20,7 +21,7 @@ def get_current_llm_day_time() -> str:
|
||||
|
||||
|
||||
def build_task_prompt_reminders(
|
||||
prompt: Prompt,
|
||||
prompt: Prompt | PromptConfig,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
citation_str: str = CITATION_REMINDER,
|
||||
language_hint_str: str = LANGUAGE_HINT,
|
||||
|
@ -24,7 +24,9 @@ from danswer.db.feedback import create_doc_retrieval_feedback
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.secondary_llm_flows.chat_session_naming import (
|
||||
get_renamed_conversation_name,
|
||||
)
|
||||
@ -303,5 +305,5 @@ def get_max_document_tokens(
|
||||
raise HTTPException(status_code=404, detail="Persona not found")
|
||||
|
||||
return MaxSelectedDocumentTokens(
|
||||
max_tokens=compute_max_document_tokens(persona),
|
||||
max_tokens=compute_max_document_tokens_for_persona(persona),
|
||||
)
|
||||
|
@ -67,6 +67,17 @@ class DocumentSearchRequest(BaseModel):
|
||||
skip_rerank: bool = False
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
|
||||
class PromptOverride(BaseModel):
|
||||
system_prompt: str | None = None
|
||||
task_prompt: str | None = None
|
||||
|
||||
|
||||
"""
|
||||
Currently the different branches are generated by changing the search query
|
||||
|
||||
@ -98,6 +109,10 @@ class CreateChatMessageRequest(BaseModel):
|
||||
query_override: str | None = None
|
||||
no_ai_answer: bool = False
|
||||
|
||||
# allows the caller to override the Persona / Prompt
|
||||
llm_override: LLMOverride | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
|
||||
@root_validator
|
||||
def check_search_doc_ids_or_retrieval_options(cls: BaseModel, values: dict) -> dict:
|
||||
search_doc_ids, retrieval_options = values.get("search_doc_ids"), values.get(
|
||||
|
@ -13,9 +13,10 @@ import {
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
} from "./interfaces";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { FeedbackType } from "./types";
|
||||
import {
|
||||
buildChatUrl,
|
||||
createChatSession,
|
||||
getCitedDocumentsFromMessage,
|
||||
getHumanAndAIMessageFromMessageNumber,
|
||||
@ -46,6 +47,7 @@ import { computeAvailableFilters } from "@/lib/filters";
|
||||
import { useDocumentSelection } from "./useDocumentSelection";
|
||||
import { StarterMessage } from "./StarterMessage";
|
||||
import { ShareChatSessionModal } from "./modal/ShareChatSessionModal";
|
||||
import { SEARCH_PARAM_NAMES, shouldSubmitOnLoad } from "./searchParams";
|
||||
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
|
||||
@ -71,6 +73,13 @@ export const Chat = ({
|
||||
shouldhideBeforeScroll?: boolean;
|
||||
}) => {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
// used to track whether or not the initial "submit on load" has been performed
|
||||
// this only applies if `?submit-on-load=true` or `?submit-on-load=1` is in the URL
|
||||
// NOTE: this is required due to React strict mode, where all `useEffect` hooks
|
||||
// are run twice on initial load during development
|
||||
const submitOnLoadPerformed = useRef<boolean>(false);
|
||||
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
// fetch messages for the chat session
|
||||
@ -117,6 +126,16 @@ export const Chat = ({
|
||||
}
|
||||
setMessageHistory([]);
|
||||
setChatSessionSharedStatus(ChatSessionSharedStatus.Private);
|
||||
|
||||
// if we're supposed to submit on initial load, then do that here
|
||||
if (
|
||||
shouldSubmitOnLoad(searchParams) &&
|
||||
!submitOnLoadPerformed.current
|
||||
) {
|
||||
submitOnLoadPerformed.current = true;
|
||||
onSubmit();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@ -151,7 +170,9 @@ export const Chat = ({
|
||||
const [chatSessionId, setChatSessionId] = useState<number | null>(
|
||||
existingChatSessionId
|
||||
);
|
||||
const [message, setMessage] = useState("");
|
||||
const [message, setMessage] = useState(
|
||||
searchParams.get(SEARCH_PARAM_NAMES.USER_MESSAGE) || ""
|
||||
);
|
||||
const [messageHistory, setMessageHistory] = useState<Message[]>([]);
|
||||
const [isStreaming, setIsStreaming] = useState(false);
|
||||
|
||||
@ -385,6 +406,13 @@ export const Chat = ({
|
||||
.map((document) => document.db_doc_id as number),
|
||||
queryOverride,
|
||||
forceSearch,
|
||||
modelVersion:
|
||||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) || undefined,
|
||||
temperature:
|
||||
parseFloat(searchParams.get(SEARCH_PARAM_NAMES.TEMPERATURE) || "") ||
|
||||
undefined,
|
||||
systemPromptOverride:
|
||||
searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
|
||||
})) {
|
||||
for (const packet of packetBunch) {
|
||||
if (Object.hasOwn(packet, "answer_piece")) {
|
||||
@ -454,7 +482,7 @@ export const Chat = ({
|
||||
currChatSessionId === urlChatSessionId.current ||
|
||||
urlChatSessionId.current === null
|
||||
) {
|
||||
router.push(`/chat?chatId=${currChatSessionId}`, {
|
||||
router.push(buildChatUrl(searchParams, currChatSessionId, null), {
|
||||
scroll: false,
|
||||
});
|
||||
}
|
||||
@ -550,7 +578,9 @@ export const Chat = ({
|
||||
if (persona) {
|
||||
setSelectedPersona(persona);
|
||||
textareaRef.current?.focus();
|
||||
router.push(`/chat?personaId=${persona.id}`);
|
||||
router.push(
|
||||
buildChatUrl(searchParams, null, persona.id)
|
||||
);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
@ -577,7 +607,7 @@ export const Chat = ({
|
||||
handlePersonaSelect={(persona) => {
|
||||
setSelectedPersona(persona);
|
||||
textareaRef.current?.focus();
|
||||
router.push(`/chat?personaId=${persona.id}`);
|
||||
router.push(buildChatUrl(searchParams, null, persona.id));
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
@ -15,6 +15,8 @@ import {
|
||||
StreamingError,
|
||||
} from "./interfaces";
|
||||
import { Persona } from "../admin/personas/interfaces";
|
||||
import { ReadonlyURLSearchParams } from "next/navigation";
|
||||
import { SEARCH_PARAM_NAMES } from "./searchParams";
|
||||
|
||||
export async function createChatSession(personaId: number): Promise<number> {
|
||||
const createChatSessionResponse = await fetch(
|
||||
@ -39,17 +41,6 @@ export async function createChatSession(personaId: number): Promise<number> {
|
||||
return chatSessionResponseJson.chat_session_id;
|
||||
}
|
||||
|
||||
export interface SendMessageRequest {
|
||||
message: string;
|
||||
parentMessageId: number | null;
|
||||
chatSessionId: number;
|
||||
promptId: number | null | undefined;
|
||||
filters: Filters | null;
|
||||
selectedDocumentIds: number[] | null;
|
||||
queryOverride?: string;
|
||||
forceSearch?: boolean;
|
||||
}
|
||||
|
||||
export async function* sendMessage({
|
||||
message,
|
||||
parentMessageId,
|
||||
@ -59,7 +50,24 @@ export async function* sendMessage({
|
||||
selectedDocumentIds,
|
||||
queryOverride,
|
||||
forceSearch,
|
||||
}: SendMessageRequest) {
|
||||
modelVersion,
|
||||
temperature,
|
||||
systemPromptOverride,
|
||||
}: {
|
||||
message: string;
|
||||
parentMessageId: number | null;
|
||||
chatSessionId: number;
|
||||
promptId: number | null | undefined;
|
||||
filters: Filters | null;
|
||||
selectedDocumentIds: number[] | null;
|
||||
queryOverride?: string;
|
||||
forceSearch?: boolean;
|
||||
// LLM overrides
|
||||
modelVersion?: string;
|
||||
temperature?: number;
|
||||
// prompt overrides
|
||||
systemPromptOverride?: string;
|
||||
}) {
|
||||
const documentsAreSelected =
|
||||
selectedDocumentIds && selectedDocumentIds.length > 0;
|
||||
const sendMessageResponse = await fetch("/api/chat/send-message", {
|
||||
@ -87,6 +95,13 @@ export async function* sendMessage({
|
||||
}
|
||||
: null,
|
||||
query_override: queryOverride,
|
||||
prompt_override: {
|
||||
system_prompt: systemPromptOverride,
|
||||
},
|
||||
llm_override: {
|
||||
temperature,
|
||||
model_version: modelVersion,
|
||||
},
|
||||
}),
|
||||
});
|
||||
if (!sendMessageResponse.ok) {
|
||||
@ -354,3 +369,38 @@ export function processRawChatHistory(rawMessages: BackendMessage[]) {
|
||||
export function personaIncludesRetrieval(selectedPersona: Persona) {
|
||||
return selectedPersona.num_chunks !== 0;
|
||||
}
|
||||
|
||||
const PARAMS_TO_SKIP = [
|
||||
SEARCH_PARAM_NAMES.SUBMIT_ON_LOAD,
|
||||
SEARCH_PARAM_NAMES.USER_MESSAGE,
|
||||
// only use these if explicitly passed in
|
||||
SEARCH_PARAM_NAMES.CHAT_ID,
|
||||
SEARCH_PARAM_NAMES.PERSONA_ID,
|
||||
];
|
||||
|
||||
export function buildChatUrl(
|
||||
existingSearchParams: ReadonlyURLSearchParams,
|
||||
chatSessionId: number | null,
|
||||
personaId: number | null
|
||||
) {
|
||||
const finalSearchParams: string[] = [];
|
||||
if (chatSessionId) {
|
||||
finalSearchParams.push(`${SEARCH_PARAM_NAMES.CHAT_ID}=${chatSessionId}`);
|
||||
}
|
||||
if (personaId) {
|
||||
finalSearchParams.push(`${SEARCH_PARAM_NAMES.PERSONA_ID}=${personaId}`);
|
||||
}
|
||||
|
||||
existingSearchParams.forEach((value, key) => {
|
||||
if (!PARAMS_TO_SKIP.includes(key)) {
|
||||
finalSearchParams.push(`${key}=${value}`);
|
||||
}
|
||||
});
|
||||
const finalSearchParamsString = finalSearchParams.join("&");
|
||||
|
||||
if (finalSearchParamsString) {
|
||||
return `/chat?${finalSearchParamsString}`;
|
||||
}
|
||||
|
||||
return "/chat";
|
||||
}
|
||||
|
22
web/src/app/chat/searchParams.ts
Normal file
22
web/src/app/chat/searchParams.ts
Normal file
@ -0,0 +1,22 @@
|
||||
import { ReadonlyURLSearchParams } from "next/navigation";
|
||||
|
||||
// search params
|
||||
export const SEARCH_PARAM_NAMES = {
|
||||
CHAT_ID: "chatId",
|
||||
PERSONA_ID: "personaId",
|
||||
// overrides
|
||||
TEMPERATURE: "temperature",
|
||||
MODEL_VERSION: "model-version",
|
||||
SYSTEM_PROMPT: "system-prompt",
|
||||
// user message
|
||||
USER_MESSAGE: "user-message",
|
||||
SUBMIT_ON_LOAD: "submit-on-load",
|
||||
};
|
||||
|
||||
export function shouldSubmitOnLoad(searchParams: ReadonlyURLSearchParams) {
|
||||
const rawSubmitOnLoad = searchParams.get(SEARCH_PARAM_NAMES.SUBMIT_ON_LOAD);
|
||||
if (rawSubmitOnLoad === "true" || rawSubmitOnLoad === "1") {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
@ -117,7 +117,7 @@ export const ChatSidebar = ({
|
||||
{chatSessions.map((chat) => {
|
||||
const isSelected = currentChatId === chat.id;
|
||||
return (
|
||||
<div key={chat.id} className="mr-3">
|
||||
<div key={`${chat.id}-${chat.name}`} className="mr-3">
|
||||
<ChatSessionDisplay
|
||||
chatSession={chat}
|
||||
isSelected={isSelected}
|
||||
|
Loading…
x
Reference in New Issue
Block a user