New assistants api (#3097)

This commit is contained in:
Chris Weaver 2024-11-11 07:55:23 -08:00 committed by GitHub
parent 9d57f34c34
commit ba805f766f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 2179 additions and 177 deletions

View File

@ -288,6 +288,15 @@ def upgrade() -> None:
def downgrade() -> None:
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
# below
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM chat_message__search_doc")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM chat_message")
op.execute("DELETE FROM chat_session")
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)

View File

@ -23,6 +23,56 @@ def upgrade() -> None:
def downgrade() -> None:
# Delete chat messages and feedback first since they reference chat sessions
# Get chat messages from sessions with null persona_id
chat_messages_query = """
SELECT id
FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
# Delete dependent records first
op.execute(
f"""
DELETE FROM document_retrieval_feedback
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
op.execute(
f"""
DELETE FROM chat_message__search_doc
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
# Delete chat messages
op.execute(
"""
DELETE FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
)
# Now we can safely delete the chat sessions
op.execute(
"""
DELETE FROM chat_session
WHERE persona_id IS NULL
"""
)
op.alter_column(
"chat_session",
"persona_id",

View File

@ -19,16 +19,10 @@ from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@ -41,7 +35,6 @@ from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
@ -61,14 +54,13 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
@ -77,14 +69,14 @@ from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.force import ForceUseTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.tool_constructor import construct_tools
from danswer.tools.tool_constructor import CustomToolConfig
from danswer.tools.tool_constructor import ImageGenerationToolConfig
from danswer.tools.tool_constructor import InternetSearchToolConfig
from danswer.tools.tool_constructor import SearchToolConfig
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
@ -95,9 +87,6 @@ from danswer.tools.tool_implementations.images.image_generation_tool import (
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
@ -122,9 +111,6 @@ from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@ -295,7 +281,6 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@ -307,6 +292,9 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
@ -428,12 +416,20 @@ def stream_chat_message_objects(
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
if existing_assistant_message_id is None:
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
else:
if final_msg.id != existing_assistant_message_id:
raise RuntimeError(
"The last message was not the existing assistant message. "
f"Final message id: {final_msg.id}, "
f"existing assistant message id: {existing_assistant_message_id}"
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
@ -504,13 +500,19 @@ def stream_chat_message_objects(
),
max_window_percentage=max_document_percentage,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
# we don't need to reserve a message id if we're using an existing assistant message
reserved_message_id = (
final_msg.id
if existing_assistant_message_id is not None
else reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
@ -525,7 +527,13 @@ def stream_chat_message_objects(
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=final_msg,
# if we're using an existing assistant message, then this will just be an
# update operation, in which case the parent should be the parent of
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
),
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
@ -537,6 +545,7 @@ def stream_chat_message_objects(
# reference_docs=,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
)
if not final_msg.prompt:
@ -560,142 +569,39 @@ def stream_chat_message_objects(
structured_response_format=new_msg_req.structured_response_format,
)
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
answer_style_config=answer_style_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
elif (
llm.config.model_provider == "azure"
and AZURE_DALLE_API_KEY is not None
):
img_generation_llm_config = LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
model=img_generation_llm_config.model_name,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(
api_key=bing_api_key,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
),
)
tool_dict = construct_tools(
persona=persona,
prompt_config=prompt_config,
db_session=db_session,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
latest_query_files=latest_query_files,
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
),
image_generation_tool_config=ImageGenerationToolConfig(
additional_headers=litellm_additional_headers,
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
additional_headers=custom_tool_additional_headers,
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
)
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
@ -871,7 +777,6 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@ -879,9 +784,11 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
citations=(
message_specific_citations.citation_map
if message_specific_citations
else None
),
error=None,
tool_call=(
ToolCall(
@ -915,7 +822,6 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@ -925,7 +831,6 @@ def stream_chat_message(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,

View File

@ -24,6 +24,13 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
return tool
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
if not tool:
raise ValueError("Tool by specified name does not exist")
return tool
def create_tool(
name: str,
description: str | None,
@ -37,7 +44,7 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.dict() for header in custom_headers]
custom_headers=[header.model_dump() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,

View File

@ -74,6 +74,9 @@ from danswer.server.manage.search_settings import router as search_settings_rout
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.users import router as user_router
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
)
from danswer.server.query_and_chat.chat_backend import router as chat_router
from danswer.server.query_and_chat.query_backend import (
admin_router as admin_query_router,
@ -270,6 +273,9 @@ def get_application() -> FastAPI:
application, token_rate_limit_settings_router
)
include_router_with_global_prefix_prepended(application, indexing_router)
include_router_with_global_prefix_prepended(
application, get_full_openai_assistants_api_router()
)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@ -0,0 +1,273 @@
from typing import Any
from typing import Optional
from uuid import uuid4
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.db.persona import get_personas
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.db.tools import get_tool_by_name
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/assistants")
# Base models
class AssistantObject(BaseModel):
id: int
object: str = "assistant"
created_at: int
name: Optional[str] = None
description: Optional[str] = None
model: str
instructions: Optional[str] = None
tools: list[dict[str, Any]]
file_ids: list[str]
metadata: Optional[dict[str, Any]] = None
class CreateAssistantRequest(BaseModel):
model: str
name: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
tools: Optional[list[dict[str, Any]]] = None
file_ids: Optional[list[str]] = None
metadata: Optional[dict[str, Any]] = None
class ModifyAssistantRequest(BaseModel):
model: Optional[str] = None
name: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
tools: Optional[list[dict[str, Any]]] = None
file_ids: Optional[list[str]] = None
metadata: Optional[dict[str, Any]] = None
class DeleteAssistantResponse(BaseModel):
id: int
object: str = "assistant.deleted"
deleted: bool
class ListAssistantsResponse(BaseModel):
object: str = "list"
data: list[AssistantObject]
first_id: Optional[int] = None
last_id: Optional[int] = None
has_more: bool
def persona_to_assistant(persona: Persona) -> AssistantObject:
return AssistantObject(
id=persona.id,
created_at=0,
name=persona.name,
description=persona.description,
model=persona.llm_model_version_override or "gpt-3.5-turbo",
instructions=persona.prompts[0].system_prompt if persona.prompts else None,
tools=[
{
"type": tool.display_name,
"function": {
"name": tool.name,
"description": tool.description,
"schema": tool.openapi_schema,
},
}
for tool in persona.tools
],
file_ids=[], # Assuming no file support for now
metadata={}, # Assuming no metadata for now
)
# API endpoints
@router.post("")
def create_assistant(
request: CreateAssistantRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
prompt = None
if request.instructions:
prompt = upsert_prompt(
user=user,
name=f"Prompt for {request.name or 'New Assistant'}",
description="Auto-generated prompt",
system_prompt=request.instructions,
task_prompt="",
include_citations=True,
datetime_aware=True,
personas=[],
db_session=db_session,
)
tool_ids = []
for tool in request.tools or []:
tool_type = tool.get("type")
if not tool_type:
continue
try:
tool_db = get_tool_by_name(tool_type, db_session)
tool_ids.append(tool_db.id)
except ValueError:
# Skip tools that don't exist in the database
logger.error(f"Tool {tool_type} not found in database")
raise HTTPException(
status_code=404, detail=f"Tool {tool_type} not found in database"
)
persona = upsert_persona(
user=user,
name=request.name or f"Assistant-{uuid4()}",
description=request.description or "",
num_chunks=25,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
llm_model_provider_override=None,
llm_model_version_override=request.model,
starter_messages=None,
is_public=False,
db_session=db_session,
prompt_ids=[prompt.id] if prompt else [0],
document_set_ids=[],
tool_ids=tool_ids,
icon_color=None,
icon_shape=None,
is_visible=True,
)
if prompt:
prompt.personas = [persona]
db_session.commit()
return persona_to_assistant(persona)
""
@router.get("/{assistant_id}")
def retrieve_assistant(
assistant_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
try:
persona = get_persona_by_id(
persona_id=assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
except ValueError:
persona = None
if not persona:
raise HTTPException(status_code=404, detail="Assistant not found")
return persona_to_assistant(persona)
@router.post("/{assistant_id}")
def modify_assistant(
assistant_id: int,
request: ModifyAssistantRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
persona = get_persona_by_id(
persona_id=assistant_id,
user=user,
db_session=db_session,
is_for_edit=True,
)
if not persona:
raise HTTPException(status_code=404, detail="Assistant not found")
update_data = request.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(persona, key, value)
if "instructions" in update_data and persona.prompts:
persona.prompts[0].system_prompt = update_data["instructions"]
db_session.commit()
return persona_to_assistant(persona)
@router.delete("/{assistant_id}")
def delete_assistant(
assistant_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> DeleteAssistantResponse:
try:
mark_persona_as_deleted(
persona_id=int(assistant_id),
user=user,
db_session=db_session,
)
return DeleteAssistantResponse(id=assistant_id, deleted=True)
except ValueError:
raise HTTPException(status_code=404, detail="Assistant not found")
@router.get("")
def list_assistants(
limit: int = Query(20, le=100),
order: str = Query("desc", regex="^(asc|desc)$"),
after: Optional[int] = None,
before: Optional[int] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ListAssistantsResponse:
personas = list(
get_personas(
user=user,
db_session=db_session,
get_editable=False,
joinedload_all=True,
)
)
# Apply filtering based on after and before
if after:
personas = [p for p in personas if p.id > int(after)]
if before:
personas = [p for p in personas if p.id < int(before)]
# Apply ordering
personas.sort(key=lambda p: p.id, reverse=(order == "desc"))
# Apply limit
personas = personas[:limit]
assistants = [persona_to_assistant(p) for p in personas]
return ListAssistantsResponse(
data=assistants,
first_id=assistants[0].id if assistants else None,
last_id=assistants[-1].id if assistants else None,
has_more=len(personas) == limit,
)

View File

@ -0,0 +1,19 @@
from fastapi import APIRouter
from danswer.server.openai_assistants_api.asssistants_api import (
router as assistants_router,
)
from danswer.server.openai_assistants_api.messages_api import router as messages_router
from danswer.server.openai_assistants_api.runs_api import router as runs_router
from danswer.server.openai_assistants_api.threads_api import router as threads_router
def get_full_openai_assistants_api_router() -> APIRouter:
router = APIRouter(prefix="/openai-assistants")
router.include_router(assistants_router)
router.include_router(runs_router)
router.include_router(threads_router)
router.include_router(messages_router)
return router

View File

@ -0,0 +1,235 @@
import uuid
from datetime import datetime
from typing import Any
from typing import Literal
from typing import Optional
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.configs.constants import MessageType
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.llm.utils import check_number_of_tokens
router = APIRouter(prefix="")
Role = Literal["user", "assistant"]
class MessageContent(BaseModel):
type: Literal["text"]
text: str
class Message(BaseModel):
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
object: Literal["thread.message"] = "thread.message"
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
thread_id: str
role: Role
content: list[MessageContent]
file_ids: list[str] = []
assistant_id: Optional[str] = None
run_id: Optional[str] = None
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
class CreateMessageRequest(BaseModel):
role: Role
content: str
file_ids: list[str] = []
metadata: Optional[dict] = None
class ListMessagesResponse(BaseModel):
object: Literal["list"] = "list"
data: list[Message]
first_id: str
last_id: str
has_more: bool
@router.post("/threads/{thread_id}/messages")
def create_message(
thread_id: str,
message: CreateMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=uuid.UUID(thread_id),
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user.id if user else None,
db_session=db_session,
)
latest_message = (
chat_messages[-1]
if chat_messages
else get_or_create_root_message(chat_session.id, db_session)
)
new_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=latest_message,
message=message.content,
prompt_id=chat_session.persona.prompts[0].id,
token_count=check_number_of_tokens(message.content),
message_type=(
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
),
db_session=db_session,
)
return Message(
id=str(new_message.id),
thread_id=thread_id,
role="user",
content=[MessageContent(type="text", text=message.content)],
file_ids=message.file_ids,
metadata=message.metadata,
)
@router.get("/threads/{thread_id}/messages")
def list_messages(
thread_id: str,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ListMessagesResponse:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=uuid.UUID(thread_id),
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user_id,
db_session=db_session,
)
# Apply filtering based on after and before
if after:
messages = [m for m in messages if str(m.id) >= after]
if before:
messages = [m for m in messages if str(m.id) <= before]
# Apply ordering
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
# Apply limit
messages = messages[:limit]
data = [
Message(
id=str(m.id),
thread_id=thread_id,
role="user" if m.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=m.message)],
created_at=int(m.time_sent.timestamp()),
)
for m in messages
]
return ListMessagesResponse(
data=data,
first_id=str(data[0].id) if data else "",
last_id=str(data[-1].id) if data else "",
has_more=len(messages) == limit,
)
@router.get("/threads/{thread_id}/messages/{message_id}")
def retrieve_message(
thread_id: str,
message_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_message = get_chat_message(
chat_message_id=message_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Message not found")
return Message(
id=str(chat_message.id),
thread_id=thread_id,
role="user" if chat_message.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=chat_message.message)],
created_at=int(chat_message.time_sent.timestamp()),
)
class ModifyMessageRequest(BaseModel):
metadata: dict
@router.post("/threads/{thread_id}/messages/{message_id}")
def modify_message(
thread_id: str,
message_id: int,
request: ModifyMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_message = get_chat_message(
chat_message_id=message_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Message not found")
# Update metadata
# TODO: Uncomment this once we have metadata in the chat message
# chat_message.metadata = request.metadata
# db_session.commit()
return Message(
id=str(chat_message.id),
thread_id=thread_id,
role="user" if chat_message.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=chat_message.message)],
created_at=int(chat_message.time_sent.timestamp()),
metadata=request.metadata,
)

View File

@ -0,0 +1,344 @@
from typing import Literal
from typing import Optional
from uuid import UUID
from fastapi import APIRouter
from fastapi import BackgroundTasks
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chat.process_message import stream_chat_message_objects
from danswer.configs.constants import MessageType
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.engine import get_session
from danswer.db.models import ChatMessage
from danswer.db.models import User
from danswer.search.models import RetrievalDetails
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter()
class RunRequest(BaseModel):
assistant_id: int
model: Optional[str] = None
instructions: Optional[str] = None
additional_instructions: Optional[str] = None
tools: Optional[list[dict]] = None
metadata: Optional[dict] = None
RunStatus = Literal[
"queued",
"in_progress",
"requires_action",
"cancelling",
"cancelled",
"failed",
"completed",
"expired",
]
class RunResponse(BaseModel):
id: str
object: Literal["thread.run"]
created_at: int
assistant_id: int
thread_id: UUID
status: RunStatus
started_at: Optional[int] = None
expires_at: Optional[int] = None
cancelled_at: Optional[int] = None
failed_at: Optional[int] = None
completed_at: Optional[int] = None
last_error: Optional[dict] = None
model: str
instructions: str
tools: list[dict]
file_ids: list[str]
metadata: Optional[dict] = None
def process_run_in_background(
message_id: int,
parent_message_id: int,
chat_session_id: UUID,
assistant_id: int,
instructions: str,
tools: list[dict],
user: User | None,
db_session: Session,
) -> None:
# Get the latest message in the chat session
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=user.id if user else None,
db_session=db_session,
)
search_tool_retrieval_details = RetrievalDetails()
for tool in tools:
if tool["type"] == SearchTool.__name__ and (
retrieval_details := tool.get("retrieval_details")
):
search_tool_retrieval_details = RetrievalDetails.model_validate(
retrieval_details
)
break
new_msg_req = CreateChatMessageRequest(
chat_session_id=chat_session_id,
parent_message_id=int(parent_message_id) if parent_message_id else None,
message=instructions,
file_descriptors=[],
prompt_id=chat_session.persona.prompts[0].id,
search_doc_ids=None,
retrieval_options=search_tool_retrieval_details, # Adjust as needed
query_override=None,
regenerate=None,
llm_override=None,
prompt_override=None,
alternate_assistant_id=assistant_id,
use_existing_user_message=True,
existing_assistant_message_id=message_id,
)
run_message = get_chat_message(message_id, user.id if user else None, db_session)
try:
for packet in stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
):
if isinstance(packet, ChatMessageDetail):
# Update the run status and message content
run_message = get_chat_message(
message_id, user.id if user else None, db_session
)
if run_message:
# this handles cancelling
if run_message.error:
return
run_message.message = packet.message
run_message.message_type = MessageType.ASSISTANT
db_session.commit()
except Exception as e:
logger.exception("Error processing run in background")
run_message.error = str(e)
db_session.commit()
return
db_session.refresh(run_message)
if run_message.token_count == 0:
run_message.error = "No tokens generated"
db_session.commit()
@router.post("/threads/{thread_id}/runs")
def create_run(
thread_id: UUID,
run_request: RunRequest,
background_tasks: BackgroundTasks,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
try:
chat_session = get_chat_session_by_id(
chat_session_id=thread_id,
user_id=user.id if user else None,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user.id if user else None,
db_session=db_session,
)
latest_message = (
chat_messages[-1]
if chat_messages
else get_or_create_root_message(chat_session.id, db_session)
)
# Create a new "run" (chat message) in the session
new_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=latest_message,
message="",
prompt_id=chat_session.persona.prompts[0].id,
token_count=0,
message_type=MessageType.ASSISTANT,
db_session=db_session,
commit=False,
)
db_session.flush()
latest_message.latest_child_message = new_message.id
db_session.commit()
# Schedule the background task
background_tasks.add_task(
process_run_in_background,
new_message.id,
latest_message.id,
chat_session.id,
run_request.assistant_id,
run_request.instructions or "",
run_request.tools or [],
user,
db_session,
)
return RunResponse(
id=str(new_message.id),
object="thread.run",
created_at=int(new_message.time_sent.timestamp()),
assistant_id=run_request.assistant_id,
thread_id=chat_session.id,
status="queued",
model=run_request.model or "default_model",
instructions=run_request.instructions or "",
tools=run_request.tools or [],
file_ids=[],
metadata=run_request.metadata,
)
@router.get("/threads/{thread_id}/runs/{run_id}")
def retrieve_run(
thread_id: UUID,
run_id: str,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
# Retrieve the chat message (which represents a "run" in DAnswer)
chat_message = get_chat_message(
chat_message_id=int(run_id), # Convert string run_id to int
user_id=user.id if user else None,
db_session=db_session,
)
if not chat_message:
raise HTTPException(status_code=404, detail="Run not found")
chat_session = chat_message.chat_session
# Map DAnswer status to OpenAI status
run_status: RunStatus = "queued"
if chat_message.message:
run_status = "in_progress"
if chat_message.token_count != 0:
run_status = "completed"
if chat_message.error:
run_status = "cancelled"
return RunResponse(
id=run_id,
object="thread.run",
created_at=int(chat_message.time_sent.timestamp()),
assistant_id=chat_session.persona_id or 0,
thread_id=chat_session.id,
status=run_status,
started_at=int(chat_message.time_sent.timestamp()),
completed_at=(
int(chat_message.time_sent.timestamp()) if chat_message.message else None
),
model=chat_session.current_alternate_model or "default_model",
instructions="", # DAnswer doesn't store per-message instructions
tools=[], # DAnswer doesn't have a direct equivalent for tools
file_ids=(
[file["id"] for file in chat_message.files] if chat_message.files else []
),
metadata=None, # DAnswer doesn't store metadata for individual messages
)
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
def cancel_run(
thread_id: UUID,
run_id: str,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
# In DAnswer, we don't have a direct equivalent to cancelling a run
# We'll simulate it by marking the message as "cancelled"
chat_message = (
db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first()
)
if not chat_message:
raise HTTPException(status_code=404, detail="Run not found")
chat_message.error = "Cancelled"
db_session.commit()
return retrieve_run(thread_id, run_id, user, db_session)
@router.get("/threads/{thread_id}/runs")
def list_runs(
thread_id: UUID,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[RunResponse]:
# In DAnswer, we'll treat each message in a chat session as a "run"
chat_messages = get_chat_messages_by_session(
chat_session_id=thread_id,
user_id=user.id if user else None,
db_session=db_session,
)
# Apply pagination
if after:
chat_messages = [msg for msg in chat_messages if str(msg.id) > after]
if before:
chat_messages = [msg for msg in chat_messages if str(msg.id) < before]
# Apply ordering
chat_messages = sorted(
chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc")
)
# Apply limit
chat_messages = chat_messages[:limit]
return [
retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages
]
@router.get("/threads/{thread_id}/runs/{run_id}/steps")
def list_run_steps(
run_id: str,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[dict]: # You may want to create a specific model for run steps
# DAnswer doesn't have an equivalent to run steps
# We'll return an empty list to maintain API compatibility
return []
# Additional helper functions can be added here if needed

View File

@ -0,0 +1,156 @@
from typing import Optional
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.chat import create_chat_session
from danswer.db.chat import delete_chat_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_chat_sessions_by_user
from danswer.db.chat import update_chat_session
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.server.query_and_chat.models import ChatSessionDetails
from danswer.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter(prefix="/threads")
# Models
class Thread(BaseModel):
id: UUID
object: str = "thread"
created_at: int
metadata: Optional[dict[str, str]] = None
class CreateThreadRequest(BaseModel):
messages: Optional[list[dict]] = None
metadata: Optional[dict[str, str]] = None
class ModifyThreadRequest(BaseModel):
metadata: Optional[dict[str, str]] = None
# API Endpoints
@router.post("")
def create_thread(
request: CreateThreadRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
new_chat_session = create_chat_session(
db_session=db_session,
description="", # Leave the naming till later to prevent delay
user_id=user_id,
persona_id=0,
)
return Thread(
id=new_chat_session.id,
created_at=int(new_chat_session.time_created.timestamp()),
metadata=request.metadata,
)
@router.get("/{thread_id}")
def retrieve_thread(
thread_id: UUID,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=thread_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return Thread(
id=chat_session.id,
created_at=int(chat_session.time_created.timestamp()),
metadata=None, # Assuming we don't store metadata in our current implementation
)
@router.post("/{thread_id}")
def modify_thread(
thread_id: UUID,
request: ModifyThreadRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
try:
chat_session = update_chat_session(
db_session=db_session,
user_id=user_id,
chat_session_id=thread_id,
description=None, # Not updating description
sharing_status=None, # Not updating sharing status
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return Thread(
id=chat_session.id,
created_at=int(chat_session.time_created.timestamp()),
metadata=request.metadata,
)
@router.delete("/{thread_id}")
def delete_thread(
thread_id: UUID,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict:
user_id = user.id if user else None
try:
delete_chat_session(
user_id=user_id,
chat_session_id=thread_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return {"id": str(thread_id), "object": "thread.deleted", "deleted": True}
@router.get("")
def list_threads(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatSessionsResponse:
user_id = user.id if user else None
chat_sessions = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
)
return ChatSessionsResponse(
sessions=[
ChatSessionDetails(
id=chat.id,
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
)
for chat in chat_sessions
]
)

View File

@ -347,7 +347,6 @@ def handle_new_chat_message(
for packet in stream_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),

View File

@ -108,6 +108,9 @@ class CreateChatMessageRequest(ChunkContext):
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
# used for "OpenAI Assistants API"
existing_assistant_message_id: int | None = None
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None

View File

@ -0,0 +1,255 @@
from typing import cast
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
"""Helper function to get image generation LLM config based on available providers"""
if llm and llm.config.api_key and llm.config.model_provider == "openai":
return LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
if llm.config.model_provider == "azure" and AZURE_DALLE_API_KEY is not None:
return LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
# Fallback to checking for OpenAI provider in database
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError("Image generation tool requires an OpenAI API key")
return LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
class SearchToolConfig(BaseModel):
answer_style_config: AnswerStyleConfig = Field(
default_factory=lambda: AnswerStyleConfig(citation_config=CitationConfig())
)
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
selected_sections: list[InferenceSection] | None = None
chunks_above: int = 0
chunks_below: int = 0
full_doc: bool = False
latest_query_files: list[InMemoryChatFile] | None = None
class InternetSearchToolConfig(BaseModel):
answer_style_config: AnswerStyleConfig = Field(
default_factory=lambda: AnswerStyleConfig(
citation_config=CitationConfig(all_docs_useful=True)
)
)
class ImageGenerationToolConfig(BaseModel):
additional_headers: dict[str, str] | None = None
class CustomToolConfig(BaseModel):
chat_session_id: UUID | None = None
message_id: int | None = None
additional_headers: dict[str, str] | None = None
def construct_tools(
persona: Persona,
prompt_config: PromptConfig,
db_session: Session,
user: User | None,
llm: LLM,
fast_llm: LLM,
search_tool_config: SearchToolConfig | None = None,
internet_search_tool_config: InternetSearchToolConfig | None = None,
image_generation_tool_config: ImageGenerationToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
) -> dict[int, list[Tool]]:
"""Constructs tools based on persona configuration and available APIs"""
tool_dict: dict[int, list[Tool]] = {}
for db_tool_model in persona.tools:
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
# Handle Search Tool
if tool_cls.__name__ == SearchTool.__name__:
if not search_tool_config:
search_tool_config = SearchToolConfig()
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
)
tool_dict[db_tool_model.id] = [search_tool]
# Handle Image Generation Tool
elif tool_cls.__name__ == ImageGenerationTool.__name__:
if not image_generation_tool_config:
image_generation_tool_config = ImageGenerationToolConfig()
img_generation_llm_config = _get_image_generation_config(
llm, db_session
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=image_generation_tool_config.additional_headers,
model=img_generation_llm_config.model_name,
)
]
# Handle Internet Search Tool
elif tool_cls.__name__ == InternetSearchTool.__name__:
if not internet_search_tool_config:
internet_search_tool_config = InternetSearchToolConfig()
if not BING_API_KEY:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(
api_key=BING_API_KEY,
answer_style_config=internet_search_tool_config.answer_style_config,
prompt_config=prompt_config,
)
]
# Handle custom tools
elif db_tool_model.openapi_schema:
if not custom_tool_config:
custom_tool_config = CustomToolConfig()
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=custom_tool_config.chat_session_id,
message_id=custom_tool_config.message_id,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_config.additional_headers or {}
)
),
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
if search_tool_config:
search_tool_config.document_pruning_config.tool_num_tokens = (
compute_all_tool_tokens(
tools,
get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
),
)
)
search_tool_config.document_pruning_config.using_tool_message = (
explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
)
return tool_dict

View File

@ -13,6 +13,14 @@ from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
DOMAIN = "test.com"
DEFAULT_PASSWORD = "test"
def build_email(name: str) -> str:
return f"{name}@test.com"
class UserManager:
@staticmethod
def create(
@ -23,9 +31,9 @@ class UserManager:
name = f"test{str(uuid4())}"
if email is None:
email = f"{name}@test.com"
email = build_email(name)
password = "test"
password = DEFAULT_PASSWORD
body = {
"email": email,

View File

@ -0,0 +1,55 @@
from typing import Optional
from uuid import UUID
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
@pytest.fixture
def admin_user() -> DATestUser | None:
try:
return UserManager.create("admin_user")
except Exception:
pass
try:
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
)
)
except Exception:
pass
return None
@pytest.fixture
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
return LLMProviderManager.create(user_performing_action=admin_user)
@pytest.fixture
def thread_id(admin_user: Optional[DATestUser]) -> UUID:
# Create a thread to use in the tests
response = requests.post(
f"{BASE_URL}/threads", # Updated endpoint path
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return UUID(response.json()["id"])

View File

@ -0,0 +1,151 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
ASSISTANTS_URL = f"{API_SERVER_URL}/openai-assistants/assistants"
def test_create_assistant(admin_user: DATestUser | None) -> None:
response = requests.post(
ASSISTANTS_URL,
json={
"model": "gpt-3.5-turbo",
"name": "Test Assistant",
"description": "A test assistant",
"instructions": "You are a helpful assistant.",
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Test Assistant"
assert data["description"] == "A test assistant"
assert data["model"] == "gpt-3.5-turbo"
assert data["instructions"] == "You are a helpful assistant."
def test_retrieve_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Retrieve Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, retrieve the assistant
response = requests.get(
f"{ASSISTANTS_URL}/{assistant_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["name"] == "Retrieve Test"
def test_modify_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Modify Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, modify the assistant
response = requests.post(
f"{ASSISTANTS_URL}/{assistant_id}",
json={"name": "Modified Assistant", "instructions": "New instructions"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["name"] == "Modified Assistant"
assert data["instructions"] == "New instructions"
def test_delete_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Delete Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, delete the assistant
response = requests.delete(
f"{ASSISTANTS_URL}/{assistant_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["deleted"] is True
def test_list_assistants(admin_user: DATestUser | None) -> None:
# Create multiple assistants
for i in range(3):
requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": f"List Test {i}"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the assistants
response = requests.get(
ASSISTANTS_URL,
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["object"] == "list"
assert len(data["data"]) >= 3 # At least the 3 we just created
assert all(assistant["object"] == "assistant" for assistant in data["data"])
def test_list_assistants_pagination(admin_user: DATestUser | None) -> None:
# Create 5 assistants
for i in range(5):
requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": f"Pagination Test {i}"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# List assistants with limit
response = requests.get(
f"{ASSISTANTS_URL}?limit=2",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 2
assert data["has_more"] is True
# Get next page
before = data["last_id"]
response = requests.get(
f"{ASSISTANTS_URL}?limit=2&before={before}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 2
def test_assistant_not_found(admin_user: DATestUser | None) -> None:
non_existent_id = -99
response = requests.get(
f"{ASSISTANTS_URL}/{non_existent_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404

View File

@ -0,0 +1,133 @@
import uuid
from typing import Optional
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants/threads"
@pytest.fixture
def thread_id(admin_user: Optional[DATestUser]) -> str:
response = requests.post(
BASE_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return response.json()["id"]
def test_create_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
response = requests.post(
f"{BASE_URL}/{thread_id}/messages", # URL structure matches API
json={
"role": "user",
"content": "Hello, world!",
"file_ids": [],
"metadata": {"key": "value"},
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["thread_id"] == thread_id
assert response_json["role"] == "user"
assert response_json["content"] == [{"type": "text", "text": "Hello, world!"}]
assert response_json["metadata"] == {"key": "value"}
def test_list_messages(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the messages
response = requests.get(
f"{BASE_URL}/{thread_id}/messages",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["object"] == "list"
assert isinstance(response_json["data"], list)
assert len(response_json["data"]) > 0
assert "first_id" in response_json
assert "last_id" in response_json
assert "has_more" in response_json
def test_retrieve_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
create_response = requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
message_id = create_response.json()["id"]
# Now, retrieve the message
response = requests.get(
f"{BASE_URL}/{thread_id}/messages/{message_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["id"] == message_id
assert response_json["thread_id"] == thread_id
assert response_json["role"] == "user"
assert response_json["content"] == [{"type": "text", "text": "Test message"}]
def test_modify_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
create_response = requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
message_id = create_response.json()["id"]
# Now, modify the message
response = requests.post(
f"{BASE_URL}/{thread_id}/messages/{message_id}",
json={"metadata": {"new_key": "new_value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["id"] == message_id
assert response_json["thread_id"] == thread_id
assert response_json["metadata"] == {"new_key": "new_value"}
def test_error_handling(admin_user: Optional[DATestUser]) -> None:
non_existent_thread_id = str(uuid.uuid4())
non_existent_message_id = -99
# Test with non-existent thread
response = requests.post(
f"{BASE_URL}/{non_existent_thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404
# Test with non-existent message
response = requests.get(
f"{BASE_URL}/{non_existent_thread_id}/messages/{non_existent_message_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404

View File

@ -0,0 +1,137 @@
from uuid import UUID
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
@pytest.fixture
def run_id(admin_user: DATestUser | None, thread_id: UUID) -> str:
"""Create a run and return its ID."""
response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return response.json()["id"]
def test_create_run(
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
) -> None:
response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
"model": "gpt-3.5-turbo",
"instructions": "Test instructions",
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["object"] == "thread.run"
assert "created_at" in response_json
assert response_json["assistant_id"] == 0
assert UUID(response_json["thread_id"]) == thread_id
assert response_json["status"] == "queued"
assert response_json["model"] == "gpt-3.5-turbo"
assert response_json["instructions"] == "Test instructions"
def test_retrieve_run(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
retrieve_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert retrieve_response.status_code == 200
response_json = retrieve_response.json()
assert response_json["id"] == run_id
assert response_json["object"] == "thread.run"
assert "created_at" in response_json
assert UUID(response_json["thread_id"]) == thread_id
def test_cancel_run(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
cancel_response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/cancel",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert cancel_response.status_code == 200
response_json = cancel_response.json()
assert response_json["id"] == run_id
assert response_json["status"] == "cancelled"
def test_list_runs(
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
) -> None:
# Create a few runs
for _ in range(3):
requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the runs
list_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert list_response.status_code == 200
response_json = list_response.json()
assert isinstance(response_json, list)
assert len(response_json) >= 3
for run in response_json:
assert "id" in run
assert run["object"] == "thread.run"
assert "created_at" in run
assert UUID(run["thread_id"]) == thread_id
assert "status" in run
assert "model" in run
def test_list_run_steps(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
steps_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/steps",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert steps_response.status_code == 200
response_json = steps_response.json()
assert isinstance(response_json, list)
# Since DAnswer doesn't have an equivalent to run steps, we expect an empty list
assert len(response_json) == 0

View File

@ -0,0 +1,132 @@
from uuid import UUID
import requests
from danswer.db.models import ChatSessionSharedStatus
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
THREADS_URL = f"{API_SERVER_URL}/openai-assistants/threads"
def test_create_thread(admin_user: DATestUser | None) -> None:
response = requests.post(
THREADS_URL,
json={"messages": None, "metadata": {"key": "value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["object"] == "thread"
assert "created_at" in response_json
assert response_json["metadata"] == {"key": "value"}
def test_retrieve_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, retrieve the thread
retrieve_response = requests.get(
f"{THREADS_URL}/{thread_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert retrieve_response.status_code == 200
response_json = retrieve_response.json()
assert response_json["id"] == thread_id
assert response_json["object"] == "thread"
assert "created_at" in response_json
def test_modify_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, modify the thread
modify_response = requests.post(
f"{THREADS_URL}/{thread_id}",
json={"metadata": {"new_key": "new_value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert modify_response.status_code == 200
response_json = modify_response.json()
assert response_json["id"] == thread_id
assert response_json["metadata"] == {"new_key": "new_value"}
def test_delete_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, delete the thread
delete_response = requests.delete(
f"{THREADS_URL}/{thread_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert delete_response.status_code == 200
response_json = delete_response.json()
assert response_json["id"] == thread_id
assert response_json["object"] == "thread.deleted"
assert response_json["deleted"] is True
def test_list_threads(admin_user: DATestUser | None) -> None:
# Create a few threads
for _ in range(3):
requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the threads
list_response = requests.get(
THREADS_URL,
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert list_response.status_code == 200
response_json = list_response.json()
assert "sessions" in response_json
assert len(response_json["sessions"]) >= 3
for session in response_json["sessions"]:
assert "id" in session
assert "name" in session
assert "persona_id" in session
assert "time_created" in session
assert "shared_status" in session
assert "folder_id" in session
assert "current_alternate_model" in session
# Validate UUID
UUID(session["id"])
# Validate shared_status
assert session["shared_status"] in [
status.value for status in ChatSessionSharedStatus
]

View File

@ -0,0 +1,125 @@
import argparse
import os
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from openai import OpenAI
ASSISTANT_NAME = "Topic Analyzer"
SYSTEM_PROMPT = """
You are a helpful assistant that analyzes topics by searching through available \
documents and providing insights. These available documents come from common \
workplace tools like Slack, emails, Confluence, Google Drive, etc.
When analyzing a topic:
1. Search for relevant information using the search tool
2. Synthesize the findings into clear insights
3. Highlight key trends, patterns, or notable developments
4. Maintain objectivity and cite sources where relevant
"""
USER_PROMPT = """
Please analyze and provide insights about this topic: {topic}.
IMPORTANT: do not mention things that are not relevant to the specified topic. \
If there is no relevant information, just say "No relevant information found."
"""
def wait_on_run(client: OpenAI, run, thread):
while run.status == "queued" or run.status == "in_progress":
run = client.beta.threads.runs.retrieve(
thread_id=thread.id,
run_id=run.id,
)
time.sleep(0.5)
return run
def show_response(messages) -> None:
# Get only the assistant's response text
for message in messages.data[::-1]:
if message.role == "assistant":
for content in message.content:
if content.type == "text":
print(content.text)
break
def analyze_topics(topics: list[str]) -> None:
openai_api_key = os.environ.get(
"OPENAI_API_KEY", "<your OpenAI API key if not set as env var>"
)
danswer_api_key = os.environ.get(
"DANSWER_API_KEY", "<your Danswer API key if not set as env var>"
)
client = OpenAI(
api_key=openai_api_key,
base_url="http://localhost:8080/openai-assistants",
default_headers={
"Authorization": f"Bearer {danswer_api_key}",
},
)
# Create an assistant if it doesn't exist
try:
assistants = client.beta.assistants.list(limit=100)
# Find the Topic Analyzer assistant if it exists
assistant = next((a for a in assistants.data if a.name == ASSISTANT_NAME))
client.beta.assistants.delete(assistant.id)
except Exception:
pass
assistant = client.beta.assistants.create(
name=ASSISTANT_NAME,
instructions=SYSTEM_PROMPT,
tools=[{"type": "SearchTool"}], # type: ignore
model="gpt-4o",
)
# Process each topic individually
for topic in topics:
thread = client.beta.threads.create()
message = client.beta.threads.messages.create(
thread_id=thread.id,
role="user",
content=USER_PROMPT.format(topic=topic),
)
run = client.beta.threads.runs.create(
thread_id=thread.id,
assistant_id=assistant.id,
tools=[
{ # type: ignore
"type": "SearchTool",
"retrieval_details": {
"run_search": "always",
"filters": {
"time_cutoff": str(
datetime.now(timezone.utc) - timedelta(days=7)
)
},
},
}
],
)
run = wait_on_run(client, run, thread)
messages = client.beta.threads.messages.list(
thread_id=thread.id, order="asc", after=message.id
)
print(f"\nAnalysis for topic: {topic}")
print("-" * 40)
show_response(messages)
print()
# Example usage
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Analyze specific topics")
parser.add_argument("topics", nargs="+", help="Topics to analyze (one or more)")
args = parser.parse_args()
analyze_topics(args.topics)