mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-16 23:00:31 +02:00
New assistants api (#3097)
This commit is contained in:
parent
9d57f34c34
commit
ba805f766f
@ -288,6 +288,15 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
|
||||
# below
|
||||
op.execute("DELETE FROM chat_feedback")
|
||||
op.execute("DELETE FROM chat_message__search_doc")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM document_retrieval_feedback")
|
||||
op.execute("DELETE FROM chat_message")
|
||||
op.execute("DELETE FROM chat_session")
|
||||
|
||||
op.drop_constraint(
|
||||
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
|
||||
)
|
||||
|
@ -23,6 +23,56 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Delete chat messages and feedback first since they reference chat sessions
|
||||
# Get chat messages from sessions with null persona_id
|
||||
chat_messages_query = """
|
||||
SELECT id
|
||||
FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
|
||||
# Delete dependent records first
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM document_retrieval_feedback
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
f"""
|
||||
DELETE FROM chat_message__search_doc
|
||||
WHERE chat_message_id IN (
|
||||
{chat_messages_query}
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Delete chat messages
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_message
|
||||
WHERE chat_session_id IN (
|
||||
SELECT id
|
||||
FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Now we can safely delete the chat sessions
|
||||
op.execute(
|
||||
"""
|
||||
DELETE FROM chat_session
|
||||
WHERE persona_id IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
|
@ -19,16 +19,10 @@ from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@ -41,7 +35,6 @@ from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
@ -61,14 +54,13 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
@ -77,14 +69,14 @@ from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.tool_constructor import construct_tools
|
||||
from danswer.tools.tool_constructor import CustomToolConfig
|
||||
from danswer.tools.tool_constructor import ImageGenerationToolConfig
|
||||
from danswer.tools.tool_constructor import InternetSearchToolConfig
|
||||
from danswer.tools.tool_constructor import SearchToolConfig
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CUSTOM_TOOL_RESPONSE_ID,
|
||||
)
|
||||
@ -95,9 +87,6 @@ from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationResponse,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
INTERNET_SEARCH_RESPONSE_ID,
|
||||
)
|
||||
@ -122,9 +111,6 @@ from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.headers import header_dict_to_header_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
@ -295,7 +281,6 @@ def stream_chat_message_objects(
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
@ -307,6 +292,9 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
# Currently surrounding context is not supported for chat
|
||||
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
|
||||
new_msg_req.chunks_above = 0
|
||||
@ -428,12 +416,20 @@ def stream_chat_message_objects(
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
if existing_assistant_message_id is None:
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
else:
|
||||
if final_msg.id != existing_assistant_message_id:
|
||||
raise RuntimeError(
|
||||
"The last message was not the existing assistant message. "
|
||||
f"Final message id: {final_msg.id}, "
|
||||
f"existing assistant message id: {existing_assistant_message_id}"
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
@ -504,13 +500,19 @@ def stream_chat_message_objects(
|
||||
),
|
||||
max_window_percentage=max_document_percentage,
|
||||
)
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
|
||||
# we don't need to reserve a message id if we're using an existing assistant message
|
||||
reserved_message_id = (
|
||||
final_msg.id
|
||||
if existing_assistant_message_id is not None
|
||||
else reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
)
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id if user_message else None,
|
||||
@ -525,7 +527,13 @@ def stream_chat_message_objects(
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=final_msg,
|
||||
# if we're using an existing assistant message, then this will just be an
|
||||
# update operation, in which case the parent should be the parent of
|
||||
# the latest. If we're creating a new assistant message, then the parent
|
||||
# should be the latest message (latest user message)
|
||||
parent_message=(
|
||||
final_msg if existing_assistant_message_id is None else parent_message
|
||||
),
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
# message=,
|
||||
@ -537,6 +545,7 @@ def stream_chat_message_objects(
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
if not final_msg.prompt:
|
||||
@ -560,142 +569,39 @@ def stream_chat_message_objects(
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
)
|
||||
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_style_config,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm.config.api_key,
|
||||
api_base=llm.config.api_base,
|
||||
api_version=llm.config.api_version,
|
||||
)
|
||||
elif (
|
||||
llm.config.model_provider == "azure"
|
||||
and AZURE_DALLE_API_KEY is not None
|
||||
):
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider="azure",
|
||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=AZURE_DALLE_API_KEY,
|
||||
api_base=AZURE_DALLE_API_BASE,
|
||||
api_version=AZURE_DALLE_API_VERSION,
|
||||
)
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
model=img_generation_llm_config.model_name,
|
||||
)
|
||||
]
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
bing_api_key = BING_API_KEY
|
||||
if not bing_api_key:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(
|
||||
api_key=bing_api_key,
|
||||
answer_style_config=answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
custom_headers=(db_tool_model.custom_headers or [])
|
||||
+ (
|
||||
header_dict_to_header_list(
|
||||
custom_tool_additional_headers or {}
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tool_dict = construct_tools(
|
||||
persona=persona,
|
||||
prompt_config=prompt_config,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
search_tool_config=SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
),
|
||||
image_generation_tool_config=ImageGenerationToolConfig(
|
||||
additional_headers=litellm_additional_headers,
|
||||
),
|
||||
custom_tool_config=CustomToolConfig(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
)
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
tools, llm_tokenizer
|
||||
)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
@ -871,7 +777,6 @@ def stream_chat_message_objects(
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
@ -879,9 +784,11 @@ def stream_chat_message_objects(
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None,
|
||||
citations=(
|
||||
message_specific_citations.citation_map
|
||||
if message_specific_citations
|
||||
else None
|
||||
),
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
@ -915,7 +822,6 @@ def stream_chat_message_objects(
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
@ -925,7 +831,6 @@ def stream_chat_message(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
is_connected=is_connected,
|
||||
|
@ -24,6 +24,13 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
|
||||
return tool
|
||||
|
||||
|
||||
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
|
||||
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
|
||||
if not tool:
|
||||
raise ValueError("Tool by specified name does not exist")
|
||||
return tool
|
||||
|
||||
|
||||
def create_tool(
|
||||
name: str,
|
||||
description: str | None,
|
||||
@ -37,7 +44,7 @@ def create_tool(
|
||||
description=description,
|
||||
in_code_tool_id=None,
|
||||
openapi_schema=openapi_schema,
|
||||
custom_headers=[header.dict() for header in custom_headers]
|
||||
custom_headers=[header.model_dump() for header in custom_headers]
|
||||
if custom_headers
|
||||
else [],
|
||||
user_id=user_id,
|
||||
|
@ -74,6 +74,9 @@ from danswer.server.manage.search_settings import router as search_settings_rout
|
||||
from danswer.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from danswer.server.manage.users import router as user_router
|
||||
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
|
||||
get_full_openai_assistants_api_router,
|
||||
)
|
||||
from danswer.server.query_and_chat.chat_backend import router as chat_router
|
||||
from danswer.server.query_and_chat.query_backend import (
|
||||
admin_router as admin_query_router,
|
||||
@ -270,6 +273,9 @@ def get_application() -> FastAPI:
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, indexing_router)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, get_full_openai_assistants_api_router()
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
273
backend/danswer/server/openai_assistants_api/asssistants_api.py
Normal file
273
backend/danswer/server/openai_assistants_api/asssistants_api.py
Normal file
@ -0,0 +1,273 @@
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.persona import get_personas
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.persona import upsert_prompt
|
||||
from danswer.db.tools import get_tool_by_name
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter(prefix="/assistants")
|
||||
|
||||
|
||||
# Base models
|
||||
class AssistantObject(BaseModel):
|
||||
id: int
|
||||
object: str = "assistant"
|
||||
created_at: int
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
model: str
|
||||
instructions: Optional[str] = None
|
||||
tools: list[dict[str, Any]]
|
||||
file_ids: list[str]
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class CreateAssistantRequest(BaseModel):
|
||||
model: str
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
tools: Optional[list[dict[str, Any]]] = None
|
||||
file_ids: Optional[list[str]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class ModifyAssistantRequest(BaseModel):
|
||||
model: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
tools: Optional[list[dict[str, Any]]] = None
|
||||
file_ids: Optional[list[str]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class DeleteAssistantResponse(BaseModel):
|
||||
id: int
|
||||
object: str = "assistant.deleted"
|
||||
deleted: bool
|
||||
|
||||
|
||||
class ListAssistantsResponse(BaseModel):
|
||||
object: str = "list"
|
||||
data: list[AssistantObject]
|
||||
first_id: Optional[int] = None
|
||||
last_id: Optional[int] = None
|
||||
has_more: bool
|
||||
|
||||
|
||||
def persona_to_assistant(persona: Persona) -> AssistantObject:
|
||||
return AssistantObject(
|
||||
id=persona.id,
|
||||
created_at=0,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
model=persona.llm_model_version_override or "gpt-3.5-turbo",
|
||||
instructions=persona.prompts[0].system_prompt if persona.prompts else None,
|
||||
tools=[
|
||||
{
|
||||
"type": tool.display_name,
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"schema": tool.openapi_schema,
|
||||
},
|
||||
}
|
||||
for tool in persona.tools
|
||||
],
|
||||
file_ids=[], # Assuming no file support for now
|
||||
metadata={}, # Assuming no metadata for now
|
||||
)
|
||||
|
||||
|
||||
# API endpoints
|
||||
@router.post("")
|
||||
def create_assistant(
|
||||
request: CreateAssistantRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
prompt = None
|
||||
if request.instructions:
|
||||
prompt = upsert_prompt(
|
||||
user=user,
|
||||
name=f"Prompt for {request.name or 'New Assistant'}",
|
||||
description="Auto-generated prompt",
|
||||
system_prompt=request.instructions,
|
||||
task_prompt="",
|
||||
include_citations=True,
|
||||
datetime_aware=True,
|
||||
personas=[],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
tool_ids = []
|
||||
for tool in request.tools or []:
|
||||
tool_type = tool.get("type")
|
||||
if not tool_type:
|
||||
continue
|
||||
|
||||
try:
|
||||
tool_db = get_tool_by_name(tool_type, db_session)
|
||||
tool_ids.append(tool_db.id)
|
||||
except ValueError:
|
||||
# Skip tools that don't exist in the database
|
||||
logger.error(f"Tool {tool_type} not found in database")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Tool {tool_type} not found in database"
|
||||
)
|
||||
|
||||
persona = upsert_persona(
|
||||
user=user,
|
||||
name=request.name or f"Assistant-{uuid4()}",
|
||||
description=request.description or "",
|
||||
num_chunks=25,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=request.model,
|
||||
starter_messages=None,
|
||||
is_public=False,
|
||||
db_session=db_session,
|
||||
prompt_ids=[prompt.id] if prompt else [0],
|
||||
document_set_ids=[],
|
||||
tool_ids=tool_ids,
|
||||
icon_color=None,
|
||||
icon_shape=None,
|
||||
is_visible=True,
|
||||
)
|
||||
|
||||
if prompt:
|
||||
prompt.personas = [persona]
|
||||
db_session.commit()
|
||||
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
""
|
||||
|
||||
|
||||
@router.get("/{assistant_id}")
|
||||
def retrieve_assistant(
|
||||
assistant_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
try:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
except ValueError:
|
||||
persona = None
|
||||
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
@router.post("/{assistant_id}")
|
||||
def modify_assistant(
|
||||
assistant_id: int,
|
||||
request: ModifyAssistantRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantObject:
|
||||
persona = get_persona_by_id(
|
||||
persona_id=assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=True,
|
||||
)
|
||||
if not persona:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
update_data = request.model_dump(exclude_unset=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(persona, key, value)
|
||||
|
||||
if "instructions" in update_data and persona.prompts:
|
||||
persona.prompts[0].system_prompt = update_data["instructions"]
|
||||
|
||||
db_session.commit()
|
||||
return persona_to_assistant(persona)
|
||||
|
||||
|
||||
@router.delete("/{assistant_id}")
|
||||
def delete_assistant(
|
||||
assistant_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DeleteAssistantResponse:
|
||||
try:
|
||||
mark_persona_as_deleted(
|
||||
persona_id=int(assistant_id),
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
return DeleteAssistantResponse(id=assistant_id, deleted=True)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Assistant not found")
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_assistants(
|
||||
limit: int = Query(20, le=100),
|
||||
order: str = Query("desc", regex="^(asc|desc)$"),
|
||||
after: Optional[int] = None,
|
||||
before: Optional[int] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ListAssistantsResponse:
|
||||
personas = list(
|
||||
get_personas(
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
get_editable=False,
|
||||
joinedload_all=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply filtering based on after and before
|
||||
if after:
|
||||
personas = [p for p in personas if p.id > int(after)]
|
||||
if before:
|
||||
personas = [p for p in personas if p.id < int(before)]
|
||||
|
||||
# Apply ordering
|
||||
personas.sort(key=lambda p: p.id, reverse=(order == "desc"))
|
||||
|
||||
# Apply limit
|
||||
personas = personas[:limit]
|
||||
|
||||
assistants = [persona_to_assistant(p) for p in personas]
|
||||
|
||||
return ListAssistantsResponse(
|
||||
data=assistants,
|
||||
first_id=assistants[0].id if assistants else None,
|
||||
last_id=assistants[-1].id if assistants else None,
|
||||
has_more=len(personas) == limit,
|
||||
)
|
@ -0,0 +1,19 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from danswer.server.openai_assistants_api.asssistants_api import (
|
||||
router as assistants_router,
|
||||
)
|
||||
from danswer.server.openai_assistants_api.messages_api import router as messages_router
|
||||
from danswer.server.openai_assistants_api.runs_api import router as runs_router
|
||||
from danswer.server.openai_assistants_api.threads_api import router as threads_router
|
||||
|
||||
|
||||
def get_full_openai_assistants_api_router() -> APIRouter:
|
||||
router = APIRouter(prefix="/openai-assistants")
|
||||
|
||||
router.include_router(assistants_router)
|
||||
router.include_router(runs_router)
|
||||
router.include_router(threads_router)
|
||||
router.include_router(messages_router)
|
||||
|
||||
return router
|
235
backend/danswer/server/openai_assistants_api/messages_api.py
Normal file
235
backend/danswer/server/openai_assistants_api/messages_api.py
Normal file
@ -0,0 +1,235 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
|
||||
router = APIRouter(prefix="")
|
||||
|
||||
|
||||
Role = Literal["user", "assistant"]
|
||||
|
||||
|
||||
class MessageContent(BaseModel):
|
||||
type: Literal["text"]
|
||||
text: str
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
|
||||
object: Literal["thread.message"] = "thread.message"
|
||||
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
||||
thread_id: str
|
||||
role: Role
|
||||
content: list[MessageContent]
|
||||
file_ids: list[str] = []
|
||||
assistant_id: Optional[str] = None
|
||||
run_id: Optional[str] = None
|
||||
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
|
||||
|
||||
|
||||
class CreateMessageRequest(BaseModel):
|
||||
role: Role
|
||||
content: str
|
||||
file_ids: list[str] = []
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
class ListMessagesResponse(BaseModel):
|
||||
object: Literal["list"] = "list"
|
||||
data: list[Message]
|
||||
first_id: str
|
||||
last_id: str
|
||||
has_more: bool
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages")
|
||||
def create_message(
|
||||
thread_id: str,
|
||||
message: CreateMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=uuid.UUID(thread_id),
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
latest_message = (
|
||||
chat_messages[-1]
|
||||
if chat_messages
|
||||
else get_or_create_root_message(chat_session.id, db_session)
|
||||
)
|
||||
|
||||
new_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=latest_message,
|
||||
message=message.content,
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
token_count=check_number_of_tokens(message.content),
|
||||
message_type=(
|
||||
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return Message(
|
||||
id=str(new_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user",
|
||||
content=[MessageContent(type="text", text=message.content)],
|
||||
file_ids=message.file_ids,
|
||||
metadata=message.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/messages")
|
||||
def list_messages(
|
||||
thread_id: str,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ListMessagesResponse:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=uuid.UUID(thread_id),
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||||
|
||||
messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Apply filtering based on after and before
|
||||
if after:
|
||||
messages = [m for m in messages if str(m.id) >= after]
|
||||
if before:
|
||||
messages = [m for m in messages if str(m.id) <= before]
|
||||
|
||||
# Apply ordering
|
||||
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
|
||||
|
||||
# Apply limit
|
||||
messages = messages[:limit]
|
||||
|
||||
data = [
|
||||
Message(
|
||||
id=str(m.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if m.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=m.message)],
|
||||
created_at=int(m.time_sent.timestamp()),
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
|
||||
return ListMessagesResponse(
|
||||
data=data,
|
||||
first_id=str(data[0].id) if data else "",
|
||||
last_id=str(data[-1].id) if data else "",
|
||||
has_more=len(messages) == limit,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/messages/{message_id}")
|
||||
def retrieve_message(
|
||||
thread_id: str,
|
||||
message_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=message_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
return Message(
|
||||
id=str(chat_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if chat_message.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=chat_message.message)],
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
)
|
||||
|
||||
|
||||
class ModifyMessageRequest(BaseModel):
|
||||
metadata: dict
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/messages/{message_id}")
|
||||
def modify_message(
|
||||
thread_id: str,
|
||||
message_id: int,
|
||||
request: ModifyMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Message:
|
||||
user_id = user.id if user else None
|
||||
|
||||
try:
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=message_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Message not found")
|
||||
|
||||
# Update metadata
|
||||
# TODO: Uncomment this once we have metadata in the chat message
|
||||
# chat_message.metadata = request.metadata
|
||||
# db_session.commit()
|
||||
|
||||
return Message(
|
||||
id=str(chat_message.id),
|
||||
thread_id=thread_id,
|
||||
role="user" if chat_message.message_type == "user" else "assistant",
|
||||
content=[MessageContent(type="text", text=chat_message.message)],
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
344
backend/danswer/server/openai_assistants_api/runs_api.py
Normal file
344
backend/danswer/server/openai_assistants_api/runs_api.py
Normal file
@ -0,0 +1,344 @@
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import User
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class RunRequest(BaseModel):
|
||||
assistant_id: int
|
||||
model: Optional[str] = None
|
||||
instructions: Optional[str] = None
|
||||
additional_instructions: Optional[str] = None
|
||||
tools: Optional[list[dict]] = None
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
RunStatus = Literal[
|
||||
"queued",
|
||||
"in_progress",
|
||||
"requires_action",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
"failed",
|
||||
"completed",
|
||||
"expired",
|
||||
]
|
||||
|
||||
|
||||
class RunResponse(BaseModel):
|
||||
id: str
|
||||
object: Literal["thread.run"]
|
||||
created_at: int
|
||||
assistant_id: int
|
||||
thread_id: UUID
|
||||
status: RunStatus
|
||||
started_at: Optional[int] = None
|
||||
expires_at: Optional[int] = None
|
||||
cancelled_at: Optional[int] = None
|
||||
failed_at: Optional[int] = None
|
||||
completed_at: Optional[int] = None
|
||||
last_error: Optional[dict] = None
|
||||
model: str
|
||||
instructions: str
|
||||
tools: list[dict]
|
||||
file_ids: list[str]
|
||||
metadata: Optional[dict] = None
|
||||
|
||||
|
||||
def process_run_in_background(
|
||||
message_id: int,
|
||||
parent_message_id: int,
|
||||
chat_session_id: UUID,
|
||||
assistant_id: int,
|
||||
instructions: str,
|
||||
tools: list[dict],
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Get the latest message in the chat session
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
search_tool_retrieval_details = RetrievalDetails()
|
||||
for tool in tools:
|
||||
if tool["type"] == SearchTool.__name__ and (
|
||||
retrieval_details := tool.get("retrieval_details")
|
||||
):
|
||||
search_tool_retrieval_details = RetrievalDetails.model_validate(
|
||||
retrieval_details
|
||||
)
|
||||
break
|
||||
|
||||
new_msg_req = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=int(parent_message_id) if parent_message_id else None,
|
||||
message=instructions,
|
||||
file_descriptors=[],
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=search_tool_retrieval_details, # Adjust as needed
|
||||
query_override=None,
|
||||
regenerate=None,
|
||||
llm_override=None,
|
||||
prompt_override=None,
|
||||
alternate_assistant_id=assistant_id,
|
||||
use_existing_user_message=True,
|
||||
existing_assistant_message_id=message_id,
|
||||
)
|
||||
|
||||
run_message = get_chat_message(message_id, user.id if user else None, db_session)
|
||||
try:
|
||||
for packet in stream_chat_message_objects(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
):
|
||||
if isinstance(packet, ChatMessageDetail):
|
||||
# Update the run status and message content
|
||||
run_message = get_chat_message(
|
||||
message_id, user.id if user else None, db_session
|
||||
)
|
||||
if run_message:
|
||||
# this handles cancelling
|
||||
if run_message.error:
|
||||
return
|
||||
|
||||
run_message.message = packet.message
|
||||
run_message.message_type = MessageType.ASSISTANT
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.exception("Error processing run in background")
|
||||
run_message.error = str(e)
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
db_session.refresh(run_message)
|
||||
if run_message.token_count == 0:
|
||||
run_message.error = "No tokens generated"
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/runs")
|
||||
def create_run(
|
||||
thread_id: UUID,
|
||||
run_request: RunRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
latest_message = (
|
||||
chat_messages[-1]
|
||||
if chat_messages
|
||||
else get_or_create_root_message(chat_session.id, db_session)
|
||||
)
|
||||
|
||||
# Create a new "run" (chat message) in the session
|
||||
new_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=latest_message,
|
||||
message="",
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
token_count=0,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
db_session.flush()
|
||||
latest_message.latest_child_message = new_message.id
|
||||
db_session.commit()
|
||||
|
||||
# Schedule the background task
|
||||
background_tasks.add_task(
|
||||
process_run_in_background,
|
||||
new_message.id,
|
||||
latest_message.id,
|
||||
chat_session.id,
|
||||
run_request.assistant_id,
|
||||
run_request.instructions or "",
|
||||
run_request.tools or [],
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
return RunResponse(
|
||||
id=str(new_message.id),
|
||||
object="thread.run",
|
||||
created_at=int(new_message.time_sent.timestamp()),
|
||||
assistant_id=run_request.assistant_id,
|
||||
thread_id=chat_session.id,
|
||||
status="queued",
|
||||
model=run_request.model or "default_model",
|
||||
instructions=run_request.instructions or "",
|
||||
tools=run_request.tools or [],
|
||||
file_ids=[],
|
||||
metadata=run_request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}")
|
||||
def retrieve_run(
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
# Retrieve the chat message (which represents a "run" in DAnswer)
|
||||
chat_message = get_chat_message(
|
||||
chat_message_id=int(run_id), # Convert string run_id to int
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
if not chat_message:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
chat_session = chat_message.chat_session
|
||||
|
||||
# Map DAnswer status to OpenAI status
|
||||
run_status: RunStatus = "queued"
|
||||
if chat_message.message:
|
||||
run_status = "in_progress"
|
||||
if chat_message.token_count != 0:
|
||||
run_status = "completed"
|
||||
if chat_message.error:
|
||||
run_status = "cancelled"
|
||||
|
||||
return RunResponse(
|
||||
id=run_id,
|
||||
object="thread.run",
|
||||
created_at=int(chat_message.time_sent.timestamp()),
|
||||
assistant_id=chat_session.persona_id or 0,
|
||||
thread_id=chat_session.id,
|
||||
status=run_status,
|
||||
started_at=int(chat_message.time_sent.timestamp()),
|
||||
completed_at=(
|
||||
int(chat_message.time_sent.timestamp()) if chat_message.message else None
|
||||
),
|
||||
model=chat_session.current_alternate_model or "default_model",
|
||||
instructions="", # DAnswer doesn't store per-message instructions
|
||||
tools=[], # DAnswer doesn't have a direct equivalent for tools
|
||||
file_ids=(
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
),
|
||||
metadata=None, # DAnswer doesn't store metadata for individual messages
|
||||
)
|
||||
|
||||
|
||||
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
|
||||
def cancel_run(
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RunResponse:
|
||||
# In DAnswer, we don't have a direct equivalent to cancelling a run
|
||||
# We'll simulate it by marking the message as "cancelled"
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first()
|
||||
)
|
||||
if not chat_message:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
chat_message.error = "Cancelled"
|
||||
db_session.commit()
|
||||
|
||||
return retrieve_run(thread_id, run_id, user, db_session)
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs")
|
||||
def list_runs(
|
||||
thread_id: UUID,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[RunResponse]:
|
||||
# In DAnswer, we'll treat each message in a chat session as a "run"
|
||||
chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Apply pagination
|
||||
if after:
|
||||
chat_messages = [msg for msg in chat_messages if str(msg.id) > after]
|
||||
if before:
|
||||
chat_messages = [msg for msg in chat_messages if str(msg.id) < before]
|
||||
|
||||
# Apply ordering
|
||||
chat_messages = sorted(
|
||||
chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc")
|
||||
)
|
||||
|
||||
# Apply limit
|
||||
chat_messages = chat_messages[:limit]
|
||||
|
||||
return [
|
||||
retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages
|
||||
]
|
||||
|
||||
|
||||
@router.get("/threads/{thread_id}/runs/{run_id}/steps")
|
||||
def list_run_steps(
|
||||
run_id: str,
|
||||
limit: int = 20,
|
||||
order: Literal["asc", "desc"] = "desc",
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[dict]: # You may want to create a specific model for run steps
|
||||
# DAnswer doesn't have an equivalent to run steps
|
||||
# We'll return an empty list to maintain API compatibility
|
||||
return []
|
||||
|
||||
|
||||
# Additional helper functions can be added here if needed
|
156
backend/danswer/server/openai_assistants_api/threads_api.py
Normal file
156
backend/danswer/server/openai_assistants_api/threads_api.py
Normal file
@ -0,0 +1,156 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import delete_chat_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_chat_sessions_by_user
|
||||
from danswer.db.chat import update_chat_session
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.server.query_and_chat.models import ChatSessionDetails
|
||||
from danswer.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter(prefix="/threads")
|
||||
|
||||
|
||||
# Models
|
||||
class Thread(BaseModel):
|
||||
id: UUID
|
||||
object: str = "thread"
|
||||
created_at: int
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class CreateThreadRequest(BaseModel):
|
||||
messages: Optional[list[dict]] = None
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
class ModifyThreadRequest(BaseModel):
|
||||
metadata: Optional[dict[str, str]] = None
|
||||
|
||||
|
||||
# API Endpoints
|
||||
@router.post("")
|
||||
def create_thread(
|
||||
request: CreateThreadRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # Leave the naming till later to prevent delay
|
||||
user_id=user_id,
|
||||
persona_id=0,
|
||||
)
|
||||
|
||||
return Thread(
|
||||
id=new_chat_session.id,
|
||||
created_at=int(new_chat_session.time_created.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}")
|
||||
def retrieve_thread(
|
||||
thread_id: UUID,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=thread_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return Thread(
|
||||
id=chat_session.id,
|
||||
created_at=int(chat_session.time_created.timestamp()),
|
||||
metadata=None, # Assuming we don't store metadata in our current implementation
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{thread_id}")
|
||||
def modify_thread(
|
||||
thread_id: UUID,
|
||||
request: ModifyThreadRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Thread:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
chat_session = update_chat_session(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
chat_session_id=thread_id,
|
||||
description=None, # Not updating description
|
||||
sharing_status=None, # Not updating sharing status
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return Thread(
|
||||
id=chat_session.id,
|
||||
created_at=int(chat_session.time_created.timestamp()),
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{thread_id}")
|
||||
def delete_thread(
|
||||
thread_id: UUID,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> dict:
|
||||
user_id = user.id if user else None
|
||||
try:
|
||||
delete_chat_session(
|
||||
user_id=user_id,
|
||||
chat_session_id=thread_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="Thread not found")
|
||||
|
||||
return {"id": str(thread_id), "object": "thread.deleted", "deleted": True}
|
||||
|
||||
|
||||
@router.get("")
|
||||
def list_threads(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
user_id = user.id if user else None
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return ChatSessionsResponse(
|
||||
sessions=[
|
||||
ChatSessionDetails(
|
||||
id=chat.id,
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
)
|
@ -347,7 +347,6 @@ def handle_new_chat_message(
|
||||
for packet in stream_chat_message(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
|
@ -108,6 +108,9 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
# used for "OpenAI Assistants API"
|
||||
existing_assistant_message_id: int | None = None
|
||||
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
255
backend/danswer/tools/tool_constructor.py
Normal file
255
backend/danswer/tools/tool_constructor.py
Normal file
@ -0,0 +1,255 @@
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.models import DynamicSchemaInfo
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.headers import header_dict_to_header_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
"""Helper function to get image generation LLM config based on available providers"""
|
||||
if llm and llm.config.api_key and llm.config.model_provider == "openai":
|
||||
return LLMConfig(
|
||||
model_provider=llm.config.model_provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm.config.api_key,
|
||||
api_base=llm.config.api_base,
|
||||
api_version=llm.config.api_version,
|
||||
)
|
||||
|
||||
if llm.config.model_provider == "azure" and AZURE_DALLE_API_KEY is not None:
|
||||
return LLMConfig(
|
||||
model_provider="azure",
|
||||
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=AZURE_DALLE_API_KEY,
|
||||
api_base=AZURE_DALLE_API_BASE,
|
||||
api_version=AZURE_DALLE_API_VERSION,
|
||||
)
|
||||
|
||||
# Fallback to checking for OpenAI provider in database
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError("Image generation tool requires an OpenAI API key")
|
||||
|
||||
return LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name="dall-e-3",
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
|
||||
|
||||
class SearchToolConfig(BaseModel):
|
||||
answer_style_config: AnswerStyleConfig = Field(
|
||||
default_factory=lambda: AnswerStyleConfig(citation_config=CitationConfig())
|
||||
)
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
chunks_above: int = 0
|
||||
chunks_below: int = 0
|
||||
full_doc: bool = False
|
||||
latest_query_files: list[InMemoryChatFile] | None = None
|
||||
|
||||
|
||||
class InternetSearchToolConfig(BaseModel):
|
||||
answer_style_config: AnswerStyleConfig = Field(
|
||||
default_factory=lambda: AnswerStyleConfig(
|
||||
citation_config=CitationConfig(all_docs_useful=True)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class ImageGenerationToolConfig(BaseModel):
|
||||
additional_headers: dict[str, str] | None = None
|
||||
|
||||
|
||||
class CustomToolConfig(BaseModel):
|
||||
chat_session_id: UUID | None = None
|
||||
message_id: int | None = None
|
||||
additional_headers: dict[str, str] | None = None
|
||||
|
||||
|
||||
def construct_tools(
|
||||
persona: Persona,
|
||||
prompt_config: PromptConfig,
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
internet_search_tool_config: InternetSearchToolConfig | None = None,
|
||||
image_generation_tool_config: ImageGenerationToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs"""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
for db_tool_model in persona.tools:
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
|
||||
# Handle Search Tool
|
||||
if tool_cls.__name__ == SearchTool.__name__:
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
selected_sections=search_tool_config.selected_sections,
|
||||
chunks_above=search_tool_config.chunks_above,
|
||||
chunks_below=search_tool_config.chunks_below,
|
||||
full_doc=search_tool_config.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
|
||||
# Handle Image Generation Tool
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
if not image_generation_tool_config:
|
||||
image_generation_tool_config = ImageGenerationToolConfig()
|
||||
|
||||
img_generation_llm_config = _get_image_generation_config(
|
||||
llm, db_session
|
||||
)
|
||||
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=image_generation_tool_config.additional_headers,
|
||||
model=img_generation_llm_config.model_name,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle Internet Search Tool
|
||||
elif tool_cls.__name__ == InternetSearchTool.__name__:
|
||||
if not internet_search_tool_config:
|
||||
internet_search_tool_config = InternetSearchToolConfig()
|
||||
|
||||
if not BING_API_KEY:
|
||||
raise ValueError(
|
||||
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
InternetSearchTool(
|
||||
api_key=BING_API_KEY,
|
||||
answer_style_config=internet_search_tool_config.answer_style_config,
|
||||
prompt_config=prompt_config,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle custom tools
|
||||
elif db_tool_model.openapi_schema:
|
||||
if not custom_tool_config:
|
||||
custom_tool_config = CustomToolConfig()
|
||||
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=custom_tool_config.chat_session_id,
|
||||
message_id=custom_tool_config.message_id,
|
||||
),
|
||||
custom_headers=(db_tool_model.custom_headers or [])
|
||||
+ (
|
||||
header_dict_to_header_list(
|
||||
custom_tool_config.additional_headers or {}
|
||||
)
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
if search_tool_config:
|
||||
search_tool_config.document_pruning_config.tool_num_tokens = (
|
||||
compute_all_tool_tokens(
|
||||
tools,
|
||||
get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
),
|
||||
)
|
||||
)
|
||||
search_tool_config.document_pruning_config.using_tool_message = (
|
||||
explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
)
|
||||
|
||||
return tool_dict
|
@ -13,6 +13,14 @@ from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
DOMAIN = "test.com"
|
||||
DEFAULT_PASSWORD = "test"
|
||||
|
||||
|
||||
def build_email(name: str) -> str:
|
||||
return f"{name}@test.com"
|
||||
|
||||
|
||||
class UserManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
@ -23,9 +31,9 @@ class UserManager:
|
||||
name = f"test{str(uuid4())}"
|
||||
|
||||
if email is None:
|
||||
email = f"{name}@test.com"
|
||||
email = build_email(name)
|
||||
|
||||
password = "test"
|
||||
password = DEFAULT_PASSWORD
|
||||
|
||||
body = {
|
||||
"email": email,
|
||||
|
55
backend/tests/integration/openai_assistants_api/conftest.py
Normal file
55
backend/tests/integration/openai_assistants_api/conftest.py
Normal file
@ -0,0 +1,55 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create("admin_user")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
return UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email("admin_user"),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
|
||||
return LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thread_id(admin_user: Optional[DATestUser]) -> UUID:
|
||||
# Create a thread to use in the tests
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads", # Updated endpoint path
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return UUID(response.json()["id"])
|
@ -0,0 +1,151 @@
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
ASSISTANTS_URL = f"{API_SERVER_URL}/openai-assistants/assistants"
|
||||
|
||||
|
||||
def test_create_assistant(admin_user: DATestUser | None) -> None:
|
||||
response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={
|
||||
"model": "gpt-3.5-turbo",
|
||||
"name": "Test Assistant",
|
||||
"description": "A test assistant",
|
||||
"instructions": "You are a helpful assistant.",
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["name"] == "Test Assistant"
|
||||
assert data["description"] == "A test assistant"
|
||||
assert data["model"] == "gpt-3.5-turbo"
|
||||
assert data["instructions"] == "You are a helpful assistant."
|
||||
|
||||
|
||||
def test_retrieve_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Retrieve Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the assistant
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["name"] == "Retrieve Test"
|
||||
|
||||
|
||||
def test_modify_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Modify Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the assistant
|
||||
response = requests.post(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
json={"name": "Modified Assistant", "instructions": "New instructions"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["name"] == "Modified Assistant"
|
||||
assert data["instructions"] == "New instructions"
|
||||
|
||||
|
||||
def test_delete_assistant(admin_user: DATestUser | None) -> None:
|
||||
# First, create an assistant
|
||||
create_response = requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": "Delete Test"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
assistant_id = create_response.json()["id"]
|
||||
|
||||
# Now, delete the assistant
|
||||
response = requests.delete(
|
||||
f"{ASSISTANTS_URL}/{assistant_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == assistant_id
|
||||
assert data["deleted"] is True
|
||||
|
||||
|
||||
def test_list_assistants(admin_user: DATestUser | None) -> None:
|
||||
# Create multiple assistants
|
||||
for i in range(3):
|
||||
requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": f"List Test {i}"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the assistants
|
||||
response = requests.get(
|
||||
ASSISTANTS_URL,
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["object"] == "list"
|
||||
assert len(data["data"]) >= 3 # At least the 3 we just created
|
||||
assert all(assistant["object"] == "assistant" for assistant in data["data"])
|
||||
|
||||
|
||||
def test_list_assistants_pagination(admin_user: DATestUser | None) -> None:
|
||||
# Create 5 assistants
|
||||
for i in range(5):
|
||||
requests.post(
|
||||
ASSISTANTS_URL,
|
||||
json={"model": "gpt-3.5-turbo", "name": f"Pagination Test {i}"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# List assistants with limit
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}?limit=2",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 2
|
||||
assert data["has_more"] is True
|
||||
|
||||
# Get next page
|
||||
before = data["last_id"]
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}?limit=2&before={before}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data["data"]) == 2
|
||||
|
||||
|
||||
def test_assistant_not_found(admin_user: DATestUser | None) -> None:
|
||||
non_existent_id = -99
|
||||
response = requests.get(
|
||||
f"{ASSISTANTS_URL}/{non_existent_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
133
backend/tests/integration/openai_assistants_api/test_messages.py
Normal file
133
backend/tests/integration/openai_assistants_api/test_messages.py
Normal file
@ -0,0 +1,133 @@
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants/threads"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def thread_id(admin_user: Optional[DATestUser]) -> str:
|
||||
response = requests.post(
|
||||
BASE_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages", # URL structure matches API
|
||||
json={
|
||||
"role": "user",
|
||||
"content": "Hello, world!",
|
||||
"file_ids": [],
|
||||
"metadata": {"key": "value"},
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["role"] == "user"
|
||||
assert response_json["content"] == [{"type": "text", "text": "Hello, world!"}]
|
||||
assert response_json["metadata"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_list_messages(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the messages
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["object"] == "list"
|
||||
assert isinstance(response_json["data"], list)
|
||||
assert len(response_json["data"]) > 0
|
||||
assert "first_id" in response_json
|
||||
assert "last_id" in response_json
|
||||
assert "has_more" in response_json
|
||||
|
||||
|
||||
def test_retrieve_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
create_response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
message_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the message
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{thread_id}/messages/{message_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == message_id
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["role"] == "user"
|
||||
assert response_json["content"] == [{"type": "text", "text": "Test message"}]
|
||||
|
||||
|
||||
def test_modify_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
|
||||
# Create a message first
|
||||
create_response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
message_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the message
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{thread_id}/messages/{message_id}",
|
||||
json={"metadata": {"new_key": "new_value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json["id"] == message_id
|
||||
assert response_json["thread_id"] == thread_id
|
||||
assert response_json["metadata"] == {"new_key": "new_value"}
|
||||
|
||||
|
||||
def test_error_handling(admin_user: Optional[DATestUser]) -> None:
|
||||
non_existent_thread_id = str(uuid.uuid4())
|
||||
non_existent_message_id = -99
|
||||
|
||||
# Test with non-existent thread
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/{non_existent_thread_id}/messages",
|
||||
json={"role": "user", "content": "Test message"},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
# Test with non-existent message
|
||||
response = requests.get(
|
||||
f"{BASE_URL}/{non_existent_thread_id}/messages/{non_existent_message_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 404
|
137
backend/tests/integration/openai_assistants_api/test_runs.py
Normal file
137
backend/tests/integration/openai_assistants_api/test_runs.py
Normal file
@ -0,0 +1,137 @@
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def run_id(admin_user: DATestUser | None, thread_id: UUID) -> str:
|
||||
"""Create a run and return its ID."""
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
return response.json()["id"]
|
||||
|
||||
|
||||
def test_create_run(
|
||||
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
"model": "gpt-3.5-turbo",
|
||||
"instructions": "Test instructions",
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["object"] == "thread.run"
|
||||
assert "created_at" in response_json
|
||||
assert response_json["assistant_id"] == 0
|
||||
assert UUID(response_json["thread_id"]) == thread_id
|
||||
assert response_json["status"] == "queued"
|
||||
assert response_json["model"] == "gpt-3.5-turbo"
|
||||
assert response_json["instructions"] == "Test instructions"
|
||||
|
||||
|
||||
def test_retrieve_run(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
retrieve_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert retrieve_response.status_code == 200
|
||||
|
||||
response_json = retrieve_response.json()
|
||||
assert response_json["id"] == run_id
|
||||
assert response_json["object"] == "thread.run"
|
||||
assert "created_at" in response_json
|
||||
assert UUID(response_json["thread_id"]) == thread_id
|
||||
|
||||
|
||||
def test_cancel_run(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
cancel_response = requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/cancel",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert cancel_response.status_code == 200
|
||||
|
||||
response_json = cancel_response.json()
|
||||
assert response_json["id"] == run_id
|
||||
assert response_json["status"] == "cancelled"
|
||||
|
||||
|
||||
def test_list_runs(
|
||||
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
# Create a few runs
|
||||
for _ in range(3):
|
||||
requests.post(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
json={
|
||||
"assistant_id": 0,
|
||||
},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the runs
|
||||
list_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert list_response.status_code == 200
|
||||
|
||||
response_json = list_response.json()
|
||||
assert isinstance(response_json, list)
|
||||
assert len(response_json) >= 3
|
||||
|
||||
for run in response_json:
|
||||
assert "id" in run
|
||||
assert run["object"] == "thread.run"
|
||||
assert "created_at" in run
|
||||
assert UUID(run["thread_id"]) == thread_id
|
||||
assert "status" in run
|
||||
assert "model" in run
|
||||
|
||||
|
||||
def test_list_run_steps(
|
||||
admin_user: DATestUser | None,
|
||||
thread_id: UUID,
|
||||
run_id: str,
|
||||
llm_provider: DATestLLMProvider,
|
||||
) -> None:
|
||||
steps_response = requests.get(
|
||||
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/steps",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert steps_response.status_code == 200
|
||||
|
||||
response_json = steps_response.json()
|
||||
assert isinstance(response_json, list)
|
||||
# Since DAnswer doesn't have an equivalent to run steps, we expect an empty list
|
||||
assert len(response_json) == 0
|
132
backend/tests/integration/openai_assistants_api/test_threads.py
Normal file
132
backend/tests/integration/openai_assistants_api/test_threads.py
Normal file
@ -0,0 +1,132 @@
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.db.models import ChatSessionSharedStatus
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
THREADS_URL = f"{API_SERVER_URL}/openai-assistants/threads"
|
||||
|
||||
|
||||
def test_create_thread(admin_user: DATestUser | None) -> None:
|
||||
response = requests.post(
|
||||
THREADS_URL,
|
||||
json={"messages": None, "metadata": {"key": "value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
assert "id" in response_json
|
||||
assert response_json["object"] == "thread"
|
||||
assert "created_at" in response_json
|
||||
assert response_json["metadata"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_retrieve_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, retrieve the thread
|
||||
retrieve_response = requests.get(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert retrieve_response.status_code == 200
|
||||
|
||||
response_json = retrieve_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["object"] == "thread"
|
||||
assert "created_at" in response_json
|
||||
|
||||
|
||||
def test_modify_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, modify the thread
|
||||
modify_response = requests.post(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
json={"metadata": {"new_key": "new_value"}},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert modify_response.status_code == 200
|
||||
|
||||
response_json = modify_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["metadata"] == {"new_key": "new_value"}
|
||||
|
||||
|
||||
def test_delete_thread(admin_user: DATestUser | None) -> None:
|
||||
# First, create a thread
|
||||
create_response = requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert create_response.status_code == 200
|
||||
thread_id = create_response.json()["id"]
|
||||
|
||||
# Now, delete the thread
|
||||
delete_response = requests.delete(
|
||||
f"{THREADS_URL}/{thread_id}",
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
response_json = delete_response.json()
|
||||
assert response_json["id"] == thread_id
|
||||
assert response_json["object"] == "thread.deleted"
|
||||
assert response_json["deleted"] is True
|
||||
|
||||
|
||||
def test_list_threads(admin_user: DATestUser | None) -> None:
|
||||
# Create a few threads
|
||||
for _ in range(3):
|
||||
requests.post(
|
||||
THREADS_URL,
|
||||
json={},
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
|
||||
# Now, list the threads
|
||||
list_response = requests.get(
|
||||
THREADS_URL,
|
||||
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert list_response.status_code == 200
|
||||
|
||||
response_json = list_response.json()
|
||||
assert "sessions" in response_json
|
||||
assert len(response_json["sessions"]) >= 3
|
||||
|
||||
for session in response_json["sessions"]:
|
||||
assert "id" in session
|
||||
assert "name" in session
|
||||
assert "persona_id" in session
|
||||
assert "time_created" in session
|
||||
assert "shared_status" in session
|
||||
assert "folder_id" in session
|
||||
assert "current_alternate_model" in session
|
||||
|
||||
# Validate UUID
|
||||
UUID(session["id"])
|
||||
|
||||
# Validate shared_status
|
||||
assert session["shared_status"] in [
|
||||
status.value for status in ChatSessionSharedStatus
|
||||
]
|
125
examples/assistants-api/topics_analyzer.py
Normal file
125
examples/assistants-api/topics_analyzer.py
Normal file
@ -0,0 +1,125 @@
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
ASSISTANT_NAME = "Topic Analyzer"
|
||||
SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that analyzes topics by searching through available \
|
||||
documents and providing insights. These available documents come from common \
|
||||
workplace tools like Slack, emails, Confluence, Google Drive, etc.
|
||||
|
||||
When analyzing a topic:
|
||||
1. Search for relevant information using the search tool
|
||||
2. Synthesize the findings into clear insights
|
||||
3. Highlight key trends, patterns, or notable developments
|
||||
4. Maintain objectivity and cite sources where relevant
|
||||
"""
|
||||
USER_PROMPT = """
|
||||
Please analyze and provide insights about this topic: {topic}.
|
||||
|
||||
IMPORTANT: do not mention things that are not relevant to the specified topic. \
|
||||
If there is no relevant information, just say "No relevant information found."
|
||||
"""
|
||||
|
||||
|
||||
def wait_on_run(client: OpenAI, run, thread):
|
||||
while run.status == "queued" or run.status == "in_progress":
|
||||
run = client.beta.threads.runs.retrieve(
|
||||
thread_id=thread.id,
|
||||
run_id=run.id,
|
||||
)
|
||||
time.sleep(0.5)
|
||||
return run
|
||||
|
||||
|
||||
def show_response(messages) -> None:
|
||||
# Get only the assistant's response text
|
||||
for message in messages.data[::-1]:
|
||||
if message.role == "assistant":
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
print(content.text)
|
||||
break
|
||||
|
||||
|
||||
def analyze_topics(topics: list[str]) -> None:
|
||||
openai_api_key = os.environ.get(
|
||||
"OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"
|
||||
)
|
||||
danswer_api_key = os.environ.get(
|
||||
"DANSWER_API_KEY", "<your Danswer API key if not set as env var>"
|
||||
)
|
||||
client = OpenAI(
|
||||
api_key=openai_api_key,
|
||||
base_url="http://localhost:8080/openai-assistants",
|
||||
default_headers={
|
||||
"Authorization": f"Bearer {danswer_api_key}",
|
||||
},
|
||||
)
|
||||
|
||||
# Create an assistant if it doesn't exist
|
||||
try:
|
||||
assistants = client.beta.assistants.list(limit=100)
|
||||
# Find the Topic Analyzer assistant if it exists
|
||||
assistant = next((a for a in assistants.data if a.name == ASSISTANT_NAME))
|
||||
client.beta.assistants.delete(assistant.id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
assistant = client.beta.assistants.create(
|
||||
name=ASSISTANT_NAME,
|
||||
instructions=SYSTEM_PROMPT,
|
||||
tools=[{"type": "SearchTool"}], # type: ignore
|
||||
model="gpt-4o",
|
||||
)
|
||||
|
||||
# Process each topic individually
|
||||
for topic in topics:
|
||||
thread = client.beta.threads.create()
|
||||
message = client.beta.threads.messages.create(
|
||||
thread_id=thread.id,
|
||||
role="user",
|
||||
content=USER_PROMPT.format(topic=topic),
|
||||
)
|
||||
|
||||
run = client.beta.threads.runs.create(
|
||||
thread_id=thread.id,
|
||||
assistant_id=assistant.id,
|
||||
tools=[
|
||||
{ # type: ignore
|
||||
"type": "SearchTool",
|
||||
"retrieval_details": {
|
||||
"run_search": "always",
|
||||
"filters": {
|
||||
"time_cutoff": str(
|
||||
datetime.now(timezone.utc) - timedelta(days=7)
|
||||
)
|
||||
},
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
run = wait_on_run(client, run, thread)
|
||||
messages = client.beta.threads.messages.list(
|
||||
thread_id=thread.id, order="asc", after=message.id
|
||||
)
|
||||
print(f"\nAnalysis for topic: {topic}")
|
||||
print("-" * 40)
|
||||
show_response(messages)
|
||||
print()
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Analyze specific topics")
|
||||
parser.add_argument("topics", nargs="+", help="Topics to analyze (one or more)")
|
||||
|
||||
args = parser.parse_args()
|
||||
analyze_topics(args.topics)
|
Loading…
x
Reference in New Issue
Block a user