Add strict json mode (#2917)

This commit is contained in:
Chris Weaver
2024-10-24 22:38:46 -07:00
committed by GitHub
parent d7a30b01d2
commit 4a47e9a841
11 changed files with 291 additions and 6 deletions

View File

@ -672,6 +672,7 @@ def stream_chat_message_objects(
all_docs_useful=selected_db_search_docs is not None
),
document_pruning_config=document_pruning_config,
structured_response_format=new_msg_req.structured_response_format,
),
prompt_config=prompt_config,
llm=(

View File

@ -116,6 +116,10 @@ class AnswerStyleConfig(BaseModel):
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
# right now, only used by the simple chat API
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_quotes_and_citation(self) -> "AnswerStyleConfig":

View File

@ -280,6 +280,7 @@ class DefaultMultiLLM(LLM):
tools: list[dict] | None,
tool_choice: ToolChoiceOptions | None,
stream: bool,
structured_response_format: dict | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
if isinstance(prompt, list):
prompt = [
@ -313,6 +314,11 @@ class DefaultMultiLLM(LLM):
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
**({"parallel_tool_calls": False} if tools else {}),
**(
{"response_format": structured_response_format}
if structured_response_format
else {}
),
**self._model_kwargs,
)
except Exception as e:
@ -336,12 +342,16 @@ class DefaultMultiLLM(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
response = cast(
litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False)
litellm.ModelResponse,
self._completion(
prompt, tools, tool_choice, False, structured_response_format
),
)
choice = response.choices[0]
if hasattr(choice, "message"):
@ -354,18 +364,21 @@ class DefaultMultiLLM(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
if DISABLE_LITELLM_STREAMING:
yield self.invoke(prompt)
yield self.invoke(prompt, tools, tool_choice, structured_response_format)
return
output = None
response = cast(
litellm.CustomStreamWrapper,
self._completion(prompt, tools, tool_choice, True),
self._completion(
prompt, tools, tool_choice, True, structured_response_format
),
)
try:
for part in response:

View File

@ -80,6 +80,7 @@ class CustomModelServer(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
return self._execute(prompt)
@ -88,5 +89,6 @@ class CustomModelServer(LLM):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)

View File

@ -88,11 +88,14 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(prompt, tools, tool_choice)
return self._invoke_implementation(
prompt, tools, tool_choice, structured_response_format
)
@abc.abstractmethod
def _invoke_implementation(
@ -100,6 +103,7 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> BaseMessage:
raise NotImplementedError
@ -108,11 +112,14 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._stream_implementation(prompt, tools, tool_choice)
return self._stream_implementation(
prompt, tools, tool_choice, structured_response_format
)
@abc.abstractmethod
def _stream_implementation(
@ -120,5 +127,6 @@ class LLM(abc.ABC):
prompt: LanguageModelInput,
tools: list[dict] | None = None,
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError

View File

@ -108,6 +108,10 @@ class CreateChatMessageRequest(ChunkContext):
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
@model_validator(mode="after")
def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest":
if self.search_doc_ids is None and self.retrieval_options is None:

View File

@ -176,6 +176,7 @@ def handle_simplified_chat_message(
chunks_above=0,
chunks_below=0,
full_doc=chat_message_req.full_doc,
structured_response_format=chat_message_req.structured_response_format,
)
packets = stream_chat_message_objects(
@ -202,7 +203,7 @@ def handle_send_message_simple_with_history(
raise HTTPException(status_code=400, detail="Messages cannot be zero length")
# This is a sanity check to make sure the chat history is valid
# It must start with a user message and alternate between user and assistant
# It must start with a user message and alternate beteen user and assistant
expected_role = MessageType.USER
for msg in req.messages:
if not msg.message:
@ -296,6 +297,7 @@ def handle_send_message_simple_with_history(
chunks_above=0,
chunks_below=0,
full_doc=req.full_doc,
structured_response_format=req.structured_response_format,
)
packets = stream_chat_message_objects(

View File

@ -48,6 +48,9 @@ class BasicCreateChatMessageRequest(ChunkContext):
query_override: str | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
@ -60,6 +63,9 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext):
skip_rerank: bool | None = None
# If search_doc_ids provided, then retrieval options are unused
search_doc_ids: list[int] | None = None
# only works if using an OpenAI model. See the following for more details:
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
class SimpleDoc(BaseModel):

View File

@ -0,0 +1,148 @@
from typing import Any
from typing import Dict
import requests
API_SERVER_URL = "http://localhost:3000" # Adjust this to your Danswer server URL
HEADERS = {"Content-Type": "application/json"}
API_KEY = "danswer-api-key" # API key here, if auth is enabled
def create_connector(
name: str,
source: str,
input_type: str,
connector_specific_config: Dict[str, Any],
is_public: bool = True,
groups: list[int] | None = None,
) -> Dict[str, Any]:
connector_update_request = {
"name": name,
"source": source,
"input_type": input_type,
"connector_specific_config": connector_specific_config,
"is_public": is_public,
"groups": groups or [],
}
response = requests.post(
url=f"{API_SERVER_URL}/api/manage/admin/connector",
json=connector_update_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
def create_credential(
name: str,
source: str,
credential_json: Dict[str, Any],
is_public: bool = True,
groups: list[int] | None = None,
) -> Dict[str, Any]:
credential_request = {
"name": name,
"source": source,
"credential_json": credential_json,
"admin_public": is_public,
"groups": groups or [],
}
response = requests.post(
url=f"{API_SERVER_URL}/api/manage/credential",
json=credential_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
def create_cc_pair(
connector_id: int,
credential_id: int,
name: str,
access_type: str = "public",
groups: list[int] | None = None,
) -> Dict[str, Any]:
cc_pair_request = {
"name": name,
"access_type": access_type,
"groups": groups or [],
}
response = requests.put(
url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}",
json=cc_pair_request,
headers=HEADERS,
)
response.raise_for_status()
return response.json()
def main() -> None:
# Create a Web connector
web_connector = create_connector(
name="Example Web Connector",
source="web",
input_type="load_state",
connector_specific_config={
"base_url": "https://example.com",
"web_connector_type": "recursive",
},
)
print(f"Created Web Connector: {web_connector}")
# Create a credential for the Web connector
web_credential = create_credential(
name="Example Web Credential",
source="web",
credential_json={}, # Web connectors typically don't need credentials
is_public=True,
)
print(f"Created Web Credential: {web_credential}")
# Create CC pair for Web connector
web_cc_pair = create_cc_pair(
connector_id=web_connector["id"],
credential_id=web_credential["id"],
name="Example Web CC Pair",
access_type="public",
)
print(f"Created Web CC Pair: {web_cc_pair}")
# Create a GitHub connector
github_connector = create_connector(
name="Example GitHub Connector",
source="github",
input_type="poll",
connector_specific_config={
"repo_owner": "example-owner",
"repo_name": "example-repo",
"include_prs": True,
"include_issues": True,
},
)
print(f"Created GitHub Connector: {github_connector}")
# Create a credential for the GitHub connector
github_credential = create_credential(
name="Example GitHub Credential",
source="github",
credential_json={"github_access_token": "your_github_access_token_here"},
is_public=True,
)
print(f"Created GitHub Credential: {github_credential}")
# Create CC pair for GitHub connector
github_cc_pair = create_cc_pair(
connector_id=github_connector["id"],
credential_id=github_credential["id"],
name="Example GitHub CC Pair",
access_type="public",
)
print(f"Created GitHub CC Pair: {github_cc_pair}")
if __name__ == "__main__":
main()

View File

@ -6,7 +6,9 @@ from sqlalchemy.orm import Session
from danswer.db.engine import get_session_context_manager
from danswer.db.search_settings import get_current_search_settings
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
@ -44,3 +46,11 @@ def vespa_client(db_session: Session) -> vespa_fixture:
@pytest.fixture
def reset() -> None:
reset_all()
@pytest.fixture
def new_admin_user() -> DATestUser | None:
try:
return UserManager.create(name="admin_user")
except Exception:
return None

View File

@ -1,7 +1,10 @@
import json
import requests
from danswer.configs.constants import MessageType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
@ -145,3 +148,87 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
# This ensures the the document we think we are referencing when we send the search_doc_ids in the second
# message is the document that we expect it to be
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
def test_send_message_simple_with_history_strict_json(
reset: None,
new_admin_user: DATestUser | None,
) -> None:
# create connectors
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=new_admin_user,
)
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=new_admin_user,
)
LLMProviderManager.create(user_performing_action=new_admin_user)
cc_pair_1.documents = DocumentManager.seed_dummy_docs(
cc_pair=cc_pair_1,
num_docs=NUM_DOCS,
api_key=api_key,
)
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [
{
"message": "List the names of the first three US presidents in JSON format",
"role": MessageType.USER.value,
}
],
"persona_id": 0,
"prompt_id": 0,
"structured_response_format": {
"type": "json_object",
"schema": {
"type": "object",
"properties": {
"presidents": {
"type": "array",
"items": {"type": "string"},
"description": "List of the first three US presidents",
}
},
"required": ["presidents"],
},
},
},
headers=new_admin_user.headers if new_admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
# Check that the answer is present
assert "answer" in response_json
assert response_json["answer"] is not None
# helper
def clean_json_string(json_string: str) -> str:
return json_string.strip().removeprefix("```json").removesuffix("```").strip()
# Attempt to parse the answer as JSON
try:
clean_answer = clean_json_string(response_json["answer"])
parsed_answer = json.loads(clean_answer)
assert isinstance(parsed_answer, dict)
assert "presidents" in parsed_answer
assert isinstance(parsed_answer["presidents"], list)
assert len(parsed_answer["presidents"]) == 3
for president in parsed_answer["presidents"]:
assert isinstance(president, str)
except json.JSONDecodeError:
assert False, "The answer is not a valid JSON object"
# Check that the answer_citationless is also valid JSON
assert "answer_citationless" in response_json
assert response_json["answer_citationless"] is not None
try:
clean_answer_citationless = clean_json_string(
response_json["answer_citationless"]
)
parsed_answer_citationless = json.loads(clean_answer_citationless)
assert isinstance(parsed_answer_citationless, dict)
except json.JSONDecodeError:
assert False, "The answer_citationless is not a valid JSON object"