From 4a47e9a841805a14b4fc9178faea1933fc1e2933 Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Thu, 24 Oct 2024 22:38:46 -0700 Subject: [PATCH] Add strict json mode (#2917) --- backend/danswer/chat/process_message.py | 1 + backend/danswer/llm/answering/models.py | 4 + backend/danswer/llm/chat_llm.py | 19 ++- backend/danswer/llm/custom_llm.py | 2 + backend/danswer/llm/interfaces.py | 12 +- .../danswer/server/query_and_chat/models.py | 4 + .../server/query_and_chat/chat_backend.py | 4 +- .../danswer/server/query_and_chat/models.py | 6 + .../scripts/add_connector_creation_script.py | 148 ++++++++++++++++++ backend/tests/integration/conftest.py | 10 ++ .../tests/dev_apis/test_simple_chat_api.py | 87 ++++++++++ 11 files changed, 291 insertions(+), 6 deletions(-) create mode 100644 backend/scripts/add_connector_creation_script.py diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index ea4e7be93d..f58a34c324 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -672,6 +672,7 @@ def stream_chat_message_objects( all_docs_useful=selected_db_search_docs is not None ), document_pruning_config=document_pruning_config, + structured_response_format=new_msg_req.structured_response_format, ), prompt_config=prompt_config, llm=( diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index fb5fa9c313..87c1297fe9 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -116,6 +116,10 @@ class AnswerStyleConfig(BaseModel): document_pruning_config: DocumentPruningConfig = Field( default_factory=DocumentPruningConfig ) + # forces the LLM to return a structured response, see + # https://platform.openai.com/docs/guides/structured-outputs/introduction + # right now, only used by the simple chat API + structured_response_format: dict | None = None @model_validator(mode="after") def check_quotes_and_citation(self) -> "AnswerStyleConfig": diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index d50f825318..d450fff0a6 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -280,6 +280,7 @@ class DefaultMultiLLM(LLM): tools: list[dict] | None, tool_choice: ToolChoiceOptions | None, stream: bool, + structured_response_format: dict | None = None, ) -> litellm.ModelResponse | litellm.CustomStreamWrapper: if isinstance(prompt, list): prompt = [ @@ -313,6 +314,11 @@ class DefaultMultiLLM(LLM): # NOTE: we can't pass this in if tools are not specified # or else OpenAI throws an error **({"parallel_tool_calls": False} if tools else {}), + **( + {"response_format": structured_response_format} + if structured_response_format + else {} + ), **self._model_kwargs, ) except Exception as e: @@ -336,12 +342,16 @@ class DefaultMultiLLM(LLM): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> BaseMessage: if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() response = cast( - litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False) + litellm.ModelResponse, + self._completion( + prompt, tools, tool_choice, False, structured_response_format + ), ) choice = response.choices[0] if hasattr(choice, "message"): @@ -354,18 +364,21 @@ class DefaultMultiLLM(LLM): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> Iterator[BaseMessage]: if LOG_DANSWER_MODEL_INTERACTIONS: self.log_model_configs() if DISABLE_LITELLM_STREAMING: - yield self.invoke(prompt) + yield self.invoke(prompt, tools, tool_choice, structured_response_format) return output = None response = cast( litellm.CustomStreamWrapper, - self._completion(prompt, tools, tool_choice, True), + self._completion( + prompt, tools, tool_choice, True, structured_response_format + ), ) try: for part in response: diff --git a/backend/danswer/llm/custom_llm.py b/backend/danswer/llm/custom_llm.py index 4a5ba7857c..6b80406cf2 100644 --- a/backend/danswer/llm/custom_llm.py +++ b/backend/danswer/llm/custom_llm.py @@ -80,6 +80,7 @@ class CustomModelServer(LLM): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> BaseMessage: return self._execute(prompt) @@ -88,5 +89,6 @@ class CustomModelServer(LLM): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> Iterator[BaseMessage]: yield self._execute(prompt) diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py index 6cb58e46c6..7deee11dfa 100644 --- a/backend/danswer/llm/interfaces.py +++ b/backend/danswer/llm/interfaces.py @@ -88,11 +88,14 @@ class LLM(abc.ABC): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> BaseMessage: self._precall(prompt) # TODO add a postcall to log model outputs independent of concrete class # implementation - return self._invoke_implementation(prompt, tools, tool_choice) + return self._invoke_implementation( + prompt, tools, tool_choice, structured_response_format + ) @abc.abstractmethod def _invoke_implementation( @@ -100,6 +103,7 @@ class LLM(abc.ABC): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> BaseMessage: raise NotImplementedError @@ -108,11 +112,14 @@ class LLM(abc.ABC): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> Iterator[BaseMessage]: self._precall(prompt) # TODO add a postcall to log model outputs independent of concrete class # implementation - return self._stream_implementation(prompt, tools, tool_choice) + return self._stream_implementation( + prompt, tools, tool_choice, structured_response_format + ) @abc.abstractmethod def _stream_implementation( @@ -120,5 +127,6 @@ class LLM(abc.ABC): prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, + structured_response_format: dict | None = None, ) -> Iterator[BaseMessage]: raise NotImplementedError diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 42f4100a24..1ca14f9283 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -108,6 +108,10 @@ class CreateChatMessageRequest(ChunkContext): # used for seeded chats to kick off the generation of an AI answer use_existing_user_message: bool = False + # forces the LLM to return a structured response, see + # https://platform.openai.com/docs/guides/structured-outputs/introduction + structured_response_format: dict | None = None + @model_validator(mode="after") def check_search_doc_ids_or_retrieval_options(self) -> "CreateChatMessageRequest": if self.search_doc_ids is None and self.retrieval_options is None: diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index dd637dcf08..b25ed8357d 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -176,6 +176,7 @@ def handle_simplified_chat_message( chunks_above=0, chunks_below=0, full_doc=chat_message_req.full_doc, + structured_response_format=chat_message_req.structured_response_format, ) packets = stream_chat_message_objects( @@ -202,7 +203,7 @@ def handle_send_message_simple_with_history( raise HTTPException(status_code=400, detail="Messages cannot be zero length") # This is a sanity check to make sure the chat history is valid - # It must start with a user message and alternate between user and assistant + # It must start with a user message and alternate beteen user and assistant expected_role = MessageType.USER for msg in req.messages: if not msg.message: @@ -296,6 +297,7 @@ def handle_send_message_simple_with_history( chunks_above=0, chunks_below=0, full_doc=req.full_doc, + structured_response_format=req.structured_response_format, ) packets = stream_chat_message_objects( diff --git a/backend/ee/danswer/server/query_and_chat/models.py b/backend/ee/danswer/server/query_and_chat/models.py index 052be683e9..4baf17ac8c 100644 --- a/backend/ee/danswer/server/query_and_chat/models.py +++ b/backend/ee/danswer/server/query_and_chat/models.py @@ -48,6 +48,9 @@ class BasicCreateChatMessageRequest(ChunkContext): query_override: str | None = None # If search_doc_ids provided, then retrieval options are unused search_doc_ids: list[int] | None = None + # only works if using an OpenAI model. See the following for more details: + # https://platform.openai.com/docs/guides/structured-outputs/introduction + structured_response_format: dict | None = None class BasicCreateChatMessageWithHistoryRequest(ChunkContext): @@ -60,6 +63,9 @@ class BasicCreateChatMessageWithHistoryRequest(ChunkContext): skip_rerank: bool | None = None # If search_doc_ids provided, then retrieval options are unused search_doc_ids: list[int] | None = None + # only works if using an OpenAI model. See the following for more details: + # https://platform.openai.com/docs/guides/structured-outputs/introduction + structured_response_format: dict | None = None class SimpleDoc(BaseModel): diff --git a/backend/scripts/add_connector_creation_script.py b/backend/scripts/add_connector_creation_script.py new file mode 100644 index 0000000000..9a1944080c --- /dev/null +++ b/backend/scripts/add_connector_creation_script.py @@ -0,0 +1,148 @@ +from typing import Any +from typing import Dict + +import requests + +API_SERVER_URL = "http://localhost:3000" # Adjust this to your Danswer server URL +HEADERS = {"Content-Type": "application/json"} +API_KEY = "danswer-api-key" # API key here, if auth is enabled + + +def create_connector( + name: str, + source: str, + input_type: str, + connector_specific_config: Dict[str, Any], + is_public: bool = True, + groups: list[int] | None = None, +) -> Dict[str, Any]: + connector_update_request = { + "name": name, + "source": source, + "input_type": input_type, + "connector_specific_config": connector_specific_config, + "is_public": is_public, + "groups": groups or [], + } + + response = requests.post( + url=f"{API_SERVER_URL}/api/manage/admin/connector", + json=connector_update_request, + headers=HEADERS, + ) + response.raise_for_status() + return response.json() + + +def create_credential( + name: str, + source: str, + credential_json: Dict[str, Any], + is_public: bool = True, + groups: list[int] | None = None, +) -> Dict[str, Any]: + credential_request = { + "name": name, + "source": source, + "credential_json": credential_json, + "admin_public": is_public, + "groups": groups or [], + } + + response = requests.post( + url=f"{API_SERVER_URL}/api/manage/credential", + json=credential_request, + headers=HEADERS, + ) + response.raise_for_status() + return response.json() + + +def create_cc_pair( + connector_id: int, + credential_id: int, + name: str, + access_type: str = "public", + groups: list[int] | None = None, +) -> Dict[str, Any]: + cc_pair_request = { + "name": name, + "access_type": access_type, + "groups": groups or [], + } + + response = requests.put( + url=f"{API_SERVER_URL}/api/manage/connector/{connector_id}/credential/{credential_id}", + json=cc_pair_request, + headers=HEADERS, + ) + response.raise_for_status() + return response.json() + + +def main() -> None: + # Create a Web connector + web_connector = create_connector( + name="Example Web Connector", + source="web", + input_type="load_state", + connector_specific_config={ + "base_url": "https://example.com", + "web_connector_type": "recursive", + }, + ) + print(f"Created Web Connector: {web_connector}") + + # Create a credential for the Web connector + web_credential = create_credential( + name="Example Web Credential", + source="web", + credential_json={}, # Web connectors typically don't need credentials + is_public=True, + ) + print(f"Created Web Credential: {web_credential}") + + # Create CC pair for Web connector + web_cc_pair = create_cc_pair( + connector_id=web_connector["id"], + credential_id=web_credential["id"], + name="Example Web CC Pair", + access_type="public", + ) + print(f"Created Web CC Pair: {web_cc_pair}") + + # Create a GitHub connector + github_connector = create_connector( + name="Example GitHub Connector", + source="github", + input_type="poll", + connector_specific_config={ + "repo_owner": "example-owner", + "repo_name": "example-repo", + "include_prs": True, + "include_issues": True, + }, + ) + print(f"Created GitHub Connector: {github_connector}") + + # Create a credential for the GitHub connector + github_credential = create_credential( + name="Example GitHub Credential", + source="github", + credential_json={"github_access_token": "your_github_access_token_here"}, + is_public=True, + ) + print(f"Created GitHub Credential: {github_credential}") + + # Create CC pair for GitHub connector + github_cc_pair = create_cc_pair( + connector_id=github_connector["id"], + credential_id=github_credential["id"], + name="Example GitHub CC Pair", + access_type="public", + ) + print(f"Created GitHub CC Pair: {github_cc_pair}") + + +if __name__ == "__main__": + main() diff --git a/backend/tests/integration/conftest.py b/backend/tests/integration/conftest.py index 77d9e0e702..7b4d55cf3d 100644 --- a/backend/tests/integration/conftest.py +++ b/backend/tests/integration/conftest.py @@ -6,7 +6,9 @@ from sqlalchemy.orm import Session from danswer.db.engine import get_session_context_manager from danswer.db.search_settings import get_current_search_settings +from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.reset import reset_all +from tests.integration.common_utils.test_models import DATestUser from tests.integration.common_utils.vespa import vespa_fixture @@ -44,3 +46,11 @@ def vespa_client(db_session: Session) -> vespa_fixture: @pytest.fixture def reset() -> None: reset_all() + + +@pytest.fixture +def new_admin_user() -> DATestUser | None: + try: + return UserManager.create(name="admin_user") + except Exception: + return None diff --git a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py index 0a4e7b40b5..e773414b3a 100644 --- a/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py +++ b/backend/tests/integration/tests/dev_apis/test_simple_chat_api.py @@ -1,7 +1,10 @@ +import json + import requests from danswer.configs.constants import MessageType from tests.integration.common_utils.constants import API_SERVER_URL +from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.constants import NUM_DOCS from tests.integration.common_utils.managers.api_key import APIKeyManager from tests.integration.common_utils.managers.cc_pair import CCPairManager @@ -145,3 +148,87 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> # This ensures the the document we think we are referencing when we send the search_doc_ids in the second # message is the document that we expect it to be assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id + + +def test_send_message_simple_with_history_strict_json( + reset: None, + new_admin_user: DATestUser | None, +) -> None: + # create connectors + cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch( + user_performing_action=new_admin_user, + ) + api_key: DATestAPIKey = APIKeyManager.create( + user_performing_action=new_admin_user, + ) + LLMProviderManager.create(user_performing_action=new_admin_user) + cc_pair_1.documents = DocumentManager.seed_dummy_docs( + cc_pair=cc_pair_1, + num_docs=NUM_DOCS, + api_key=api_key, + ) + + response = requests.post( + f"{API_SERVER_URL}/chat/send-message-simple-with-history", + json={ + "messages": [ + { + "message": "List the names of the first three US presidents in JSON format", + "role": MessageType.USER.value, + } + ], + "persona_id": 0, + "prompt_id": 0, + "structured_response_format": { + "type": "json_object", + "schema": { + "type": "object", + "properties": { + "presidents": { + "type": "array", + "items": {"type": "string"}, + "description": "List of the first three US presidents", + } + }, + "required": ["presidents"], + }, + }, + }, + headers=new_admin_user.headers if new_admin_user else GENERAL_HEADERS, + ) + assert response.status_code == 200 + + response_json = response.json() + + # Check that the answer is present + assert "answer" in response_json + assert response_json["answer"] is not None + + # helper + def clean_json_string(json_string: str) -> str: + return json_string.strip().removeprefix("```json").removesuffix("```").strip() + + # Attempt to parse the answer as JSON + try: + clean_answer = clean_json_string(response_json["answer"]) + parsed_answer = json.loads(clean_answer) + assert isinstance(parsed_answer, dict) + assert "presidents" in parsed_answer + assert isinstance(parsed_answer["presidents"], list) + assert len(parsed_answer["presidents"]) == 3 + for president in parsed_answer["presidents"]: + assert isinstance(president, str) + except json.JSONDecodeError: + assert False, "The answer is not a valid JSON object" + + # Check that the answer_citationless is also valid JSON + assert "answer_citationless" in response_json + assert response_json["answer_citationless"] is not None + try: + clean_answer_citationless = clean_json_string( + response_json["answer_citationless"] + ) + parsed_answer_citationless = json.loads(clean_answer_citationless) + assert isinstance(parsed_answer_citationless, dict) + except json.JSONDecodeError: + assert False, "The answer_citationless is not a valid JSON object"