From 20b25e322f67fb608c7fb142b8a5a6c7a428548d Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 11 May 2023 22:49:26 -0700 Subject: [PATCH] DAN-23 Stream model output (#30) --- backend/danswer/chunking/models.py | 7 +- backend/danswer/configs/model_configs.py | 11 +++ backend/danswer/datastores/qdrant/indexing.py | 2 +- backend/danswer/direct_qa/interfaces.py | 14 +++- backend/danswer/direct_qa/question_answer.py | 74 +++++++++++++++++ .../semantic_search/semantic_search.py | 12 +-- backend/danswer/server/search_backend.py | 61 ++++++++++++-- backend/scripts/simulate_frontend.py | 80 ++++++++++++------- 8 files changed, 210 insertions(+), 51 deletions(-) diff --git a/backend/danswer/chunking/models.py b/backend/danswer/chunking/models.py index 9f2b39154..4bfc1e685 100644 --- a/backend/danswer/chunking/models.py +++ b/backend/danswer/chunking/models.py @@ -1,6 +1,5 @@ import inspect from dataclasses import dataclass -from typing import Optional from danswer.connectors.models import Document @@ -10,9 +9,9 @@ class BaseChunk: chunk_id: int blurb: str # The first sentence(s) of the first Section of the chunk content: str - source_links: Optional[ - dict[int, str] - ] # Holds the link and the offsets into the raw Chunk text + source_links: dict[ + int, str + ] | None # Holds the link and the offsets into the raw Chunk text section_continuation: bool # True if this Chunk's start is not at the start of a Section diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index efee7c510..b2e0169dd 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -1,5 +1,16 @@ import os +# Important considerations when choosing models +# Max tokens count needs to be high considering use case (at least 512) +# Models used must be MIT or Apache license +# Inference/Indexing speed + +# Bi/Cross-Encoder Model Configs +# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding) +DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1" +CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" +DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model + QUERY_EMBEDDING_CONTEXT_SIZE = 256 DOC_EMBEDDING_CONTEXT_SIZE = 512 CROSS_EMBED_CONTEXT_SIZE = 512 diff --git a/backend/danswer/datastores/qdrant/indexing.py b/backend/danswer/datastores/qdrant/indexing.py index 34a07039f..0bb0f474e 100644 --- a/backend/danswer/datastores/qdrant/indexing.py +++ b/backend/danswer/datastores/qdrant/indexing.py @@ -11,7 +11,7 @@ from danswer.configs.constants import SECTION_CONTINUATION from danswer.configs.constants import SEMANTIC_IDENTIFIER from danswer.configs.constants import SOURCE_LINKS from danswer.configs.constants import SOURCE_TYPE -from danswer.semantic_search.semantic_search import DOC_EMBEDDING_DIM +from danswer.configs.model_configs import DOC_EMBEDDING_DIM from danswer.utils.clients import get_qdrant_client from danswer.utils.logging import setup_logger from qdrant_client import QdrantClient diff --git a/backend/danswer/direct_qa/interfaces.py b/backend/danswer/direct_qa/interfaces.py index d1d9b939b..6622a55d3 100644 --- a/backend/danswer/direct_qa/interfaces.py +++ b/backend/danswer/direct_qa/interfaces.py @@ -1,5 +1,5 @@ import abc -from typing import * +from typing import Any from danswer.chunking.models import InferenceChunk @@ -7,6 +7,16 @@ from danswer.chunking.models import InferenceChunk class QAModel: @abc.abstractmethod def answer_question( - self, query: str, context_docs: list[InferenceChunk] + self, + query: str, + context_docs: list[InferenceChunk], ) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]: raise NotImplementedError + + @abc.abstractmethod + def answer_question_stream( + self, + query: str, + context_docs: list[InferenceChunk], + ) -> Any: + raise NotImplementedError diff --git a/backend/danswer/direct_qa/question_answer.py b/backend/danswer/direct_qa/question_answer.py index 5e9ea64bf..65ded226c 100644 --- a/backend/danswer/direct_qa/question_answer.py +++ b/backend/danswer/direct_qa/question_answer.py @@ -2,6 +2,8 @@ import json import math import re from collections.abc import Callable +from collections.abc import Generator +from typing import Any from typing import Dict from typing import Optional from typing import Tuple @@ -37,6 +39,10 @@ logger = setup_logger() openai.api_key = OPENAI_API_KEY +def yield_json_line(json_dict): + return json.dumps(json_dict) + "\n" + + def extract_answer_quotes_freeform( answer_raw: str, ) -> Tuple[Optional[str], Optional[list[str]]]: @@ -165,6 +171,15 @@ def process_answer( return answer, quotes_dict +def stream_answer_end(answer_so_far: str, next_token: str) -> bool: + next_token = next_token.replace('\\"', "") + if answer_so_far and answer_so_far[-1] != "\\": + next_token = next_token[1:] + if '"' in next_token: + return True + return False + + class OpenAICompletionQA(QAModel): def __init__( self, @@ -207,6 +222,57 @@ class OpenAICompletionQA(QAModel): answer, quotes_dict = process_answer(model_output, context_docs) return answer, quotes_dict + def answer_question_stream( + self, query: str, context_docs: list[InferenceChunk] + ) -> Generator[dict[str, Any] | None, None, None]: + top_contents = [ranked_chunk.content for ranked_chunk in context_docs] + filled_prompt = self.prompt_processor(query, top_contents) + logger.debug(filled_prompt) + + try: + response = openai.Completion.create( + prompt=filled_prompt, + temperature=0, + top_p=1, + frequency_penalty=0, + presence_penalty=0, + model=self.model_version, + max_tokens=self.max_output_tokens, + stream=True, + ) + + model_output = "" + found_answer_start = False + found_answer_end = False + # iterate through the stream of events + for event in response: + event_text = event["choices"][0]["text"] + model_previous = model_output + model_output += event_text + + if not found_answer_start and '{"answer":"' in model_output.replace( + " ", "" + ).replace("\n", ""): + found_answer_start = True + continue + + if found_answer_start and not found_answer_end: + if stream_answer_end(model_previous, event_text): + found_answer_end = True + continue + yield {"answer data": event_text} + + except Exception as e: + logger.exception(e) + model_output = "Model Failure" + + logger.debug(model_output) + + answer, quotes_dict = process_answer(model_output, context_docs) + logger.info(answer) + + yield quotes_dict + class OpenAIChatCompletionQA(QAModel): def __init__( @@ -257,3 +323,11 @@ class OpenAIChatCompletionQA(QAModel): answer, quotes_dict = process_answer(model_output, context_docs) return answer, quotes_dict + + @log_function_time() + def answer_question_stream( + self, query: str, context_docs: list[InferenceChunk] + ) -> Any: + raise NotImplementedError( + "Danswer with chat completion does not support streaming yet" + ) diff --git a/backend/danswer/semantic_search/semantic_search.py b/backend/danswer/semantic_search/semantic_search.py index 134986b3e..daaf0ee0e 100644 --- a/backend/danswer/semantic_search/semantic_search.py +++ b/backend/danswer/semantic_search/semantic_search.py @@ -22,7 +22,9 @@ from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE +from danswer.configs.model_configs import CROSS_ENCODER_MODEL from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE +from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL from danswer.datastores.interfaces import Datastore from danswer.datastores.interfaces import DatastoreFilter from danswer.utils.logging import setup_logger @@ -33,16 +35,6 @@ from sentence_transformers import SentenceTransformer # type: ignore logger = setup_logger() -# Important considerations when choosing models -# Max tokens count needs to be high considering use case (at least 512) -# Models used must be MIT or Apache license -# Inference/Indexing speed - -# Bi/Cross-Encoder Model Configs -# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding) -DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1" -DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model -CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2" _EMBED_MODEL: None | SentenceTransformer = None _RERANK_MODEL: None | CrossEncoder = None diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 0eab6b5ae..ae61749c3 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -1,3 +1,4 @@ +import json import time from http import HTTPStatus @@ -11,6 +12,7 @@ from danswer.datastores import create_datastore from danswer.db.engine import build_async_engine from danswer.db.models import User from danswer.direct_qa import get_default_backend_qa_model +from danswer.direct_qa.question_answer import yield_json_line from danswer.semantic_search.semantic_search import retrieve_ranked_documents from danswer.server.models import KeywordResponse from danswer.server.models import QAQuestion @@ -24,6 +26,7 @@ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Request +from fastapi.responses import StreamingResponse from fastapi_users.db import SQLAlchemyUserDatabase from sqlalchemy.ext.asyncio import AsyncSession @@ -78,29 +81,73 @@ async def promote_admin( return -@router.post("/direct-qa", response_model=QAResponse) +@router.get("/direct-qa", response_model=QAResponse) def direct_qa(question: QAQuestion): start_time = time.time() - qa_model = get_default_backend_qa_model() + query = question.query collection = question.collection filters = question.filters - - datastore = create_datastore(collection) - logger.info(f"Received semantic query: {query}") - ranked_chunks = retrieve_ranked_documents(query, filters, datastore) + ranked_chunks = retrieve_ranked_documents( + query, filters, create_datastore(collection) + ) if not ranked_chunks: return {"answer": None, "quotes": None} + qa_model = get_default_backend_qa_model() answer, quotes = qa_model.answer_question(query, ranked_chunks) + logger.info(f"Total QA took {time.time() - start_time} seconds") return QAResponse(answer=answer, quotes=quotes) -@router.post("/keyword-search", response_model=KeywordResponse) +@router.get("/stream-direct-qa") +def stream_direct_qa(question: QAQuestion): + top_documents_key = "top_documents" + answer_key = "answer" + quotes_key = "quotes" + + def stream_qa_portions(): + query = question.query + collection = question.collection + filters = question.filters + logger.info(f"Received semantic query: {query}") + + ranked_chunks = retrieve_ranked_documents( + query, filters, create_datastore(collection) + ) + if not ranked_chunks: + return yield_json_line( + {top_documents_key: None, answer_key: None, quotes_key: None} + ) + + linked_chunks = [ + chunk + for chunk in ranked_chunks + if chunk.source_links and "0" in chunk.source_links + ] + top_docs = { + top_documents_key: { + "document section links": [ + chunk.source_links["0"] for chunk in linked_chunks + ], + "blurbs": [chunk.blurb for chunk in linked_chunks], + } + } + yield yield_json_line(top_docs) + + qa_model = get_default_backend_qa_model() + for response_dict in qa_model.answer_question_stream(query, ranked_chunks): + logger.debug(response_dict) + yield yield_json_line(response_dict) + + return StreamingResponse(stream_qa_portions(), media_type="application/json") + + +@router.get("/keyword-search", response_model=KeywordResponse) def keyword_search(question: QAQuestion): ts_client = TSClient.get_instance() query = question.query diff --git a/backend/scripts/simulate_frontend.py b/backend/scripts/simulate_frontend.py index efd07b721..1afd503a6 100644 --- a/backend/scripts/simulate_frontend.py +++ b/backend/scripts/simulate_frontend.py @@ -22,19 +22,27 @@ if __name__ == "__main__": ) parser.add_argument( - "-s", + "-t", "--source-types", type=str, help="Comma separated list of source types to filter by", ) + parser.add_argument( + "-s", + "--stream", + action="store_true", + help="Enable streaming response", + ) + parser.add_argument("query", nargs="*", help="The query to process") while True: try: user_input = input( "\n\nAsk any question:\n" - " - prefix with -s to add a filter by source(s)\n" + " - prefix with -t to add a filter by source type(s)\n" + " - prefix with -s to stream answer\n" " - input an empty string to rerun last query\n\t" ) @@ -55,40 +63,58 @@ if __name__ == "__main__": source_types = source_types[0] query = " ".join(args.query) - endpoint = f"http://127.0.0.1:{APP_PORT}/direct-qa" + endpoint = ( + f"http://127.0.0.1:{APP_PORT}/direct-qa" + if not args.stream + else f"http://127.0.0.1:{APP_PORT}/stream-direct-qa" + ) if args.keyword_search: endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search" + raise NotImplementedError("keyword search is not supported for now") query_json = { "query": query, "collection": QDRANT_DEFAULT_COLLECTION, "filters": [{SOURCE_TYPE: source_types}], } - - response = requests.post(endpoint, json=query_json) - contents = json.loads(response.content) - if keyword_search: - if contents["results"]: - for link in contents["results"]: - print(link) + if not args.stream: + response = requests.get(endpoint, json=query_json) + contents = json.loads(response.content) + if keyword_search: + if contents["results"]: + for link in contents["results"]: + print(link) + else: + print("No matches found") else: - print("No matches found") + answer = contents.get("answer") + if answer: + print("Answer: " + answer) + else: + print("Answer: ?") + if contents.get("quotes"): + for ind, (quote, quote_info) in enumerate( + contents["quotes"].items() + ): + print(f"Quote {str(ind + 1)}:\n{quote}") + print( + f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}" + ) + print(f"Blurb: {quote_info[BLURB]}") + print(f"Link: {quote_info[SOURCE_LINK]}") + print(f"Source: {quote_info[SOURCE_TYPE]}") + else: + print("No quotes found") else: - answer = contents.get("answer") - if answer: - print("Answer: " + answer) - else: - print("Answer: ?") - if contents.get("quotes"): - for ind, (quote, quote_info) in enumerate( - contents["quotes"].items() - ): - print(f"Quote {str(ind + 1)}:\n{quote}") - print(f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}") - print(f"Blurb: {quote_info[BLURB]}") - print(f"Link: {quote_info[SOURCE_LINK]}") - print(f"Source: {quote_info[SOURCE_TYPE]}") - else: - print("No quotes found") + answer = "" + with requests.get(endpoint, json=query_json, stream=True) as r: + for json_response in r.iter_lines(): + response_dict = json.loads(json_response.decode()) + if "answer data" not in response_dict: + print(response_dict) + else: + answer += response_dict["answer data"] + print(answer) + except Exception as e: print(f"Failed due to {e}, retrying")