mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-20 00:30:10 +02:00
New assistants api (#3097)
This commit is contained in:
parent
9d57f34c34
commit
ba805f766f
@ -288,6 +288,15 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
|
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
|
||||||
|
# below
|
||||||
|
op.execute("DELETE FROM chat_feedback")
|
||||||
|
op.execute("DELETE FROM chat_message__search_doc")
|
||||||
|
op.execute("DELETE FROM document_retrieval_feedback")
|
||||||
|
op.execute("DELETE FROM document_retrieval_feedback")
|
||||||
|
op.execute("DELETE FROM chat_message")
|
||||||
|
op.execute("DELETE FROM chat_session")
|
||||||
|
|
||||||
op.drop_constraint(
|
op.drop_constraint(
|
||||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||||
)
|
)
|
||||||
|
@ -23,6 +23,56 @@ def upgrade() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
|
# Delete chat messages and feedback first since they reference chat sessions
|
||||||
|
# Get chat messages from sessions with null persona_id
|
||||||
|
chat_messages_query = """
|
||||||
|
SELECT id
|
||||||
|
FROM chat_message
|
||||||
|
WHERE chat_session_id IN (
|
||||||
|
SELECT id
|
||||||
|
FROM chat_session
|
||||||
|
WHERE persona_id IS NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Delete dependent records first
|
||||||
|
op.execute(
|
||||||
|
f"""
|
||||||
|
DELETE FROM document_retrieval_feedback
|
||||||
|
WHERE chat_message_id IN (
|
||||||
|
{chat_messages_query}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
op.execute(
|
||||||
|
f"""
|
||||||
|
DELETE FROM chat_message__search_doc
|
||||||
|
WHERE chat_message_id IN (
|
||||||
|
{chat_messages_query}
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Delete chat messages
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM chat_message
|
||||||
|
WHERE chat_session_id IN (
|
||||||
|
SELECT id
|
||||||
|
FROM chat_session
|
||||||
|
WHERE persona_id IS NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now we can safely delete the chat sessions
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DELETE FROM chat_session
|
||||||
|
WHERE persona_id IS NULL
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
op.alter_column(
|
op.alter_column(
|
||||||
"chat_session",
|
"chat_session",
|
||||||
"persona_id",
|
"persona_id",
|
||||||
|
@ -19,16 +19,10 @@ from danswer.chat.models import MessageSpecificCitations
|
|||||||
from danswer.chat.models import QADocsResponse
|
from danswer.chat.models import QADocsResponse
|
||||||
from danswer.chat.models import StreamingError
|
from danswer.chat.models import StreamingError
|
||||||
from danswer.chat.models import StreamStopInfo
|
from danswer.chat.models import StreamStopInfo
|
||||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
|
||||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
|
||||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
|
||||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
|
||||||
from danswer.configs.chat_configs import BING_API_KEY
|
|
||||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
|
||||||
from danswer.db.chat import attach_files_to_chat_message
|
from danswer.db.chat import attach_files_to_chat_message
|
||||||
from danswer.db.chat import create_db_search_doc
|
from danswer.db.chat import create_db_search_doc
|
||||||
from danswer.db.chat import create_new_chat_message
|
from danswer.db.chat import create_new_chat_message
|
||||||
@ -41,7 +35,6 @@ from danswer.db.chat import reserve_message_id
|
|||||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||||
from danswer.db.engine import get_session_context_manager
|
from danswer.db.engine import get_session_context_manager
|
||||||
from danswer.db.llm import fetch_existing_llm_providers
|
|
||||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||||
from danswer.db.models import ToolCall
|
from danswer.db.models import ToolCall
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
@ -61,14 +54,13 @@ from danswer.llm.answering.models import PromptConfig
|
|||||||
from danswer.llm.exceptions import GenAIDisabledException
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
from danswer.llm.factory import get_llms_for_persona
|
from danswer.llm.factory import get_llms_for_persona
|
||||||
from danswer.llm.factory import get_main_llm_from_tuple
|
from danswer.llm.factory import get_main_llm_from_tuple
|
||||||
from danswer.llm.interfaces import LLMConfig
|
|
||||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||||
from danswer.natural_language_processing.utils import get_tokenizer
|
from danswer.natural_language_processing.utils import get_tokenizer
|
||||||
from danswer.search.enums import LLMEvaluationType
|
|
||||||
from danswer.search.enums import OptionalSearchSetting
|
from danswer.search.enums import OptionalSearchSetting
|
||||||
from danswer.search.enums import QueryFlow
|
from danswer.search.enums import QueryFlow
|
||||||
from danswer.search.enums import SearchType
|
from danswer.search.enums import SearchType
|
||||||
from danswer.search.models import InferenceSection
|
from danswer.search.models import InferenceSection
|
||||||
|
from danswer.search.models import RetrievalDetails
|
||||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||||
from danswer.search.utils import dedupe_documents
|
from danswer.search.utils import dedupe_documents
|
||||||
@ -77,14 +69,14 @@ from danswer.search.utils import relevant_sections_to_indices
|
|||||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||||
from danswer.server.utils import get_json_line
|
from danswer.server.utils import get_json_line
|
||||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
|
||||||
from danswer.tools.force import ForceUseTool
|
from danswer.tools.force import ForceUseTool
|
||||||
from danswer.tools.models import DynamicSchemaInfo
|
|
||||||
from danswer.tools.models import ToolResponse
|
from danswer.tools.models import ToolResponse
|
||||||
from danswer.tools.tool import Tool
|
from danswer.tools.tool import Tool
|
||||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
from danswer.tools.tool_constructor import construct_tools
|
||||||
build_custom_tools_from_openapi_schema_and_headers,
|
from danswer.tools.tool_constructor import CustomToolConfig
|
||||||
)
|
from danswer.tools.tool_constructor import ImageGenerationToolConfig
|
||||||
|
from danswer.tools.tool_constructor import InternetSearchToolConfig
|
||||||
|
from danswer.tools.tool_constructor import SearchToolConfig
|
||||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||||
CUSTOM_TOOL_RESPONSE_ID,
|
CUSTOM_TOOL_RESPONSE_ID,
|
||||||
)
|
)
|
||||||
@ -95,9 +87,6 @@ from danswer.tools.tool_implementations.images.image_generation_tool import (
|
|||||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||||
ImageGenerationResponse,
|
ImageGenerationResponse,
|
||||||
)
|
)
|
||||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
|
||||||
ImageGenerationTool,
|
|
||||||
)
|
|
||||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||||
INTERNET_SEARCH_RESPONSE_ID,
|
INTERNET_SEARCH_RESPONSE_ID,
|
||||||
)
|
)
|
||||||
@ -122,9 +111,6 @@ from danswer.tools.tool_implementations.search.search_tool import (
|
|||||||
SECTION_RELEVANCE_LIST_ID,
|
SECTION_RELEVANCE_LIST_ID,
|
||||||
)
|
)
|
||||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||||
from danswer.tools.utils import compute_all_tool_tokens
|
|
||||||
from danswer.tools.utils import explicit_tool_calling_supported
|
|
||||||
from danswer.utils.headers import header_dict_to_header_list
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.timing import log_generator_function_time
|
from danswer.utils.timing import log_generator_function_time
|
||||||
|
|
||||||
@ -295,7 +281,6 @@ def stream_chat_message_objects(
|
|||||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||||
# if specified, uses the last user message and does not create a new user message based
|
# if specified, uses the last user message and does not create a new user message based
|
||||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||||
use_existing_user_message: bool = False,
|
|
||||||
litellm_additional_headers: dict[str, str] | None = None,
|
litellm_additional_headers: dict[str, str] | None = None,
|
||||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||||
is_connected: Callable[[], bool] | None = None,
|
is_connected: Callable[[], bool] | None = None,
|
||||||
@ -307,6 +292,9 @@ def stream_chat_message_objects(
|
|||||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||||
4. [always] Details on the final AI response message that is created
|
4. [always] Details on the final AI response message that is created
|
||||||
"""
|
"""
|
||||||
|
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||||
|
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||||
|
|
||||||
# Currently surrounding context is not supported for chat
|
# Currently surrounding context is not supported for chat
|
||||||
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
||||||
new_msg_req.chunks_above = 0
|
new_msg_req.chunks_above = 0
|
||||||
@ -428,12 +416,20 @@ def stream_chat_message_objects(
|
|||||||
final_msg, history_msgs = create_chat_chain(
|
final_msg, history_msgs = create_chat_chain(
|
||||||
chat_session_id=chat_session_id, db_session=db_session
|
chat_session_id=chat_session_id, db_session=db_session
|
||||||
)
|
)
|
||||||
|
if existing_assistant_message_id is None:
|
||||||
if final_msg.message_type != MessageType.USER:
|
if final_msg.message_type != MessageType.USER:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"The last message was not a user message. Cannot call "
|
"The last message was not a user message. Cannot call "
|
||||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||||
"when the last message is not a user message."
|
"when the last message is not a user message."
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if final_msg.id != existing_assistant_message_id:
|
||||||
|
raise RuntimeError(
|
||||||
|
"The last message was not the existing assistant message. "
|
||||||
|
f"Final message id: {final_msg.id}, "
|
||||||
|
f"existing assistant message id: {existing_assistant_message_id}"
|
||||||
|
)
|
||||||
|
|
||||||
# Disable Query Rephrasing for the first message
|
# Disable Query Rephrasing for the first message
|
||||||
# This leads to a better first response since the LLM rephrasing the question
|
# This leads to a better first response since the LLM rephrasing the question
|
||||||
@ -504,7 +500,12 @@ def stream_chat_message_objects(
|
|||||||
),
|
),
|
||||||
max_window_percentage=max_document_percentage,
|
max_window_percentage=max_document_percentage,
|
||||||
)
|
)
|
||||||
reserved_message_id = reserve_message_id(
|
|
||||||
|
# we don't need to reserve a message id if we're using an existing assistant message
|
||||||
|
reserved_message_id = (
|
||||||
|
final_msg.id
|
||||||
|
if existing_assistant_message_id is not None
|
||||||
|
else reserve_message_id(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
parent_message=user_message.id
|
parent_message=user_message.id
|
||||||
@ -512,6 +513,7 @@ def stream_chat_message_objects(
|
|||||||
else parent_message.id,
|
else parent_message.id,
|
||||||
message_type=MessageType.ASSISTANT,
|
message_type=MessageType.ASSISTANT,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
yield MessageResponseIDInfo(
|
yield MessageResponseIDInfo(
|
||||||
user_message_id=user_message.id if user_message else None,
|
user_message_id=user_message.id if user_message else None,
|
||||||
reserved_assistant_message_id=reserved_message_id,
|
reserved_assistant_message_id=reserved_message_id,
|
||||||
@ -525,7 +527,13 @@ def stream_chat_message_objects(
|
|||||||
partial_response = partial(
|
partial_response = partial(
|
||||||
create_new_chat_message,
|
create_new_chat_message,
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
parent_message=final_msg,
|
# if we're using an existing assistant message, then this will just be an
|
||||||
|
# update operation, in which case the parent should be the parent of
|
||||||
|
# the latest. If we're creating a new assistant message, then the parent
|
||||||
|
# should be the latest message (latest user message)
|
||||||
|
parent_message=(
|
||||||
|
final_msg if existing_assistant_message_id is None else parent_message
|
||||||
|
),
|
||||||
prompt_id=prompt_id,
|
prompt_id=prompt_id,
|
||||||
overridden_model=overridden_model,
|
overridden_model=overridden_model,
|
||||||
# message=,
|
# message=,
|
||||||
@ -537,6 +545,7 @@ def stream_chat_message_objects(
|
|||||||
# reference_docs=,
|
# reference_docs=,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
commit=False,
|
commit=False,
|
||||||
|
reserved_message_id=reserved_message_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not final_msg.prompt:
|
if not final_msg.prompt:
|
||||||
@ -560,142 +569,39 @@ def stream_chat_message_objects(
|
|||||||
structured_response_format=new_msg_req.structured_response_format,
|
structured_response_format=new_msg_req.structured_response_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
# find out what tools to use
|
tool_dict = construct_tools(
|
||||||
search_tool: SearchTool | None = None
|
persona=persona,
|
||||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
prompt_config=prompt_config,
|
||||||
for db_tool_model in persona.tools:
|
|
||||||
# handle in-code tools specially
|
|
||||||
if db_tool_model.in_code_tool_id:
|
|
||||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
|
||||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
|
||||||
search_tool = SearchTool(
|
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
user=user,
|
user=user,
|
||||||
persona=persona,
|
|
||||||
retrieval_options=retrieval_options,
|
|
||||||
prompt_config=prompt_config,
|
|
||||||
llm=llm,
|
llm=llm,
|
||||||
fast_llm=fast_llm,
|
fast_llm=fast_llm,
|
||||||
pruning_config=document_pruning_config,
|
search_tool_config=SearchToolConfig(
|
||||||
answer_style_config=answer_style_config,
|
answer_style_config=answer_style_config,
|
||||||
|
document_pruning_config=document_pruning_config,
|
||||||
|
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||||
selected_sections=selected_sections,
|
selected_sections=selected_sections,
|
||||||
chunks_above=new_msg_req.chunks_above,
|
chunks_above=new_msg_req.chunks_above,
|
||||||
chunks_below=new_msg_req.chunks_below,
|
chunks_below=new_msg_req.chunks_below,
|
||||||
full_doc=new_msg_req.full_doc,
|
full_doc=new_msg_req.full_doc,
|
||||||
evaluation_type=(
|
latest_query_files=latest_query_files,
|
||||||
LLMEvaluationType.BASIC
|
|
||||||
if persona.llm_relevance_filter
|
|
||||||
else LLMEvaluationType.SKIP
|
|
||||||
),
|
),
|
||||||
)
|
internet_search_tool_config=InternetSearchToolConfig(
|
||||||
tool_dict[db_tool_model.id] = [search_tool]
|
|
||||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
|
||||||
img_generation_llm_config: LLMConfig | None = None
|
|
||||||
if (
|
|
||||||
llm
|
|
||||||
and llm.config.api_key
|
|
||||||
and llm.config.model_provider == "openai"
|
|
||||||
):
|
|
||||||
img_generation_llm_config = LLMConfig(
|
|
||||||
model_provider=llm.config.model_provider,
|
|
||||||
model_name="dall-e-3",
|
|
||||||
temperature=GEN_AI_TEMPERATURE,
|
|
||||||
api_key=llm.config.api_key,
|
|
||||||
api_base=llm.config.api_base,
|
|
||||||
api_version=llm.config.api_version,
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
llm.config.model_provider == "azure"
|
|
||||||
and AZURE_DALLE_API_KEY is not None
|
|
||||||
):
|
|
||||||
img_generation_llm_config = LLMConfig(
|
|
||||||
model_provider="azure",
|
|
||||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
|
||||||
temperature=GEN_AI_TEMPERATURE,
|
|
||||||
api_key=AZURE_DALLE_API_KEY,
|
|
||||||
api_base=AZURE_DALLE_API_BASE,
|
|
||||||
api_version=AZURE_DALLE_API_VERSION,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
llm_providers = fetch_existing_llm_providers(db_session)
|
|
||||||
openai_provider = next(
|
|
||||||
iter(
|
|
||||||
[
|
|
||||||
llm_provider
|
|
||||||
for llm_provider in llm_providers
|
|
||||||
if llm_provider.provider == "openai"
|
|
||||||
]
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if not openai_provider or not openai_provider.api_key:
|
|
||||||
raise ValueError(
|
|
||||||
"Image generation tool requires an OpenAI API key"
|
|
||||||
)
|
|
||||||
img_generation_llm_config = LLMConfig(
|
|
||||||
model_provider=openai_provider.provider,
|
|
||||||
model_name="dall-e-3",
|
|
||||||
temperature=GEN_AI_TEMPERATURE,
|
|
||||||
api_key=openai_provider.api_key,
|
|
||||||
api_base=openai_provider.api_base,
|
|
||||||
api_version=openai_provider.api_version,
|
|
||||||
)
|
|
||||||
tool_dict[db_tool_model.id] = [
|
|
||||||
ImageGenerationTool(
|
|
||||||
api_key=cast(str, img_generation_llm_config.api_key),
|
|
||||||
api_base=img_generation_llm_config.api_base,
|
|
||||||
api_version=img_generation_llm_config.api_version,
|
|
||||||
additional_headers=litellm_additional_headers,
|
|
||||||
model=img_generation_llm_config.model_name,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
|
||||||
bing_api_key = BING_API_KEY
|
|
||||||
if not bing_api_key:
|
|
||||||
raise ValueError(
|
|
||||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
|
||||||
)
|
|
||||||
tool_dict[db_tool_model.id] = [
|
|
||||||
InternetSearchTool(
|
|
||||||
api_key=bing_api_key,
|
|
||||||
answer_style_config=answer_style_config,
|
answer_style_config=answer_style_config,
|
||||||
prompt_config=prompt_config,
|
),
|
||||||
)
|
image_generation_tool_config=ImageGenerationToolConfig(
|
||||||
]
|
additional_headers=litellm_additional_headers,
|
||||||
|
),
|
||||||
continue
|
custom_tool_config=CustomToolConfig(
|
||||||
|
|
||||||
# handle all custom tools
|
|
||||||
if db_tool_model.openapi_schema:
|
|
||||||
tool_dict[db_tool_model.id] = cast(
|
|
||||||
list[Tool],
|
|
||||||
build_custom_tools_from_openapi_schema_and_headers(
|
|
||||||
db_tool_model.openapi_schema,
|
|
||||||
dynamic_schema_info=DynamicSchemaInfo(
|
|
||||||
chat_session_id=chat_session_id,
|
chat_session_id=chat_session_id,
|
||||||
message_id=user_message.id if user_message else None,
|
message_id=user_message.id if user_message else None,
|
||||||
),
|
additional_headers=custom_tool_additional_headers,
|
||||||
custom_headers=(db_tool_model.custom_headers or [])
|
|
||||||
+ (
|
|
||||||
header_dict_to_header_list(
|
|
||||||
custom_tool_additional_headers or {}
|
|
||||||
)
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
tools: list[Tool] = []
|
tools: list[Tool] = []
|
||||||
for tool_list in tool_dict.values():
|
for tool_list in tool_dict.values():
|
||||||
tools.extend(tool_list)
|
tools.extend(tool_list)
|
||||||
|
|
||||||
# factor in tool definition size when pruning
|
|
||||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
|
||||||
tools, llm_tokenizer
|
|
||||||
)
|
|
||||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
|
||||||
llm_provider, llm_model_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# LLM prompt building, response capturing, etc.
|
# LLM prompt building, response capturing, etc.
|
||||||
answer = Answer(
|
answer = Answer(
|
||||||
is_connected=is_connected,
|
is_connected=is_connected,
|
||||||
@ -871,7 +777,6 @@ def stream_chat_message_objects(
|
|||||||
tool_name_to_tool_id[tool.name] = tool_id
|
tool_name_to_tool_id[tool.name] = tool_id
|
||||||
|
|
||||||
gen_ai_response_message = partial_response(
|
gen_ai_response_message = partial_response(
|
||||||
reserved_message_id=reserved_message_id,
|
|
||||||
message=answer.llm_answer,
|
message=answer.llm_answer,
|
||||||
rephrased_query=(
|
rephrased_query=(
|
||||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||||
@ -879,9 +784,11 @@ def stream_chat_message_objects(
|
|||||||
reference_docs=reference_db_search_docs,
|
reference_docs=reference_db_search_docs,
|
||||||
files=ai_message_files,
|
files=ai_message_files,
|
||||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||||
citations=message_specific_citations.citation_map
|
citations=(
|
||||||
|
message_specific_citations.citation_map
|
||||||
if message_specific_citations
|
if message_specific_citations
|
||||||
else None,
|
else None
|
||||||
|
),
|
||||||
error=None,
|
error=None,
|
||||||
tool_call=(
|
tool_call=(
|
||||||
ToolCall(
|
ToolCall(
|
||||||
@ -915,7 +822,6 @@ def stream_chat_message_objects(
|
|||||||
def stream_chat_message(
|
def stream_chat_message(
|
||||||
new_msg_req: CreateChatMessageRequest,
|
new_msg_req: CreateChatMessageRequest,
|
||||||
user: User | None,
|
user: User | None,
|
||||||
use_existing_user_message: bool = False,
|
|
||||||
litellm_additional_headers: dict[str, str] | None = None,
|
litellm_additional_headers: dict[str, str] | None = None,
|
||||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||||
is_connected: Callable[[], bool] | None = None,
|
is_connected: Callable[[], bool] | None = None,
|
||||||
@ -925,7 +831,6 @@ def stream_chat_message(
|
|||||||
new_msg_req=new_msg_req,
|
new_msg_req=new_msg_req,
|
||||||
user=user,
|
user=user,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
use_existing_user_message=use_existing_user_message,
|
|
||||||
litellm_additional_headers=litellm_additional_headers,
|
litellm_additional_headers=litellm_additional_headers,
|
||||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||||
is_connected=is_connected,
|
is_connected=is_connected,
|
||||||
|
@ -24,6 +24,13 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
|
|||||||
return tool
|
return tool
|
||||||
|
|
||||||
|
|
||||||
|
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
|
||||||
|
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
|
||||||
|
if not tool:
|
||||||
|
raise ValueError("Tool by specified name does not exist")
|
||||||
|
return tool
|
||||||
|
|
||||||
|
|
||||||
def create_tool(
|
def create_tool(
|
||||||
name: str,
|
name: str,
|
||||||
description: str | None,
|
description: str | None,
|
||||||
@ -37,7 +44,7 @@ def create_tool(
|
|||||||
description=description,
|
description=description,
|
||||||
in_code_tool_id=None,
|
in_code_tool_id=None,
|
||||||
openapi_schema=openapi_schema,
|
openapi_schema=openapi_schema,
|
||||||
custom_headers=[header.dict() for header in custom_headers]
|
custom_headers=[header.model_dump() for header in custom_headers]
|
||||||
if custom_headers
|
if custom_headers
|
||||||
else [],
|
else [],
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
@ -74,6 +74,9 @@ from danswer.server.manage.search_settings import router as search_settings_rout
|
|||||||
from danswer.server.manage.slack_bot import router as slack_bot_management_router
|
from danswer.server.manage.slack_bot import router as slack_bot_management_router
|
||||||
from danswer.server.manage.users import router as user_router
|
from danswer.server.manage.users import router as user_router
|
||||||
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
|
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
|
||||||
|
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
|
||||||
|
get_full_openai_assistants_api_router,
|
||||||
|
)
|
||||||
from danswer.server.query_and_chat.chat_backend import router as chat_router
|
from danswer.server.query_and_chat.chat_backend import router as chat_router
|
||||||
from danswer.server.query_and_chat.query_backend import (
|
from danswer.server.query_and_chat.query_backend import (
|
||||||
admin_router as admin_query_router,
|
admin_router as admin_query_router,
|
||||||
@ -270,6 +273,9 @@ def get_application() -> FastAPI:
|
|||||||
application, token_rate_limit_settings_router
|
application, token_rate_limit_settings_router
|
||||||
)
|
)
|
||||||
include_router_with_global_prefix_prepended(application, indexing_router)
|
include_router_with_global_prefix_prepended(application, indexing_router)
|
||||||
|
include_router_with_global_prefix_prepended(
|
||||||
|
application, get_full_openai_assistants_api_router()
|
||||||
|
)
|
||||||
|
|
||||||
if AUTH_TYPE == AuthType.DISABLED:
|
if AUTH_TYPE == AuthType.DISABLED:
|
||||||
# Server logs this during auth setup verification step
|
# Server logs this during auth setup verification step
|
||||||
|
273
backend/danswer/server/openai_assistants_api/asssistants_api.py
Normal file
273
backend/danswer/server/openai_assistants_api/asssistants_api.py
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
from typing import Any
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from fastapi import Query
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.auth.users import current_user
|
||||||
|
from danswer.db.engine import get_session
|
||||||
|
from danswer.db.models import Persona
|
||||||
|
from danswer.db.models import User
|
||||||
|
from danswer.db.persona import get_persona_by_id
|
||||||
|
from danswer.db.persona import get_personas
|
||||||
|
from danswer.db.persona import mark_persona_as_deleted
|
||||||
|
from danswer.db.persona import upsert_persona
|
||||||
|
from danswer.db.persona import upsert_prompt
|
||||||
|
from danswer.db.tools import get_tool_by_name
|
||||||
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/assistants")
|
||||||
|
|
||||||
|
|
||||||
|
# Base models
|
||||||
|
class AssistantObject(BaseModel):
|
||||||
|
id: int
|
||||||
|
object: str = "assistant"
|
||||||
|
created_at: int
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
model: str
|
||||||
|
instructions: Optional[str] = None
|
||||||
|
tools: list[dict[str, Any]]
|
||||||
|
file_ids: list[str]
|
||||||
|
metadata: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CreateAssistantRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
instructions: Optional[str] = None
|
||||||
|
tools: Optional[list[dict[str, Any]]] = None
|
||||||
|
file_ids: Optional[list[str]] = None
|
||||||
|
metadata: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModifyAssistantRequest(BaseModel):
|
||||||
|
model: Optional[str] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
description: Optional[str] = None
|
||||||
|
instructions: Optional[str] = None
|
||||||
|
tools: Optional[list[dict[str, Any]]] = None
|
||||||
|
file_ids: Optional[list[str]] = None
|
||||||
|
metadata: Optional[dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DeleteAssistantResponse(BaseModel):
|
||||||
|
id: int
|
||||||
|
object: str = "assistant.deleted"
|
||||||
|
deleted: bool
|
||||||
|
|
||||||
|
|
||||||
|
class ListAssistantsResponse(BaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
data: list[AssistantObject]
|
||||||
|
first_id: Optional[int] = None
|
||||||
|
last_id: Optional[int] = None
|
||||||
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
def persona_to_assistant(persona: Persona) -> AssistantObject:
|
||||||
|
return AssistantObject(
|
||||||
|
id=persona.id,
|
||||||
|
created_at=0,
|
||||||
|
name=persona.name,
|
||||||
|
description=persona.description,
|
||||||
|
model=persona.llm_model_version_override or "gpt-3.5-turbo",
|
||||||
|
instructions=persona.prompts[0].system_prompt if persona.prompts else None,
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"type": tool.display_name,
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"schema": tool.openapi_schema,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tool in persona.tools
|
||||||
|
],
|
||||||
|
file_ids=[], # Assuming no file support for now
|
||||||
|
metadata={}, # Assuming no metadata for now
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# API endpoints
|
||||||
|
@router.post("")
|
||||||
|
def create_assistant(
|
||||||
|
request: CreateAssistantRequest,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> AssistantObject:
|
||||||
|
prompt = None
|
||||||
|
if request.instructions:
|
||||||
|
prompt = upsert_prompt(
|
||||||
|
user=user,
|
||||||
|
name=f"Prompt for {request.name or 'New Assistant'}",
|
||||||
|
description="Auto-generated prompt",
|
||||||
|
system_prompt=request.instructions,
|
||||||
|
task_prompt="",
|
||||||
|
include_citations=True,
|
||||||
|
datetime_aware=True,
|
||||||
|
personas=[],
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_ids = []
|
||||||
|
for tool in request.tools or []:
|
||||||
|
tool_type = tool.get("type")
|
||||||
|
if not tool_type:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
tool_db = get_tool_by_name(tool_type, db_session)
|
||||||
|
tool_ids.append(tool_db.id)
|
||||||
|
except ValueError:
|
||||||
|
# Skip tools that don't exist in the database
|
||||||
|
logger.error(f"Tool {tool_type} not found in database")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Tool {tool_type} not found in database"
|
||||||
|
)
|
||||||
|
|
||||||
|
persona = upsert_persona(
|
||||||
|
user=user,
|
||||||
|
name=request.name or f"Assistant-{uuid4()}",
|
||||||
|
description=request.description or "",
|
||||||
|
num_chunks=25,
|
||||||
|
llm_relevance_filter=True,
|
||||||
|
llm_filter_extraction=True,
|
||||||
|
recency_bias=RecencyBiasSetting.AUTO,
|
||||||
|
llm_model_provider_override=None,
|
||||||
|
llm_model_version_override=request.model,
|
||||||
|
starter_messages=None,
|
||||||
|
is_public=False,
|
||||||
|
db_session=db_session,
|
||||||
|
prompt_ids=[prompt.id] if prompt else [0],
|
||||||
|
document_set_ids=[],
|
||||||
|
tool_ids=tool_ids,
|
||||||
|
icon_color=None,
|
||||||
|
icon_shape=None,
|
||||||
|
is_visible=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if prompt:
|
||||||
|
prompt.personas = [persona]
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
return persona_to_assistant(persona)
|
||||||
|
|
||||||
|
|
||||||
|
""
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{assistant_id}")
|
||||||
|
def retrieve_assistant(
|
||||||
|
assistant_id: int,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> AssistantObject:
|
||||||
|
try:
|
||||||
|
persona = get_persona_by_id(
|
||||||
|
persona_id=assistant_id,
|
||||||
|
user=user,
|
||||||
|
db_session=db_session,
|
||||||
|
is_for_edit=False,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
persona = None
|
||||||
|
|
||||||
|
if not persona:
|
||||||
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||||
|
return persona_to_assistant(persona)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{assistant_id}")
|
||||||
|
def modify_assistant(
|
||||||
|
assistant_id: int,
|
||||||
|
request: ModifyAssistantRequest,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> AssistantObject:
|
||||||
|
persona = get_persona_by_id(
|
||||||
|
persona_id=assistant_id,
|
||||||
|
user=user,
|
||||||
|
db_session=db_session,
|
||||||
|
is_for_edit=True,
|
||||||
|
)
|
||||||
|
if not persona:
|
||||||
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||||
|
|
||||||
|
update_data = request.model_dump(exclude_unset=True)
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(persona, key, value)
|
||||||
|
|
||||||
|
if "instructions" in update_data and persona.prompts:
|
||||||
|
persona.prompts[0].system_prompt = update_data["instructions"]
|
||||||
|
|
||||||
|
db_session.commit()
|
||||||
|
return persona_to_assistant(persona)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{assistant_id}")
|
||||||
|
def delete_assistant(
|
||||||
|
assistant_id: int,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> DeleteAssistantResponse:
|
||||||
|
try:
|
||||||
|
mark_persona_as_deleted(
|
||||||
|
persona_id=int(assistant_id),
|
||||||
|
user=user,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
return DeleteAssistantResponse(id=assistant_id, deleted=True)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
def list_assistants(
|
||||||
|
limit: int = Query(20, le=100),
|
||||||
|
order: str = Query("desc", regex="^(asc|desc)$"),
|
||||||
|
after: Optional[int] = None,
|
||||||
|
before: Optional[int] = None,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> ListAssistantsResponse:
|
||||||
|
personas = list(
|
||||||
|
get_personas(
|
||||||
|
user=user,
|
||||||
|
db_session=db_session,
|
||||||
|
get_editable=False,
|
||||||
|
joinedload_all=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply filtering based on after and before
|
||||||
|
if after:
|
||||||
|
personas = [p for p in personas if p.id > int(after)]
|
||||||
|
if before:
|
||||||
|
personas = [p for p in personas if p.id < int(before)]
|
||||||
|
|
||||||
|
# Apply ordering
|
||||||
|
personas.sort(key=lambda p: p.id, reverse=(order == "desc"))
|
||||||
|
|
||||||
|
# Apply limit
|
||||||
|
personas = personas[:limit]
|
||||||
|
|
||||||
|
assistants = [persona_to_assistant(p) for p in personas]
|
||||||
|
|
||||||
|
return ListAssistantsResponse(
|
||||||
|
data=assistants,
|
||||||
|
first_id=assistants[0].id if assistants else None,
|
||||||
|
last_id=assistants[-1].id if assistants else None,
|
||||||
|
has_more=len(personas) == limit,
|
||||||
|
)
|
@ -0,0 +1,19 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from danswer.server.openai_assistants_api.asssistants_api import (
|
||||||
|
router as assistants_router,
|
||||||
|
)
|
||||||
|
from danswer.server.openai_assistants_api.messages_api import router as messages_router
|
||||||
|
from danswer.server.openai_assistants_api.runs_api import router as runs_router
|
||||||
|
from danswer.server.openai_assistants_api.threads_api import router as threads_router
|
||||||
|
|
||||||
|
|
||||||
|
def get_full_openai_assistants_api_router() -> APIRouter:
|
||||||
|
router = APIRouter(prefix="/openai-assistants")
|
||||||
|
|
||||||
|
router.include_router(assistants_router)
|
||||||
|
router.include_router(runs_router)
|
||||||
|
router.include_router(threads_router)
|
||||||
|
router.include_router(messages_router)
|
||||||
|
|
||||||
|
return router
|
235
backend/danswer/server/openai_assistants_api/messages_api.py
Normal file
235
backend/danswer/server/openai_assistants_api/messages_api.py
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
from typing import Literal
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from pydantic import Field
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.auth.users import current_user
|
||||||
|
from danswer.configs.constants import MessageType
|
||||||
|
from danswer.db.chat import create_new_chat_message
|
||||||
|
from danswer.db.chat import get_chat_message
|
||||||
|
from danswer.db.chat import get_chat_messages_by_session
|
||||||
|
from danswer.db.chat import get_chat_session_by_id
|
||||||
|
from danswer.db.chat import get_or_create_root_message
|
||||||
|
from danswer.db.engine import get_session
|
||||||
|
from danswer.db.models import User
|
||||||
|
from danswer.llm.utils import check_number_of_tokens
|
||||||
|
|
||||||
|
router = APIRouter(prefix="")
|
||||||
|
|
||||||
|
|
||||||
|
Role = Literal["user", "assistant"]
|
||||||
|
|
||||||
|
|
||||||
|
class MessageContent(BaseModel):
|
||||||
|
type: Literal["text"]
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
|
||||||
|
object: Literal["thread.message"] = "thread.message"
|
||||||
|
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||||
|
thread_id: str
|
||||||
|
role: Role
|
||||||
|
content: list[MessageContent]
|
||||||
|
file_ids: list[str] = []
|
||||||
|
assistant_id: Optional[str] = None
|
||||||
|
run_id: Optional[str] = None
|
||||||
|
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class CreateMessageRequest(BaseModel):
|
||||||
|
role: Role
|
||||||
|
content: str
|
||||||
|
file_ids: list[str] = []
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ListMessagesResponse(BaseModel):
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
data: list[Message]
|
||||||
|
first_id: str
|
||||||
|
last_id: str
|
||||||
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/threads/{thread_id}/messages")
|
||||||
|
def create_message(
|
||||||
|
thread_id: str,
|
||||||
|
message: CreateMessageRequest,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> Message:
|
||||||
|
user_id = user.id if user else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_session = get_chat_session_by_id(
|
||||||
|
chat_session_id=uuid.UUID(thread_id),
|
||||||
|
user_id=user_id,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||||
|
|
||||||
|
chat_messages = get_chat_messages_by_session(
|
||||||
|
chat_session_id=chat_session.id,
|
||||||
|
user_id=user.id if user else None,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
latest_message = (
|
||||||
|
chat_messages[-1]
|
||||||
|
if chat_messages
|
||||||
|
else get_or_create_root_message(chat_session.id, db_session)
|
||||||
|
)
|
||||||
|
|
||||||
|
new_message = create_new_chat_message(
|
||||||
|
chat_session_id=chat_session.id,
|
||||||
|
parent_message=latest_message,
|
||||||
|
message=message.content,
|
||||||
|
prompt_id=chat_session.persona.prompts[0].id,
|
||||||
|
token_count=check_number_of_tokens(message.content),
|
||||||
|
message_type=(
|
||||||
|
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
|
||||||
|
),
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Message(
|
||||||
|
id=str(new_message.id),
|
||||||
|
thread_id=thread_id,
|
||||||
|
role="user",
|
||||||
|
content=[MessageContent(type="text", text=message.content)],
|
||||||
|
file_ids=message.file_ids,
|
||||||
|
metadata=message.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/threads/{thread_id}/messages")
|
||||||
|
def list_messages(
|
||||||
|
thread_id: str,
|
||||||
|
limit: int = 20,
|
||||||
|
order: Literal["asc", "desc"] = "desc",
|
||||||
|
after: Optional[str] = None,
|
||||||
|
before: Optional[str] = None,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> ListMessagesResponse:
|
||||||
|
user_id = user.id if user else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_session = get_chat_session_by_id(
|
||||||
|
chat_session_id=uuid.UUID(thread_id),
|
||||||
|
user_id=user_id,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||||
|
|
||||||
|
messages = get_chat_messages_by_session(
|
||||||
|
chat_session_id=chat_session.id,
|
||||||
|
user_id=user_id,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply filtering based on after and before
|
||||||
|
if after:
|
||||||
|
messages = [m for m in messages if str(m.id) >= after]
|
||||||
|
if before:
|
||||||
|
messages = [m for m in messages if str(m.id) <= before]
|
||||||
|
|
||||||
|
# Apply ordering
|
||||||
|
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
|
||||||
|
|
||||||
|
# Apply limit
|
||||||
|
messages = messages[:limit]
|
||||||
|
|
||||||
|
data = [
|
||||||
|
Message(
|
||||||
|
id=str(m.id),
|
||||||
|
thread_id=thread_id,
|
||||||
|
role="user" if m.message_type == "user" else "assistant",
|
||||||
|
content=[MessageContent(type="text", text=m.message)],
|
||||||
|
created_at=int(m.time_sent.timestamp()),
|
||||||
|
)
|
||||||
|
for m in messages
|
||||||
|
]
|
||||||
|
|
||||||
|
return ListMessagesResponse(
|
||||||
|
data=data,
|
||||||
|
first_id=str(data[0].id) if data else "",
|
||||||
|
last_id=str(data[-1].id) if data else "",
|
||||||
|
has_more=len(messages) == limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/threads/{thread_id}/messages/{message_id}")
|
||||||
|
def retrieve_message(
|
||||||
|
thread_id: str,
|
||||||
|
message_id: int,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> Message:
|
||||||
|
user_id = user.id if user else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_message = get_chat_message(
|
||||||
|
chat_message_id=message_id,
|
||||||
|
user_id=user_id,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail="Message not found")
|
||||||
|
|
||||||
|
return Message(
|
||||||
|
id=str(chat_message.id),
|
||||||
|
thread_id=thread_id,
|
||||||
|
role="user" if chat_message.message_type == "user" else "assistant",
|
||||||
|
content=[MessageContent(type="text", text=chat_message.message)],
|
||||||
|
created_at=int(chat_message.time_sent.timestamp()),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ModifyMessageRequest(BaseModel):
|
||||||
|
metadata: dict
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/threads/{thread_id}/messages/{message_id}")
|
||||||
|
def modify_message(
|
||||||
|
thread_id: str,
|
||||||
|
message_id: int,
|
||||||
|
request: ModifyMessageRequest,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> Message:
|
||||||
|
user_id = user.id if user else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
chat_message = get_chat_message(
|
||||||
|
chat_message_id=message_id,
|
||||||
|
user_id=user_id,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail="Message not found")
|
||||||
|
|
||||||
|
# Update metadata
|
||||||
|
# TODO: Uncomment this once we have metadata in the chat message
|
||||||
|
# chat_message.metadata = request.metadata
|
||||||
|
# db_session.commit()
|
||||||
|
|
||||||
|
return Message(
|
||||||
|
id=str(chat_message.id),
|
||||||
|
thread_id=thread_id,
|
||||||
|
role="user" if chat_message.message_type == "user" else "assistant",
|
||||||
|
content=[MessageContent(type="text", text=chat_message.message)],
|
||||||
|
created_at=int(chat_message.time_sent.timestamp()),
|
||||||
|
metadata=request.metadata,
|
||||||
|
)
|
344
backend/danswer/server/openai_assistants_api/runs_api.py
Normal file
344
backend/danswer/server/openai_assistants_api/runs_api.py
Normal file
@ -0,0 +1,344 @@
|
|||||||
|
from typing import Literal
|
||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi import BackgroundTasks
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.auth.users import current_user
|
||||||
|
from danswer.chat.process_message import stream_chat_message_objects
|
||||||
|
from danswer.configs.constants import MessageType
|
||||||
|
from danswer.db.chat import create_new_chat_message
|
||||||
|
from danswer.db.chat import get_chat_message
|
||||||
|
from danswer.db.chat import get_chat_messages_by_session
|
||||||
|
from danswer.db.chat import get_chat_session_by_id
|
||||||
|
from danswer.db.chat import get_or_create_root_message
|
||||||
|
from danswer.db.engine import get_session
|
||||||
|
from danswer.db.models import ChatMessage
|
||||||
|
from danswer.db.models import User
|
||||||
|
from danswer.search.models import RetrievalDetails
|
||||||
|
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||||
|
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||||
|
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
class RunRequest(BaseModel):
|
||||||
|
assistant_id: int
|
||||||
|
model: Optional[str] = None
|
||||||
|
instructions: Optional[str] = None
|
||||||
|
additional_instructions: Optional[str] = None
|
||||||
|
tools: Optional[list[dict]] = None
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
RunStatus = Literal[
|
||||||
|
"queued",
|
||||||
|
"in_progress",
|
||||||
|
"requires_action",
|
||||||
|
"cancelling",
|
||||||
|
"cancelled",
|
||||||
|
"failed",
|
||||||
|
"completed",
|
||||||
|
"expired",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class RunResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
object: Literal["thread.run"]
|
||||||
|
created_at: int
|
||||||
|
assistant_id: int
|
||||||
|
thread_id: UUID
|
||||||
|
status: RunStatus
|
||||||
|
started_at: Optional[int] = None
|
||||||
|
expires_at: Optional[int] = None
|
||||||
|
cancelled_at: Optional[int] = None
|
||||||
|
failed_at: Optional[int] = None
|
||||||
|
completed_at: Optional[int] = None
|
||||||
|
last_error: Optional[dict] = None
|
||||||
|
model: str
|
||||||
|
instructions: str
|
||||||
|
tools: list[dict]
|
||||||
|
file_ids: list[str]
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
def process_run_in_background(
|
||||||
|
message_id: int,
|
||||||
|
parent_message_id: int,
|
||||||
|
chat_session_id: UUID,
|
||||||
|
assistant_id: int,
|
||||||
|
instructions: str,
|
||||||
|
tools: list[dict],
|
||||||
|
user: User | None,
|
||||||
|
db_session: Session,
|
||||||
|
) -> None:
|
||||||
|
# Get the latest message in the chat session
|
||||||
|
chat_session = get_chat_session_by_id(
|
||||||
|
chat_session_id=chat_session_id,
|
||||||
|
user_id=user.id if user else None,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
search_tool_retrieval_details = RetrievalDetails()
|
||||||
|
for tool in tools:
|
||||||
|
if tool["type"] == SearchTool.__name__ and (
|
||||||
|
retrieval_details := tool.get("retrieval_details")
|
||||||
|
):
|
||||||
|
search_tool_retrieval_details = RetrievalDetails.model_validate(
|
||||||
|
retrieval_details
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
new_msg_req = CreateChatMessageRequest(
|
||||||
|
chat_session_id=chat_session_id,
|
||||||
|
parent_message_id=int(parent_message_id) if parent_message_id else None,
|
||||||
|
message=instructions,
|
||||||
|
file_descriptors=[],
|
||||||
|
prompt_id=chat_session.persona.prompts[0].id,
|
||||||
|
search_doc_ids=None,
|
||||||
|
retrieval_options=search_tool_retrieval_details, # Adjust as needed
|
||||||
|
query_override=None,
|
||||||
|
regenerate=None,
|
||||||
|
llm_override=None,
|
||||||
|
prompt_override=None,
|
||||||
|
alternate_assistant_id=assistant_id,
|
||||||
|
use_existing_user_message=True,
|
||||||
|
existing_assistant_message_id=message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
run_message = get_chat_message(message_id, user.id if user else None, db_session)
|
||||||
|
try:
|
||||||
|
for packet in stream_chat_message_objects(
|
||||||
|
new_msg_req=new_msg_req,
|
||||||
|
user=user,
|
||||||
|
db_session=db_session,
|
||||||
|
):
|
||||||
|
if isinstance(packet, ChatMessageDetail):
|
||||||
|
# Update the run status and message content
|
||||||
|
run_message = get_chat_message(
|
||||||
|
message_id, user.id if user else None, db_session
|
||||||
|
)
|
||||||
|
if run_message:
|
||||||
|
# this handles cancelling
|
||||||
|
if run_message.error:
|
||||||
|
return
|
||||||
|
|
||||||
|
run_message.message = packet.message
|
||||||
|
run_message.message_type = MessageType.ASSISTANT
|
||||||
|
db_session.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Error processing run in background")
|
||||||
|
run_message.error = str(e)
|
||||||
|
db_session.commit()
|
||||||
|
return
|
||||||
|
|
||||||
|
db_session.refresh(run_message)
|
||||||
|
if run_message.token_count == 0:
|
||||||
|
run_message.error = "No tokens generated"
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/threads/{thread_id}/runs")
|
||||||
|
def create_run(
|
||||||
|
thread_id: UUID,
|
||||||
|
run_request: RunRequest,
|
||||||
|
background_tasks: BackgroundTasks,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> RunResponse:
|
||||||
|
try:
|
||||||
|
chat_session = get_chat_session_by_id(
|
||||||
|
chat_session_id=thread_id,
|
||||||
|
user_id=user.id if user else None,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
chat_messages = get_chat_messages_by_session(
|
||||||
|
chat_session_id=chat_session.id,
|
||||||
|
user_id=user.id if user else None,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
latest_message = (
|
||||||
|
chat_messages[-1]
|
||||||
|
if chat_messages
|
||||||
|
else get_or_create_root_message(chat_session.id, db_session)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a new "run" (chat message) in the session
|
||||||
|
new_message = create_new_chat_message(
|
||||||
|
chat_session_id=chat_session.id,
|
||||||
|
parent_message=latest_message,
|
||||||
|
message="",
|
||||||
|
prompt_id=chat_session.persona.prompts[0].id,
|
||||||
|
token_count=0,
|
||||||
|
message_type=MessageType.ASSISTANT,
|
||||||
|
db_session=db_session,
|
||||||
|
commit=False,
|
||||||
|
)
|
||||||
|
db_session.flush()
|
||||||
|
latest_message.latest_child_message = new_message.id
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
# Schedule the background task
|
||||||
|
background_tasks.add_task(
|
||||||
|
process_run_in_background,
|
||||||
|
new_message.id,
|
||||||
|
latest_message.id,
|
||||||
|
chat_session.id,
|
||||||
|
run_request.assistant_id,
|
||||||
|
run_request.instructions or "",
|
||||||
|
run_request.tools or [],
|
||||||
|
user,
|
||||||
|
db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RunResponse(
|
||||||
|
id=str(new_message.id),
|
||||||
|
object="thread.run",
|
||||||
|
created_at=int(new_message.time_sent.timestamp()),
|
||||||
|
assistant_id=run_request.assistant_id,
|
||||||
|
thread_id=chat_session.id,
|
||||||
|
status="queued",
|
||||||
|
model=run_request.model or "default_model",
|
||||||
|
instructions=run_request.instructions or "",
|
||||||
|
tools=run_request.tools or [],
|
||||||
|
file_ids=[],
|
||||||
|
metadata=run_request.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/threads/{thread_id}/runs/{run_id}")
|
||||||
|
def retrieve_run(
|
||||||
|
thread_id: UUID,
|
||||||
|
run_id: str,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> RunResponse:
|
||||||
|
# Retrieve the chat message (which represents a "run" in DAnswer)
|
||||||
|
chat_message = get_chat_message(
|
||||||
|
chat_message_id=int(run_id), # Convert string run_id to int
|
||||||
|
user_id=user.id if user else None,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
if not chat_message:
|
||||||
|
raise HTTPException(status_code=404, detail="Run not found")
|
||||||
|
|
||||||
|
chat_session = chat_message.chat_session
|
||||||
|
|
||||||
|
# Map DAnswer status to OpenAI status
|
||||||
|
run_status: RunStatus = "queued"
|
||||||
|
if chat_message.message:
|
||||||
|
run_status = "in_progress"
|
||||||
|
if chat_message.token_count != 0:
|
||||||
|
run_status = "completed"
|
||||||
|
if chat_message.error:
|
||||||
|
run_status = "cancelled"
|
||||||
|
|
||||||
|
return RunResponse(
|
||||||
|
id=run_id,
|
||||||
|
object="thread.run",
|
||||||
|
created_at=int(chat_message.time_sent.timestamp()),
|
||||||
|
assistant_id=chat_session.persona_id or 0,
|
||||||
|
thread_id=chat_session.id,
|
||||||
|
status=run_status,
|
||||||
|
started_at=int(chat_message.time_sent.timestamp()),
|
||||||
|
completed_at=(
|
||||||
|
int(chat_message.time_sent.timestamp()) if chat_message.message else None
|
||||||
|
),
|
||||||
|
model=chat_session.current_alternate_model or "default_model",
|
||||||
|
instructions="", # DAnswer doesn't store per-message instructions
|
||||||
|
tools=[], # DAnswer doesn't have a direct equivalent for tools
|
||||||
|
file_ids=(
|
||||||
|
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||||
|
),
|
||||||
|
metadata=None, # DAnswer doesn't store metadata for individual messages
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
|
||||||
|
def cancel_run(
|
||||||
|
thread_id: UUID,
|
||||||
|
run_id: str,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> RunResponse:
|
||||||
|
# In DAnswer, we don't have a direct equivalent to cancelling a run
|
||||||
|
# We'll simulate it by marking the message as "cancelled"
|
||||||
|
chat_message = (
|
||||||
|
db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first()
|
||||||
|
)
|
||||||
|
if not chat_message:
|
||||||
|
raise HTTPException(status_code=404, detail="Run not found")
|
||||||
|
|
||||||
|
chat_message.error = "Cancelled"
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
return retrieve_run(thread_id, run_id, user, db_session)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/threads/{thread_id}/runs")
|
||||||
|
def list_runs(
|
||||||
|
thread_id: UUID,
|
||||||
|
limit: int = 20,
|
||||||
|
order: Literal["asc", "desc"] = "desc",
|
||||||
|
after: Optional[str] = None,
|
||||||
|
before: Optional[str] = None,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> list[RunResponse]:
|
||||||
|
# In DAnswer, we'll treat each message in a chat session as a "run"
|
||||||
|
chat_messages = get_chat_messages_by_session(
|
||||||
|
chat_session_id=thread_id,
|
||||||
|
user_id=user.id if user else None,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply pagination
|
||||||
|
if after:
|
||||||
|
chat_messages = [msg for msg in chat_messages if str(msg.id) > after]
|
||||||
|
if before:
|
||||||
|
chat_messages = [msg for msg in chat_messages if str(msg.id) < before]
|
||||||
|
|
||||||
|
# Apply ordering
|
||||||
|
chat_messages = sorted(
|
||||||
|
chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply limit
|
||||||
|
chat_messages = chat_messages[:limit]
|
||||||
|
|
||||||
|
return [
|
||||||
|
retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/threads/{thread_id}/runs/{run_id}/steps")
|
||||||
|
def list_run_steps(
|
||||||
|
run_id: str,
|
||||||
|
limit: int = 20,
|
||||||
|
order: Literal["asc", "desc"] = "desc",
|
||||||
|
after: Optional[str] = None,
|
||||||
|
before: Optional[str] = None,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> list[dict]: # You may want to create a specific model for run steps
|
||||||
|
# DAnswer doesn't have an equivalent to run steps
|
||||||
|
# We'll return an empty list to maintain API compatibility
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# Additional helper functions can be added here if needed
|
156
backend/danswer/server/openai_assistants_api/threads_api.py
Normal file
156
backend/danswer/server/openai_assistants_api/threads_api.py
Normal file
@ -0,0 +1,156 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi import Depends
|
||||||
|
from fastapi import HTTPException
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.auth.users import current_user
|
||||||
|
from danswer.db.chat import create_chat_session
|
||||||
|
from danswer.db.chat import delete_chat_session
|
||||||
|
from danswer.db.chat import get_chat_session_by_id
|
||||||
|
from danswer.db.chat import get_chat_sessions_by_user
|
||||||
|
from danswer.db.chat import update_chat_session
|
||||||
|
from danswer.db.engine import get_session
|
||||||
|
from danswer.db.models import User
|
||||||
|
from danswer.server.query_and_chat.models import ChatSessionDetails
|
||||||
|
from danswer.server.query_and_chat.models import ChatSessionsResponse
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/threads")
|
||||||
|
|
||||||
|
|
||||||
|
# Models
|
||||||
|
class Thread(BaseModel):
|
||||||
|
id: UUID
|
||||||
|
object: str = "thread"
|
||||||
|
created_at: int
|
||||||
|
metadata: Optional[dict[str, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CreateThreadRequest(BaseModel):
|
||||||
|
messages: Optional[list[dict]] = None
|
||||||
|
metadata: Optional[dict[str, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModifyThreadRequest(BaseModel):
|
||||||
|
metadata: Optional[dict[str, str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
# API Endpoints
|
||||||
|
@router.post("")
|
||||||
|
def create_thread(
|
||||||
|
request: CreateThreadRequest,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> Thread:
|
||||||
|
user_id = user.id if user else None
|
||||||
|
new_chat_session = create_chat_session(
|
||||||
|
db_session=db_session,
|
||||||
|
description="", # Leave the naming till later to prevent delay
|
||||||
|
user_id=user_id,
|
||||||
|
persona_id=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Thread(
|
||||||
|
id=new_chat_session.id,
|
||||||
|
created_at=int(new_chat_session.time_created.timestamp()),
|
||||||
|
metadata=request.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{thread_id}")
|
||||||
|
def retrieve_thread(
|
||||||
|
thread_id: UUID,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> Thread:
|
||||||
|
user_id = user.id if user else None
|
||||||
|
try:
|
||||||
|
chat_session = get_chat_session_by_id(
|
||||||
|
chat_session_id=thread_id,
|
||||||
|
user_id=user_id,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
return Thread(
|
||||||
|
id=chat_session.id,
|
||||||
|
created_at=int(chat_session.time_created.timestamp()),
|
||||||
|
metadata=None, # Assuming we don't store metadata in our current implementation
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{thread_id}")
|
||||||
|
def modify_thread(
|
||||||
|
thread_id: UUID,
|
||||||
|
request: ModifyThreadRequest,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> Thread:
|
||||||
|
user_id = user.id if user else None
|
||||||
|
try:
|
||||||
|
chat_session = update_chat_session(
|
||||||
|
db_session=db_session,
|
||||||
|
user_id=user_id,
|
||||||
|
chat_session_id=thread_id,
|
||||||
|
description=None, # Not updating description
|
||||||
|
sharing_status=None, # Not updating sharing status
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
return Thread(
|
||||||
|
id=chat_session.id,
|
||||||
|
created_at=int(chat_session.time_created.timestamp()),
|
||||||
|
metadata=request.metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{thread_id}")
|
||||||
|
def delete_thread(
|
||||||
|
thread_id: UUID,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> dict:
|
||||||
|
user_id = user.id if user else None
|
||||||
|
try:
|
||||||
|
delete_chat_session(
|
||||||
|
user_id=user_id,
|
||||||
|
chat_session_id=thread_id,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=404, detail="Thread not found")
|
||||||
|
|
||||||
|
return {"id": str(thread_id), "object": "thread.deleted", "deleted": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
def list_threads(
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> ChatSessionsResponse:
|
||||||
|
user_id = user.id if user else None
|
||||||
|
chat_sessions = get_chat_sessions_by_user(
|
||||||
|
user_id=user_id,
|
||||||
|
deleted=False,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatSessionsResponse(
|
||||||
|
sessions=[
|
||||||
|
ChatSessionDetails(
|
||||||
|
id=chat.id,
|
||||||
|
name=chat.description,
|
||||||
|
persona_id=chat.persona_id,
|
||||||
|
time_created=chat.time_created.isoformat(),
|
||||||
|
shared_status=chat.shared_status,
|
||||||
|
folder_id=chat.folder_id,
|
||||||
|
current_alternate_model=chat.current_alternate_model,
|
||||||
|
)
|
||||||
|
for chat in chat_sessions
|
||||||
|
]
|
||||||
|
)
|
@ -347,7 +347,6 @@ def handle_new_chat_message(
|
|||||||
for packet in stream_chat_message(
|
for packet in stream_chat_message(
|
||||||
new_msg_req=chat_message_req,
|
new_msg_req=chat_message_req,
|
||||||
user=user,
|
user=user,
|
||||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
|
||||||
litellm_additional_headers=extract_headers(
|
litellm_additional_headers=extract_headers(
|
||||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||||
),
|
),
|
||||||
|
@ -108,6 +108,9 @@ class CreateChatMessageRequest(ChunkContext):
|
|||||||
# used for seeded chats to kick off the generation of an AI answer
|
# used for seeded chats to kick off the generation of an AI answer
|
||||||
use_existing_user_message: bool = False
|
use_existing_user_message: bool = False
|
||||||
|
|
||||||
|
# used for "OpenAI Assistants API"
|
||||||
|
existing_assistant_message_id: int | None = None
|
||||||
|
|
||||||
# forces the LLM to return a structured response, see
|
# forces the LLM to return a structured response, see
|
||||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||||
structured_response_format: dict | None = None
|
structured_response_format: dict | None = None
|
||||||
|
255
backend/danswer/tools/tool_constructor.py
Normal file
255
backend/danswer/tools/tool_constructor.py
Normal file
@ -0,0 +1,255 @@
|
|||||||
|
from typing import cast
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from pydantic import Field
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||||
|
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||||
|
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||||
|
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||||
|
from danswer.configs.chat_configs import BING_API_KEY
|
||||||
|
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
|
from danswer.db.llm import fetch_existing_llm_providers
|
||||||
|
from danswer.db.models import Persona
|
||||||
|
from danswer.db.models import User
|
||||||
|
from danswer.file_store.models import InMemoryChatFile
|
||||||
|
from danswer.llm.answering.models import AnswerStyleConfig
|
||||||
|
from danswer.llm.answering.models import CitationConfig
|
||||||
|
from danswer.llm.answering.models import DocumentPruningConfig
|
||||||
|
from danswer.llm.answering.models import PromptConfig
|
||||||
|
from danswer.llm.interfaces import LLM
|
||||||
|
from danswer.llm.interfaces import LLMConfig
|
||||||
|
from danswer.natural_language_processing.utils import get_tokenizer
|
||||||
|
from danswer.search.enums import LLMEvaluationType
|
||||||
|
from danswer.search.models import InferenceSection
|
||||||
|
from danswer.search.models import RetrievalDetails
|
||||||
|
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||||
|
from danswer.tools.models import DynamicSchemaInfo
|
||||||
|
from danswer.tools.tool import Tool
|
||||||
|
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||||
|
build_custom_tools_from_openapi_schema_and_headers,
|
||||||
|
)
|
||||||
|
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||||
|
ImageGenerationTool,
|
||||||
|
)
|
||||||
|
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||||
|
InternetSearchTool,
|
||||||
|
)
|
||||||
|
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||||
|
from danswer.tools.utils import compute_all_tool_tokens
|
||||||
|
from danswer.tools.utils import explicit_tool_calling_supported
|
||||||
|
from danswer.utils.headers import header_dict_to_header_list
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||||
|
"""Helper function to get image generation LLM config based on available providers"""
|
||||||
|
if llm and llm.config.api_key and llm.config.model_provider == "openai":
|
||||||
|
return LLMConfig(
|
||||||
|
model_provider=llm.config.model_provider,
|
||||||
|
model_name="dall-e-3",
|
||||||
|
temperature=GEN_AI_TEMPERATURE,
|
||||||
|
api_key=llm.config.api_key,
|
||||||
|
api_base=llm.config.api_base,
|
||||||
|
api_version=llm.config.api_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
if llm.config.model_provider == "azure" and AZURE_DALLE_API_KEY is not None:
|
||||||
|
return LLMConfig(
|
||||||
|
model_provider="azure",
|
||||||
|
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||||
|
temperature=GEN_AI_TEMPERATURE,
|
||||||
|
api_key=AZURE_DALLE_API_KEY,
|
||||||
|
api_base=AZURE_DALLE_API_BASE,
|
||||||
|
api_version=AZURE_DALLE_API_VERSION,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback to checking for OpenAI provider in database
|
||||||
|
llm_providers = fetch_existing_llm_providers(db_session)
|
||||||
|
openai_provider = next(
|
||||||
|
iter(
|
||||||
|
[
|
||||||
|
llm_provider
|
||||||
|
for llm_provider in llm_providers
|
||||||
|
if llm_provider.provider == "openai"
|
||||||
|
]
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not openai_provider or not openai_provider.api_key:
|
||||||
|
raise ValueError("Image generation tool requires an OpenAI API key")
|
||||||
|
|
||||||
|
return LLMConfig(
|
||||||
|
model_provider=openai_provider.provider,
|
||||||
|
model_name="dall-e-3",
|
||||||
|
temperature=GEN_AI_TEMPERATURE,
|
||||||
|
api_key=openai_provider.api_key,
|
||||||
|
api_base=openai_provider.api_base,
|
||||||
|
api_version=openai_provider.api_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SearchToolConfig(BaseModel):
|
||||||
|
answer_style_config: AnswerStyleConfig = Field(
|
||||||
|
default_factory=lambda: AnswerStyleConfig(citation_config=CitationConfig())
|
||||||
|
)
|
||||||
|
document_pruning_config: DocumentPruningConfig = Field(
|
||||||
|
default_factory=DocumentPruningConfig
|
||||||
|
)
|
||||||
|
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||||
|
selected_sections: list[InferenceSection] | None = None
|
||||||
|
chunks_above: int = 0
|
||||||
|
chunks_below: int = 0
|
||||||
|
full_doc: bool = False
|
||||||
|
latest_query_files: list[InMemoryChatFile] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class InternetSearchToolConfig(BaseModel):
|
||||||
|
answer_style_config: AnswerStyleConfig = Field(
|
||||||
|
default_factory=lambda: AnswerStyleConfig(
|
||||||
|
citation_config=CitationConfig(all_docs_useful=True)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageGenerationToolConfig(BaseModel):
|
||||||
|
additional_headers: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CustomToolConfig(BaseModel):
|
||||||
|
chat_session_id: UUID | None = None
|
||||||
|
message_id: int | None = None
|
||||||
|
additional_headers: dict[str, str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def construct_tools(
|
||||||
|
persona: Persona,
|
||||||
|
prompt_config: PromptConfig,
|
||||||
|
db_session: Session,
|
||||||
|
user: User | None,
|
||||||
|
llm: LLM,
|
||||||
|
fast_llm: LLM,
|
||||||
|
search_tool_config: SearchToolConfig | None = None,
|
||||||
|
internet_search_tool_config: InternetSearchToolConfig | None = None,
|
||||||
|
image_generation_tool_config: ImageGenerationToolConfig | None = None,
|
||||||
|
custom_tool_config: CustomToolConfig | None = None,
|
||||||
|
) -> dict[int, list[Tool]]:
|
||||||
|
"""Constructs tools based on persona configuration and available APIs"""
|
||||||
|
tool_dict: dict[int, list[Tool]] = {}
|
||||||
|
|
||||||
|
for db_tool_model in persona.tools:
|
||||||
|
if db_tool_model.in_code_tool_id:
|
||||||
|
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||||
|
|
||||||
|
# Handle Search Tool
|
||||||
|
if tool_cls.__name__ == SearchTool.__name__:
|
||||||
|
if not search_tool_config:
|
||||||
|
search_tool_config = SearchToolConfig()
|
||||||
|
|
||||||
|
search_tool = SearchTool(
|
||||||
|
db_session=db_session,
|
||||||
|
user=user,
|
||||||
|
persona=persona,
|
||||||
|
retrieval_options=search_tool_config.retrieval_options,
|
||||||
|
prompt_config=prompt_config,
|
||||||
|
llm=llm,
|
||||||
|
fast_llm=fast_llm,
|
||||||
|
pruning_config=search_tool_config.document_pruning_config,
|
||||||
|
answer_style_config=search_tool_config.answer_style_config,
|
||||||
|
selected_sections=search_tool_config.selected_sections,
|
||||||
|
chunks_above=search_tool_config.chunks_above,
|
||||||
|
chunks_below=search_tool_config.chunks_below,
|
||||||
|
full_doc=search_tool_config.full_doc,
|
||||||
|
evaluation_type=(
|
||||||
|
LLMEvaluationType.BASIC
|
||||||
|
if persona.llm_relevance_filter
|
||||||
|
else LLMEvaluationType.SKIP
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tool_dict[db_tool_model.id] = [search_tool]
|
||||||
|
|
||||||
|
# Handle Image Generation Tool
|
||||||
|
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||||
|
if not image_generation_tool_config:
|
||||||
|
image_generation_tool_config = ImageGenerationToolConfig()
|
||||||
|
|
||||||
|
img_generation_llm_config = _get_image_generation_config(
|
||||||
|
llm, db_session
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_dict[db_tool_model.id] = [
|
||||||
|
ImageGenerationTool(
|
||||||
|
api_key=cast(str, img_generation_llm_config.api_key),
|
||||||
|
api_base=img_generation_llm_config.api_base,
|
||||||
|
api_version=img_generation_llm_config.api_version,
|
||||||
|
additional_headers=image_generation_tool_config.additional_headers,
|
||||||
|
model=img_generation_llm_config.model_name,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Handle Internet Search Tool
|
||||||
|
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||||
|
if not internet_search_tool_config:
|
||||||
|
internet_search_tool_config = InternetSearchToolConfig()
|
||||||
|
|
||||||
|
if not BING_API_KEY:
|
||||||
|
raise ValueError(
|
||||||
|
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||||
|
)
|
||||||
|
tool_dict[db_tool_model.id] = [
|
||||||
|
InternetSearchTool(
|
||||||
|
api_key=BING_API_KEY,
|
||||||
|
answer_style_config=internet_search_tool_config.answer_style_config,
|
||||||
|
prompt_config=prompt_config,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Handle custom tools
|
||||||
|
elif db_tool_model.openapi_schema:
|
||||||
|
if not custom_tool_config:
|
||||||
|
custom_tool_config = CustomToolConfig()
|
||||||
|
|
||||||
|
tool_dict[db_tool_model.id] = cast(
|
||||||
|
list[Tool],
|
||||||
|
build_custom_tools_from_openapi_schema_and_headers(
|
||||||
|
db_tool_model.openapi_schema,
|
||||||
|
dynamic_schema_info=DynamicSchemaInfo(
|
||||||
|
chat_session_id=custom_tool_config.chat_session_id,
|
||||||
|
message_id=custom_tool_config.message_id,
|
||||||
|
),
|
||||||
|
custom_headers=(db_tool_model.custom_headers or [])
|
||||||
|
+ (
|
||||||
|
header_dict_to_header_list(
|
||||||
|
custom_tool_config.additional_headers or {}
|
||||||
|
)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tools: list[Tool] = []
|
||||||
|
for tool_list in tool_dict.values():
|
||||||
|
tools.extend(tool_list)
|
||||||
|
|
||||||
|
# factor in tool definition size when pruning
|
||||||
|
if search_tool_config:
|
||||||
|
search_tool_config.document_pruning_config.tool_num_tokens = (
|
||||||
|
compute_all_tool_tokens(
|
||||||
|
tools,
|
||||||
|
get_tokenizer(
|
||||||
|
model_name=llm.config.model_name,
|
||||||
|
provider_type=llm.config.model_provider,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
search_tool_config.document_pruning_config.using_tool_message = (
|
||||||
|
explicit_tool_calling_supported(
|
||||||
|
llm.config.model_provider, llm.config.model_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_dict
|
@ -13,6 +13,14 @@ from tests.integration.common_utils.constants import GENERAL_HEADERS
|
|||||||
from tests.integration.common_utils.test_models import DATestUser
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
|
||||||
|
DOMAIN = "test.com"
|
||||||
|
DEFAULT_PASSWORD = "test"
|
||||||
|
|
||||||
|
|
||||||
|
def build_email(name: str) -> str:
|
||||||
|
return f"{name}@test.com"
|
||||||
|
|
||||||
|
|
||||||
class UserManager:
|
class UserManager:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(
|
def create(
|
||||||
@ -23,9 +31,9 @@ class UserManager:
|
|||||||
name = f"test{str(uuid4())}"
|
name = f"test{str(uuid4())}"
|
||||||
|
|
||||||
if email is None:
|
if email is None:
|
||||||
email = f"{name}@test.com"
|
email = build_email(name)
|
||||||
|
|
||||||
password = "test"
|
password = DEFAULT_PASSWORD
|
||||||
|
|
||||||
body = {
|
body = {
|
||||||
"email": email,
|
"email": email,
|
||||||
|
55
backend/tests/integration/openai_assistants_api/conftest.py
Normal file
55
backend/tests/integration/openai_assistants_api/conftest.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from typing import Optional
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||||
|
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||||
|
from tests.integration.common_utils.managers.user import build_email
|
||||||
|
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||||
|
from tests.integration.common_utils.managers.user import UserManager
|
||||||
|
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def admin_user() -> DATestUser | None:
|
||||||
|
try:
|
||||||
|
return UserManager.create("admin_user")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
return UserManager.login_as_user(
|
||||||
|
DATestUser(
|
||||||
|
id="",
|
||||||
|
email=build_email("admin_user"),
|
||||||
|
password=DEFAULT_PASSWORD,
|
||||||
|
headers=GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
|
||||||
|
return LLMProviderManager.create(user_performing_action=admin_user)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def thread_id(admin_user: Optional[DATestUser]) -> UUID:
|
||||||
|
# Create a thread to use in the tests
|
||||||
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/threads", # Updated endpoint path
|
||||||
|
json={},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
return UUID(response.json()["id"])
|
@ -0,0 +1,151 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
ASSISTANTS_URL = f"{API_SERVER_URL}/openai-assistants/assistants"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_assistant(admin_user: DATestUser | None) -> None:
|
||||||
|
response = requests.post(
|
||||||
|
ASSISTANTS_URL,
|
||||||
|
json={
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"name": "Test Assistant",
|
||||||
|
"description": "A test assistant",
|
||||||
|
"instructions": "You are a helpful assistant.",
|
||||||
|
},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["name"] == "Test Assistant"
|
||||||
|
assert data["description"] == "A test assistant"
|
||||||
|
assert data["model"] == "gpt-3.5-turbo"
|
||||||
|
assert data["instructions"] == "You are a helpful assistant."
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_assistant(admin_user: DATestUser | None) -> None:
|
||||||
|
# First, create an assistant
|
||||||
|
create_response = requests.post(
|
||||||
|
ASSISTANTS_URL,
|
||||||
|
json={"model": "gpt-3.5-turbo", "name": "Retrieve Test"},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert create_response.status_code == 200
|
||||||
|
assistant_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Now, retrieve the assistant
|
||||||
|
response = requests.get(
|
||||||
|
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == assistant_id
|
||||||
|
assert data["name"] == "Retrieve Test"
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_assistant(admin_user: DATestUser | None) -> None:
|
||||||
|
# First, create an assistant
|
||||||
|
create_response = requests.post(
|
||||||
|
ASSISTANTS_URL,
|
||||||
|
json={"model": "gpt-3.5-turbo", "name": "Modify Test"},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert create_response.status_code == 200
|
||||||
|
assistant_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Now, modify the assistant
|
||||||
|
response = requests.post(
|
||||||
|
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||||
|
json={"name": "Modified Assistant", "instructions": "New instructions"},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == assistant_id
|
||||||
|
assert data["name"] == "Modified Assistant"
|
||||||
|
assert data["instructions"] == "New instructions"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_assistant(admin_user: DATestUser | None) -> None:
|
||||||
|
# First, create an assistant
|
||||||
|
create_response = requests.post(
|
||||||
|
ASSISTANTS_URL,
|
||||||
|
json={"model": "gpt-3.5-turbo", "name": "Delete Test"},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert create_response.status_code == 200
|
||||||
|
assistant_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Now, delete the assistant
|
||||||
|
response = requests.delete(
|
||||||
|
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["id"] == assistant_id
|
||||||
|
assert data["deleted"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assistants(admin_user: DATestUser | None) -> None:
|
||||||
|
# Create multiple assistants
|
||||||
|
for i in range(3):
|
||||||
|
requests.post(
|
||||||
|
ASSISTANTS_URL,
|
||||||
|
json={"model": "gpt-3.5-turbo", "name": f"List Test {i}"},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now, list the assistants
|
||||||
|
response = requests.get(
|
||||||
|
ASSISTANTS_URL,
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["object"] == "list"
|
||||||
|
assert len(data["data"]) >= 3 # At least the 3 we just created
|
||||||
|
assert all(assistant["object"] == "assistant" for assistant in data["data"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assistants_pagination(admin_user: DATestUser | None) -> None:
|
||||||
|
# Create 5 assistants
|
||||||
|
for i in range(5):
|
||||||
|
requests.post(
|
||||||
|
ASSISTANTS_URL,
|
||||||
|
json={"model": "gpt-3.5-turbo", "name": f"Pagination Test {i}"},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# List assistants with limit
|
||||||
|
response = requests.get(
|
||||||
|
f"{ASSISTANTS_URL}?limit=2",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["data"]) == 2
|
||||||
|
assert data["has_more"] is True
|
||||||
|
|
||||||
|
# Get next page
|
||||||
|
before = data["last_id"]
|
||||||
|
response = requests.get(
|
||||||
|
f"{ASSISTANTS_URL}?limit=2&before={before}",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert len(data["data"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_assistant_not_found(admin_user: DATestUser | None) -> None:
|
||||||
|
non_existent_id = -99
|
||||||
|
response = requests.get(
|
||||||
|
f"{ASSISTANTS_URL}/{non_existent_id}",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
133
backend/tests/integration/openai_assistants_api/test_messages.py
Normal file
133
backend/tests/integration/openai_assistants_api/test_messages.py
Normal file
@ -0,0 +1,133 @@
|
|||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
BASE_URL = f"{API_SERVER_URL}/openai-assistants/threads"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def thread_id(admin_user: Optional[DATestUser]) -> str:
|
||||||
|
response = requests.post(
|
||||||
|
BASE_URL,
|
||||||
|
json={},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
return response.json()["id"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||||
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/{thread_id}/messages", # URL structure matches API
|
||||||
|
json={
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, world!",
|
||||||
|
"file_ids": [],
|
||||||
|
"metadata": {"key": "value"},
|
||||||
|
},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert "id" in response_json
|
||||||
|
assert response_json["thread_id"] == thread_id
|
||||||
|
assert response_json["role"] == "user"
|
||||||
|
assert response_json["content"] == [{"type": "text", "text": "Hello, world!"}]
|
||||||
|
assert response_json["metadata"] == {"key": "value"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_messages(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||||
|
# Create a message first
|
||||||
|
requests.post(
|
||||||
|
f"{BASE_URL}/{thread_id}/messages",
|
||||||
|
json={"role": "user", "content": "Test message"},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now, list the messages
|
||||||
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/{thread_id}/messages",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["object"] == "list"
|
||||||
|
assert isinstance(response_json["data"], list)
|
||||||
|
assert len(response_json["data"]) > 0
|
||||||
|
assert "first_id" in response_json
|
||||||
|
assert "last_id" in response_json
|
||||||
|
assert "has_more" in response_json
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||||
|
# Create a message first
|
||||||
|
create_response = requests.post(
|
||||||
|
f"{BASE_URL}/{thread_id}/messages",
|
||||||
|
json={"role": "user", "content": "Test message"},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
message_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Now, retrieve the message
|
||||||
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/{thread_id}/messages/{message_id}",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["id"] == message_id
|
||||||
|
assert response_json["thread_id"] == thread_id
|
||||||
|
assert response_json["role"] == "user"
|
||||||
|
assert response_json["content"] == [{"type": "text", "text": "Test message"}]
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||||
|
# Create a message first
|
||||||
|
create_response = requests.post(
|
||||||
|
f"{BASE_URL}/{thread_id}/messages",
|
||||||
|
json={"role": "user", "content": "Test message"},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
message_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Now, modify the message
|
||||||
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/{thread_id}/messages/{message_id}",
|
||||||
|
json={"metadata": {"new_key": "new_value"}},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert response_json["id"] == message_id
|
||||||
|
assert response_json["thread_id"] == thread_id
|
||||||
|
assert response_json["metadata"] == {"new_key": "new_value"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_handling(admin_user: Optional[DATestUser]) -> None:
|
||||||
|
non_existent_thread_id = str(uuid.uuid4())
|
||||||
|
non_existent_message_id = -99
|
||||||
|
|
||||||
|
# Test with non-existent thread
|
||||||
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/{non_existent_thread_id}/messages",
|
||||||
|
json={"role": "user", "content": "Test message"},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
# Test with non-existent message
|
||||||
|
response = requests.get(
|
||||||
|
f"{BASE_URL}/{non_existent_thread_id}/messages/{non_existent_message_id}",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 404
|
137
backend/tests/integration/openai_assistants_api/test_runs.py
Normal file
137
backend/tests/integration/openai_assistants_api/test_runs.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||||
|
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def run_id(admin_user: DATestUser | None, thread_id: UUID) -> str:
|
||||||
|
"""Create a run and return its ID."""
|
||||||
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||||
|
json={
|
||||||
|
"assistant_id": 0,
|
||||||
|
},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
return response.json()["id"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_run(
|
||||||
|
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
|
||||||
|
) -> None:
|
||||||
|
response = requests.post(
|
||||||
|
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||||
|
json={
|
||||||
|
"assistant_id": 0,
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"instructions": "Test instructions",
|
||||||
|
},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert "id" in response_json
|
||||||
|
assert response_json["object"] == "thread.run"
|
||||||
|
assert "created_at" in response_json
|
||||||
|
assert response_json["assistant_id"] == 0
|
||||||
|
assert UUID(response_json["thread_id"]) == thread_id
|
||||||
|
assert response_json["status"] == "queued"
|
||||||
|
assert response_json["model"] == "gpt-3.5-turbo"
|
||||||
|
assert response_json["instructions"] == "Test instructions"
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_run(
|
||||||
|
admin_user: DATestUser | None,
|
||||||
|
thread_id: UUID,
|
||||||
|
run_id: str,
|
||||||
|
llm_provider: DATestLLMProvider,
|
||||||
|
) -> None:
|
||||||
|
retrieve_response = requests.get(
|
||||||
|
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert retrieve_response.status_code == 200
|
||||||
|
|
||||||
|
response_json = retrieve_response.json()
|
||||||
|
assert response_json["id"] == run_id
|
||||||
|
assert response_json["object"] == "thread.run"
|
||||||
|
assert "created_at" in response_json
|
||||||
|
assert UUID(response_json["thread_id"]) == thread_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_run(
|
||||||
|
admin_user: DATestUser | None,
|
||||||
|
thread_id: UUID,
|
||||||
|
run_id: str,
|
||||||
|
llm_provider: DATestLLMProvider,
|
||||||
|
) -> None:
|
||||||
|
cancel_response = requests.post(
|
||||||
|
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/cancel",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert cancel_response.status_code == 200
|
||||||
|
|
||||||
|
response_json = cancel_response.json()
|
||||||
|
assert response_json["id"] == run_id
|
||||||
|
assert response_json["status"] == "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_runs(
|
||||||
|
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
|
||||||
|
) -> None:
|
||||||
|
# Create a few runs
|
||||||
|
for _ in range(3):
|
||||||
|
requests.post(
|
||||||
|
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||||
|
json={
|
||||||
|
"assistant_id": 0,
|
||||||
|
},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now, list the runs
|
||||||
|
list_response = requests.get(
|
||||||
|
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert list_response.status_code == 200
|
||||||
|
|
||||||
|
response_json = list_response.json()
|
||||||
|
assert isinstance(response_json, list)
|
||||||
|
assert len(response_json) >= 3
|
||||||
|
|
||||||
|
for run in response_json:
|
||||||
|
assert "id" in run
|
||||||
|
assert run["object"] == "thread.run"
|
||||||
|
assert "created_at" in run
|
||||||
|
assert UUID(run["thread_id"]) == thread_id
|
||||||
|
assert "status" in run
|
||||||
|
assert "model" in run
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_run_steps(
|
||||||
|
admin_user: DATestUser | None,
|
||||||
|
thread_id: UUID,
|
||||||
|
run_id: str,
|
||||||
|
llm_provider: DATestLLMProvider,
|
||||||
|
) -> None:
|
||||||
|
steps_response = requests.get(
|
||||||
|
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/steps",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert steps_response.status_code == 200
|
||||||
|
|
||||||
|
response_json = steps_response.json()
|
||||||
|
assert isinstance(response_json, list)
|
||||||
|
# Since DAnswer doesn't have an equivalent to run steps, we expect an empty list
|
||||||
|
assert len(response_json) == 0
|
132
backend/tests/integration/openai_assistants_api/test_threads.py
Normal file
132
backend/tests/integration/openai_assistants_api/test_threads.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from danswer.db.models import ChatSessionSharedStatus
|
||||||
|
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||||
|
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||||
|
from tests.integration.common_utils.test_models import DATestUser
|
||||||
|
|
||||||
|
THREADS_URL = f"{API_SERVER_URL}/openai-assistants/threads"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_thread(admin_user: DATestUser | None) -> None:
|
||||||
|
response = requests.post(
|
||||||
|
THREADS_URL,
|
||||||
|
json={"messages": None, "metadata": {"key": "value"}},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
response_json = response.json()
|
||||||
|
assert "id" in response_json
|
||||||
|
assert response_json["object"] == "thread"
|
||||||
|
assert "created_at" in response_json
|
||||||
|
assert response_json["metadata"] == {"key": "value"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieve_thread(admin_user: DATestUser | None) -> None:
|
||||||
|
# First, create a thread
|
||||||
|
create_response = requests.post(
|
||||||
|
THREADS_URL,
|
||||||
|
json={},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert create_response.status_code == 200
|
||||||
|
thread_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Now, retrieve the thread
|
||||||
|
retrieve_response = requests.get(
|
||||||
|
f"{THREADS_URL}/{thread_id}",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert retrieve_response.status_code == 200
|
||||||
|
|
||||||
|
response_json = retrieve_response.json()
|
||||||
|
assert response_json["id"] == thread_id
|
||||||
|
assert response_json["object"] == "thread"
|
||||||
|
assert "created_at" in response_json
|
||||||
|
|
||||||
|
|
||||||
|
def test_modify_thread(admin_user: DATestUser | None) -> None:
|
||||||
|
# First, create a thread
|
||||||
|
create_response = requests.post(
|
||||||
|
THREADS_URL,
|
||||||
|
json={},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert create_response.status_code == 200
|
||||||
|
thread_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Now, modify the thread
|
||||||
|
modify_response = requests.post(
|
||||||
|
f"{THREADS_URL}/{thread_id}",
|
||||||
|
json={"metadata": {"new_key": "new_value"}},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert modify_response.status_code == 200
|
||||||
|
|
||||||
|
response_json = modify_response.json()
|
||||||
|
assert response_json["id"] == thread_id
|
||||||
|
assert response_json["metadata"] == {"new_key": "new_value"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_thread(admin_user: DATestUser | None) -> None:
|
||||||
|
# First, create a thread
|
||||||
|
create_response = requests.post(
|
||||||
|
THREADS_URL,
|
||||||
|
json={},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert create_response.status_code == 200
|
||||||
|
thread_id = create_response.json()["id"]
|
||||||
|
|
||||||
|
# Now, delete the thread
|
||||||
|
delete_response = requests.delete(
|
||||||
|
f"{THREADS_URL}/{thread_id}",
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert delete_response.status_code == 200
|
||||||
|
|
||||||
|
response_json = delete_response.json()
|
||||||
|
assert response_json["id"] == thread_id
|
||||||
|
assert response_json["object"] == "thread.deleted"
|
||||||
|
assert response_json["deleted"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_threads(admin_user: DATestUser | None) -> None:
|
||||||
|
# Create a few threads
|
||||||
|
for _ in range(3):
|
||||||
|
requests.post(
|
||||||
|
THREADS_URL,
|
||||||
|
json={},
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now, list the threads
|
||||||
|
list_response = requests.get(
|
||||||
|
THREADS_URL,
|
||||||
|
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||||
|
)
|
||||||
|
assert list_response.status_code == 200
|
||||||
|
|
||||||
|
response_json = list_response.json()
|
||||||
|
assert "sessions" in response_json
|
||||||
|
assert len(response_json["sessions"]) >= 3
|
||||||
|
|
||||||
|
for session in response_json["sessions"]:
|
||||||
|
assert "id" in session
|
||||||
|
assert "name" in session
|
||||||
|
assert "persona_id" in session
|
||||||
|
assert "time_created" in session
|
||||||
|
assert "shared_status" in session
|
||||||
|
assert "folder_id" in session
|
||||||
|
assert "current_alternate_model" in session
|
||||||
|
|
||||||
|
# Validate UUID
|
||||||
|
UUID(session["id"])
|
||||||
|
|
||||||
|
# Validate shared_status
|
||||||
|
assert session["shared_status"] in [
|
||||||
|
status.value for status in ChatSessionSharedStatus
|
||||||
|
]
|
125
examples/assistants-api/topics_analyzer.py
Normal file
125
examples/assistants-api/topics_analyzer.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from datetime import timedelta
|
||||||
|
from datetime import timezone
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
|
||||||
|
ASSISTANT_NAME = "Topic Analyzer"
|
||||||
|
SYSTEM_PROMPT = """
|
||||||
|
You are a helpful assistant that analyzes topics by searching through available \
|
||||||
|
documents and providing insights. These available documents come from common \
|
||||||
|
workplace tools like Slack, emails, Confluence, Google Drive, etc.
|
||||||
|
|
||||||
|
When analyzing a topic:
|
||||||
|
1. Search for relevant information using the search tool
|
||||||
|
2. Synthesize the findings into clear insights
|
||||||
|
3. Highlight key trends, patterns, or notable developments
|
||||||
|
4. Maintain objectivity and cite sources where relevant
|
||||||
|
"""
|
||||||
|
USER_PROMPT = """
|
||||||
|
Please analyze and provide insights about this topic: {topic}.
|
||||||
|
|
||||||
|
IMPORTANT: do not mention things that are not relevant to the specified topic. \
|
||||||
|
If there is no relevant information, just say "No relevant information found."
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def wait_on_run(client: OpenAI, run, thread):
|
||||||
|
while run.status == "queued" or run.status == "in_progress":
|
||||||
|
run = client.beta.threads.runs.retrieve(
|
||||||
|
thread_id=thread.id,
|
||||||
|
run_id=run.id,
|
||||||
|
)
|
||||||
|
time.sleep(0.5)
|
||||||
|
return run
|
||||||
|
|
||||||
|
|
||||||
|
def show_response(messages) -> None:
|
||||||
|
# Get only the assistant's response text
|
||||||
|
for message in messages.data[::-1]:
|
||||||
|
if message.role == "assistant":
|
||||||
|
for content in message.content:
|
||||||
|
if content.type == "text":
|
||||||
|
print(content.text)
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_topics(topics: list[str]) -> None:
|
||||||
|
openai_api_key = os.environ.get(
|
||||||
|
"OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"
|
||||||
|
)
|
||||||
|
danswer_api_key = os.environ.get(
|
||||||
|
"DANSWER_API_KEY", "<your Danswer API key if not set as env var>"
|
||||||
|
)
|
||||||
|
client = OpenAI(
|
||||||
|
api_key=openai_api_key,
|
||||||
|
base_url="http://localhost:8080/openai-assistants",
|
||||||
|
default_headers={
|
||||||
|
"Authorization": f"Bearer {danswer_api_key}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create an assistant if it doesn't exist
|
||||||
|
try:
|
||||||
|
assistants = client.beta.assistants.list(limit=100)
|
||||||
|
# Find the Topic Analyzer assistant if it exists
|
||||||
|
assistant = next((a for a in assistants.data if a.name == ASSISTANT_NAME))
|
||||||
|
client.beta.assistants.delete(assistant.id)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
assistant = client.beta.assistants.create(
|
||||||
|
name=ASSISTANT_NAME,
|
||||||
|
instructions=SYSTEM_PROMPT,
|
||||||
|
tools=[{"type": "SearchTool"}], # type: ignore
|
||||||
|
model="gpt-4o",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Process each topic individually
|
||||||
|
for topic in topics:
|
||||||
|
thread = client.beta.threads.create()
|
||||||
|
message = client.beta.threads.messages.create(
|
||||||
|
thread_id=thread.id,
|
||||||
|
role="user",
|
||||||
|
content=USER_PROMPT.format(topic=topic),
|
||||||
|
)
|
||||||
|
|
||||||
|
run = client.beta.threads.runs.create(
|
||||||
|
thread_id=thread.id,
|
||||||
|
assistant_id=assistant.id,
|
||||||
|
tools=[
|
||||||
|
{ # type: ignore
|
||||||
|
"type": "SearchTool",
|
||||||
|
"retrieval_details": {
|
||||||
|
"run_search": "always",
|
||||||
|
"filters": {
|
||||||
|
"time_cutoff": str(
|
||||||
|
datetime.now(timezone.utc) - timedelta(days=7)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
run = wait_on_run(client, run, thread)
|
||||||
|
messages = client.beta.threads.messages.list(
|
||||||
|
thread_id=thread.id, order="asc", after=message.id
|
||||||
|
)
|
||||||
|
print(f"\nAnalysis for topic: {topic}")
|
||||||
|
print("-" * 40)
|
||||||
|
show_response(messages)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
# Example usage
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Analyze specific topics")
|
||||||
|
parser.add_argument("topics", nargs="+", help="Topics to analyze (one or more)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
analyze_topics(args.topics)
|
Loading…
x
Reference in New Issue
Block a user