welcome to onyx

This commit is contained in:
pablodanswer
2024-12-13 09:48:43 -08:00
parent 54dcbfa288
commit 21ec5ed795
813 changed files with 7021 additions and 6824 deletions

View File

@@ -0,0 +1,159 @@
from collections.abc import Callable
from typing import cast
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic.v1 import BaseModel as BaseModel__v1
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLMConfig
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_message_tokens
from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
def default_build_system_message(
prompt_config: PromptConfig,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
if prompt_config.datetime_aware:
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
if not system_prompt:
return None
system_msg = SystemMessage(content=system_prompt)
return system_msg
def default_build_user_message(
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
) -> HumanMessage:
user_prompt = (
CHAT_USER_CONTEXT_FREE_PROMPT.format(
task_prompt=prompt_config.task_prompt, user_query=user_query
)
if prompt_config.task_prompt
else user_query
)
user_prompt = user_prompt.strip()
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
)
return user_msg
class AnswerPromptBuilder:
def __init__(
self,
user_message: HumanMessage,
message_history: list[PreviousMessage],
llm_config: LLMConfig,
raw_user_text: str,
single_message_history: str | None = None,
) -> None:
self.max_tokens = compute_max_llm_input_tokens(llm_config)
llm_tokenizer = get_tokenizer(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
self.llm_tokenizer_encode_func = cast(
Callable[[str], list[int]], llm_tokenizer.encode
)
self.raw_message_history = message_history
(
self.message_history,
self.history_token_cnts,
) = translate_history_to_basemessages(message_history)
# for cases where like the QA flow where we want to condense the chat history
# into a single message rather than a sequence of User / Assistant messages
self.single_message_history = single_message_history
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
self.raw_user_message = raw_user_text
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
self.system_message_and_token_cnt = None
return
self.system_message_and_token_cnt = (
system_message,
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
)
def update_user_prompt(self, user_message: HumanMessage) -> None:
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
)
def append_message(self, message: BaseMessage) -> None:
"""Append a new message to the message history."""
token_count = check_message_tokens(message, self.llm_tokenizer_encode_func)
self.new_messages_and_token_cnts.append((message, token_count))
def get_user_message_content(self) -> str:
query, _ = message_to_prompt_and_imgs(self.user_message_and_token_cnt[0])
return query
def build(self) -> list[BaseMessage]:
if not self.user_message_and_token_cnt:
raise ValueError("User message must be set before building prompt")
final_messages_with_tokens: list[tuple[BaseMessage, int]] = []
if self.system_message_and_token_cnt:
final_messages_with_tokens.append(self.system_message_and_token_cnt)
final_messages_with_tokens.extend(
[
(self.message_history[i], self.history_token_cnts[i])
for i in range(len(self.message_history))
]
)
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if self.new_messages_and_token_cnts:
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens
)
class LLMCall(BaseModel__v1):
prompt_builder: AnswerPromptBuilder
tools: list[Tool]
force_use_tool: ForceUseTool
files: list[InMemoryChatFile]
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
using_tool_calling_llm: bool
class Config:
arbitrary_types_allowed = True

View File

@@ -0,0 +1,179 @@
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from onyx.chat.models import LlmDoc
from onyx.chat.models import PromptConfig
from onyx.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from onyx.context.search.models import InferenceChunk
from onyx.db.models import Persona
from onyx.db.persona import get_default_prompt__read_only
from onyx.db.search_settings import get_multilingual_expansion
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
from onyx.llm.interfaces import LLMConfig
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_number_of_tokens
from onyx.llm.utils import get_max_input_tokens
from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import build_complete_context_str
from onyx.prompts.prompt_utils import build_task_prompt_reminders
from onyx.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
from onyx.prompts.token_counts import (
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
)
from onyx.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from onyx.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from onyx.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_prompt_tokens(prompt_config: PromptConfig) -> int:
# Note: currently custom prompts do not allow datetime aware, only default prompts
return (
check_number_of_tokens(prompt_config.system_prompt)
+ check_number_of_tokens(prompt_config.task_prompt)
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
+ CITATION_STATEMENT_TOKEN_CNT
+ CITATION_REMINDER_TOKEN_CNT
+ (LANGUAGE_HINT_TOKEN_CNT if get_multilingual_expansion() else 0)
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0)
)
# buffer just to be safe so that we don't overflow the token limit due to
# a small miscalculation
_MISC_BUFFER = 40
def compute_max_document_tokens(
prompt_config: PromptConfig,
llm_config: LLMConfig,
actual_user_input: str | None = None,
tool_token_count: int = 0,
max_llm_token_override: int | None = None,
) -> int:
"""Estimates the number of tokens available for context documents. Formula is roughly:
(
model_context_window - reserved_output_tokens - prompt_tokens
- (actual_user_input OR reserved_user_message_tokens) - buffer (just to be safe)
)
The actual_user_input is used at query time. If we are calculating this before knowing the exact input (e.g.
if we're trying to determine if the user should be able to select another document) then we just set an
arbitrary "upper bound".
"""
# if we can't find a number of tokens, just assume some common default
max_input_tokens = (
max_llm_token_override
if max_llm_token_override
else get_max_input_tokens(
model_name=llm_config.model_name, model_provider=llm_config.model_provider
)
)
prompt_tokens = get_prompt_tokens(prompt_config)
user_input_tokens = (
check_number_of_tokens(actual_user_input)
if actual_user_input is not None
else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
)
return (
max_input_tokens
- prompt_tokens
- user_input_tokens
- tool_token_count
- _MISC_BUFFER
)
def compute_max_document_tokens_for_persona(
persona: Persona,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
return compute_max_document_tokens(
prompt_config=PromptConfig.from_model(prompt),
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
actual_user_input=actual_user_input,
max_llm_token_override=max_llm_token_override,
)
def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int:
"""Maximum tokens allows in the input to the LLM (of any type)."""
input_tokens = get_max_input_tokens(
model_name=llm_config.model_name, model_provider=llm_config.model_provider
)
return input_tokens - _MISC_BUFFER
def build_citations_system_message(
prompt_config: PromptConfig,
) -> SystemMessage:
system_prompt = prompt_config.system_prompt.strip()
if prompt_config.include_citations:
system_prompt += REQUIRE_CITATION_STATEMENT
if prompt_config.datetime_aware:
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
return SystemMessage(content=system_prompt)
def build_citations_user_message(
message: HumanMessage,
prompt_config: PromptConfig,
context_docs: list[LlmDoc] | list[InferenceChunk],
all_doc_useful: bool,
history_message: str = "",
) -> HumanMessage:
multilingual_expansion = get_multilingual_expansion()
task_prompt_with_reminder = build_task_prompt_reminders(
prompt=prompt_config, use_language_hint=bool(multilingual_expansion)
)
history_block = (
HISTORY_BLOCK.format(history_str=history_message) + "\n"
if history_message
else ""
)
query, img_urls = message_to_prompt_and_imgs(message)
if context_docs:
context_docs_str = build_complete_context_str(context_docs)
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
user_prompt = CITATIONS_PROMPT.format(
optional_ignore_statement=optional_ignore,
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
user_query=query,
history_block=history_block,
)
else:
# if no context docs provided, assume we're in the tool calling flow
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
task_prompt=task_prompt_with_reminder,
user_query=query,
history_block=history_block,
)
user_prompt = user_prompt.strip()
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, img_urls=img_urls)
if img_urls
else user_prompt
)
return user_msg

View File

@@ -0,0 +1,61 @@
from langchain.schema.messages import HumanMessage
from onyx.chat.models import LlmDoc
from onyx.chat.models import PromptConfig
from onyx.configs.chat_configs import LANGUAGE_HINT
from onyx.context.search.models import InferenceChunk
from onyx.db.search_settings import get_multilingual_expansion
from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.prompts.direct_qa_prompts import CONTEXT_BLOCK
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.direct_qa_prompts import JSON_PROMPT
from onyx.prompts.prompt_utils import add_date_time_to_prompt
from onyx.prompts.prompt_utils import build_complete_context_str
def _build_strong_llm_quotes_prompt(
question: str,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
use_language_hint = bool(get_multilingual_expansion())
context_block = ""
if context_docs:
context_docs_str = build_complete_context_str(context_docs)
context_block = CONTEXT_BLOCK.format(context_docs_str=context_docs_str)
history_block = ""
if history_str:
history_block = HISTORY_BLOCK.format(history_str=history_str)
full_prompt = JSON_PROMPT.format(
system_prompt=prompt.system_prompt,
context_block=context_block,
history_block=history_block,
task_prompt=prompt.task_prompt,
user_query=question,
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
).strip()
if prompt.datetime_aware:
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
return HumanMessage(content=full_prompt)
def build_quotes_user_message(
message: HumanMessage,
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
query, _ = message_to_prompt_and_imgs(message)
return _build_strong_llm_quotes_prompt(
question=query,
context_docs=context_docs,
history_str=history_str,
prompt=prompt,
)

View File

@@ -0,0 +1,60 @@
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from onyx.configs.constants import MessageType
from onyx.db.models import ChatMessage
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from onyx.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
def build_dummy_prompt(
system_prompt: str, task_prompt: str, retrieval_disabled: bool
) -> str:
if retrieval_disabled:
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=system_prompt,
task_prompt=task_prompt,
).strip()
def translate_onyx_msg_to_langchain(
msg: ChatMessage | PreviousMessage,
) -> BaseMessage:
files: list[InMemoryChatFile] = []
# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now.
if not isinstance(msg, ChatMessage):
files = msg.files
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
if msg.message_type == MessageType.USER:
return HumanMessage(content=content)
raise ValueError(f"New message type {msg.message_type} not handled")
def translate_history_to_basemessages(
history: list[ChatMessage] | list["PreviousMessage"],
) -> tuple[list[BaseMessage], list[int]]:
history_basemessages = [
translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0
]
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
return history_basemessages, history_token_counts