mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-30 09:40:50 +02:00
Add strict json mode (#2917)
This commit is contained in:
@ -672,6 +672,7 @@ def stream_chat_message_objects(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
),
|
||||
document_pruning_config=document_pruning_config,
|
||||
structured_response_format=new_msg_req.structured_response_format,
|
||||
),
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
|
@ -116,6 +116,10 @@ class AnswerStyleConfig(BaseModel):
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
|
@ -280,6 +280,7 @@ class DefaultMultiLLM(LLM):
|
||||
tools: list[dict] | None,
|
||||
tool_choice: ToolChoiceOptions | None,
|
||||
stream: bool,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
if isinstance(prompt, list):
|
||||
prompt = [
|
||||
@ -313,6 +314,11 @@ class DefaultMultiLLM(LLM):
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
**({"parallel_tool_calls": False} if tools else {}),
|
||||
**(
|
||||
{"response_format": structured_response_format}
|
||||
if structured_response_format
|
||||
else {}
|
||||
),
|
||||
**self._model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
@ -336,12 +342,16 @@ class DefaultMultiLLM(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
response = cast(
|
||||
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
|
||||
litellm.ModelResponse,
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, False, structured_response_format
|
||||
),
|
||||
)
|
||||
choice = response.choices[0]
|
||||
if hasattr(choice, "message"):
|
||||
@ -354,18 +364,21 @@ class DefaultMultiLLM(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
|
||||
if DISABLE_LITELLM_STREAMING:
|
||||
yield self.invoke(prompt)
|
||||
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
|
||||
return
|
||||
|
||||
output = None
|
||||
response = cast(
|
||||
litellm.CustomStreamWrapper,
|
||||
self._completion(prompt, tools, tool_choice, True),
|
||||
self._completion(
|
||||
prompt, tools, tool_choice, True, structured_response_format
|
||||
),
|
||||
)
|
||||
try:
|
||||
for part in response:
|
||||
|
@ -80,6 +80,7 @@ class CustomModelServer(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
return self._execute(prompt)
|
||||
|
||||
@ -88,5 +89,6 @@ class CustomModelServer(LLM):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
yield self._execute(prompt)
|
||||
|
@ -88,11 +88,14 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._invoke_implementation(prompt, tools, tool_choice)
|
||||
return self._invoke_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _invoke_implementation(
|
||||
@ -100,6 +103,7 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -108,11 +112,14 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._stream_implementation(prompt, tools, tool_choice)
|
||||
return self._stream_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _stream_implementation(
|
||||
@ -120,5 +127,6 @@ class LLM(abc.ABC):
|
||||
prompt: LanguageModelInput,
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
@ -108,6 +108,10 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
|
||||
if self.search_doc_ids is None and self.retrieval_options is None:
|
||||
|
@ -176,6 +176,7 @@ def handle_simplified_chat_message(
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=chat_message_req.full_doc,
|
||||
structured_response_format=chat_message_req.structured_response_format,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
@ -202,7 +203,7 @@ def handle_send_message_simple_with_history(
|
||||
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
|
||||
|
||||
# This is a sanity check to make sure the chat history is valid
|
||||
# It must start with a user message and alternate between user and assistant
|
||||
# It must start with a user message and alternate beteen user and assistant
|
||||
expected_role = MessageType.USER
|
||||
for msg in req.messages:
|
||||
if not msg.message:
|
||||
@ -296,6 +297,7 @@ def handle_send_message_simple_with_history(
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
full_doc=req.full_doc,
|
||||
structured_response_format=req.structured_response_format,
|
||||
)
|
||||
|
||||
packets = stream_chat_message_objects(
|
||||
|
@ -48,6 +48,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
|
||||
query_override: str | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
|
||||
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
@ -60,6 +63,9 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
|
||||
skip_rerank: bool | None = None
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None = None
|
||||
# only works if using an OpenAI model. See the following for more details:
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
|
||||
class SimpleDoc(BaseModel):
|
||||
|
148
backend/scripts/add_connector_creation_script.py
Normal file
148
backend/scripts/add_connector_creation_script.py
Normal file
@ -0,0 +1,148 @@
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
|
||||
API_SERVER_URL = "http://localhost:3000" # Adjust this to your Danswer server URL
|
||||
HEADERS = {"Content-Type": "application/json"}
|
||||
API_KEY = "danswer-api-key" # API key here, if auth is enabled
|
||||
|
||||
|
||||
def create_connector(
|
||||
name: str,
|
||||
source: str,
|
||||
input_type: str,
|
||||
connector_specific_config: Dict[str, Any],
|
||||
is_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
connector_update_request = {
|
||||
"name": name,
|
||||
"source": source,
|
||||
"input_type": input_type,
|
||||
"connector_specific_config": connector_specific_config,
|
||||
"is_public": is_public,
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/api/manage/admin/connector",
|
||||
json=connector_update_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def create_credential(
|
||||
name: str,
|
||||
source: str,
|
||||
credential_json: Dict[str, Any],
|
||||
is_public: bool = True,
|
||||
groups: list[int] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
credential_request = {
|
||||
"name": name,
|
||||
"source": source,
|
||||
"credential_json": credential_json,
|
||||
"admin_public": is_public,
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/api/manage/credential",
|
||||
json=credential_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def create_cc_pair(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
name: str,
|
||||
access_type: str = "public",
|
||||
groups: list[int] | None = None,
|
||||
) -> Dict[str, Any]:
|
||||
cc_pair_request = {
|
||||
"name": name,
|
||||
"access_type": access_type,
|
||||
"groups": groups or [],
|
||||
}
|
||||
|
||||
response = requests.put(
|
||||
url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}",
|
||||
json=cc_pair_request,
|
||||
headers=HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Create a Web connector
|
||||
web_connector = create_connector(
|
||||
name="Example Web Connector",
|
||||
source="web",
|
||||
input_type="load_state",
|
||||
connector_specific_config={
|
||||
"base_url": "https://example.com",
|
||||
"web_connector_type": "recursive",
|
||||
},
|
||||
)
|
||||
print(f"Created Web Connector: {web_connector}")
|
||||
|
||||
# Create a credential for the Web connector
|
||||
web_credential = create_credential(
|
||||
name="Example Web Credential",
|
||||
source="web",
|
||||
credential_json={}, # Web connectors typically don't need credentials
|
||||
is_public=True,
|
||||
)
|
||||
print(f"Created Web Credential: {web_credential}")
|
||||
|
||||
# Create CC pair for Web connector
|
||||
web_cc_pair = create_cc_pair(
|
||||
connector_id=web_connector["id"],
|
||||
credential_id=web_credential["id"],
|
||||
name="Example Web CC Pair",
|
||||
access_type="public",
|
||||
)
|
||||
print(f"Created Web CC Pair: {web_cc_pair}")
|
||||
|
||||
# Create a GitHub connector
|
||||
github_connector = create_connector(
|
||||
name="Example GitHub Connector",
|
||||
source="github",
|
||||
input_type="poll",
|
||||
connector_specific_config={
|
||||
"repo_owner": "example-owner",
|
||||
"repo_name": "example-repo",
|
||||
"include_prs": True,
|
||||
"include_issues": True,
|
||||
},
|
||||
)
|
||||
print(f"Created GitHub Connector: {github_connector}")
|
||||
|
||||
# Create a credential for the GitHub connector
|
||||
github_credential = create_credential(
|
||||
name="Example GitHub Credential",
|
||||
source="github",
|
||||
credential_json={"github_access_token": "your_github_access_token_here"},
|
||||
is_public=True,
|
||||
)
|
||||
print(f"Created GitHub Credential: {github_credential}")
|
||||
|
||||
# Create CC pair for GitHub connector
|
||||
github_cc_pair = create_cc_pair(
|
||||
connector_id=github_connector["id"],
|
||||
credential_id=github_credential["id"],
|
||||
name="Example GitHub CC Pair",
|
||||
access_type="public",
|
||||
)
|
||||
print(f"Created GitHub CC Pair: {github_cc_pair}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -6,7 +6,9 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
@ -44,3 +46,11 @@ def vespa_client(db_session: Session) -> vespa_fixture:
|
||||
@pytest.fixture
|
||||
def reset() -> None:
|
||||
reset_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def new_admin_user() -> DATestUser | None:
|
||||
try:
|
||||
return UserManager.create(name="admin_user")
|
||||
except Exception:
|
||||
return None
|
||||
|
@ -1,7 +1,10 @@
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import NUM_DOCS
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
@ -145,3 +148,87 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
|
||||
# This ensures the the document we think we are referencing when we send the search_doc_ids in the second
|
||||
# message is the document that we expect it to be
|
||||
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
|
||||
|
||||
|
||||
def test_send_message_simple_with_history_strict_json(
|
||||
reset: None,
|
||||
new_admin_user: DATestUser | None,
|
||||
) -> None:
|
||||
# create connectors
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=new_admin_user,
|
||||
)
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=new_admin_user,
|
||||
)
|
||||
LLMProviderManager.create(user_performing_action=new_admin_user)
|
||||
cc_pair_1.documents = DocumentManager.seed_dummy_docs(
|
||||
cc_pair=cc_pair_1,
|
||||
num_docs=NUM_DOCS,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
|
||||
json={
|
||||
"messages": [
|
||||
{
|
||||
"message": "List the names of the first three US presidents in JSON format",
|
||||
"role": MessageType.USER.value,
|
||||
}
|
||||
],
|
||||
"persona_id": 0,
|
||||
"prompt_id": 0,
|
||||
"structured_response_format": {
|
||||
"type": "json_object",
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"presidents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "List of the first three US presidents",
|
||||
}
|
||||
},
|
||||
"required": ["presidents"],
|
||||
},
|
||||
},
|
||||
},
|
||||
headers=new_admin_user.headers if new_admin_user else GENERAL_HEADERS,
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
# Check that the answer is present
|
||||
assert "answer" in response_json
|
||||
assert response_json["answer"] is not None
|
||||
|
||||
# helper
|
||||
def clean_json_string(json_string: str) -> str:
|
||||
return json_string.strip().removeprefix("```json").removesuffix("```").strip()
|
||||
|
||||
# Attempt to parse the answer as JSON
|
||||
try:
|
||||
clean_answer = clean_json_string(response_json["answer"])
|
||||
parsed_answer = json.loads(clean_answer)
|
||||
assert isinstance(parsed_answer, dict)
|
||||
assert "presidents" in parsed_answer
|
||||
assert isinstance(parsed_answer["presidents"], list)
|
||||
assert len(parsed_answer["presidents"]) == 3
|
||||
for president in parsed_answer["presidents"]:
|
||||
assert isinstance(president, str)
|
||||
except json.JSONDecodeError:
|
||||
assert False, "The answer is not a valid JSON object"
|
||||
|
||||
# Check that the answer_citationless is also valid JSON
|
||||
assert "answer_citationless" in response_json
|
||||
assert response_json["answer_citationless"] is not None
|
||||
try:
|
||||
clean_answer_citationless = clean_json_string(
|
||||
response_json["answer_citationless"]
|
||||
)
|
||||
parsed_answer_citationless = json.loads(clean_answer_citationless)
|
||||
assert isinstance(parsed_answer_citationless, dict)
|
||||
except json.JSONDecodeError:
|
||||
assert False, "The answer_citationless is not a valid JSON object"
|
||||
|
Reference in New Issue
Block a user