DAN-23 Stream model output (#30)

This commit is contained in:
Yuhong Sun 2023-05-11 22:49:26 -07:00 committed by GitHub
parent c6a0baed13
commit 20b25e322f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 210 additions and 51 deletions

View File

@ -1,6 +1,5 @@
import inspect import inspect
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
from danswer.connectors.models import Document from danswer.connectors.models import Document
@ -10,9 +9,9 @@ class BaseChunk:
chunk_id: int chunk_id: int
blurb: str # The first sentence(s) of the first Section of the chunk blurb: str # The first sentence(s) of the first Section of the chunk
content: str content: str
source_links: Optional[ source_links: dict[
dict[int, str] int, str
] # Holds the link and the offsets into the raw Chunk text ] | None # Holds the link and the offsets into the raw Chunk text
section_continuation: bool # True if this Chunk's start is not at the start of a Section section_continuation: bool # True if this Chunk's start is not at the start of a Section

View File

@ -1,5 +1,16 @@
import os import os
# Important considerations when choosing models
# Max tokens count needs to be high considering use case (at least 512)
# Models used must be MIT or Apache license
# Inference/Indexing speed
# Bi/Cross-Encoder Model Configs
# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding)
DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1"
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model
QUERY_EMBEDDING_CONTEXT_SIZE = 256 QUERY_EMBEDDING_CONTEXT_SIZE = 256
DOC_EMBEDDING_CONTEXT_SIZE = 512 DOC_EMBEDDING_CONTEXT_SIZE = 512
CROSS_EMBED_CONTEXT_SIZE = 512 CROSS_EMBED_CONTEXT_SIZE = 512

View File

@ -11,7 +11,7 @@ from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SEMANTIC_IDENTIFIER from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINKS from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE from danswer.configs.constants import SOURCE_TYPE
from danswer.semantic_search.semantic_search import DOC_EMBEDDING_DIM from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.utils.clients import get_qdrant_client from danswer.utils.clients import get_qdrant_client
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger
from qdrant_client import QdrantClient from qdrant_client import QdrantClient

View File

@ -1,5 +1,5 @@
import abc import abc
from typing import * from typing import Any
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
@ -7,6 +7,16 @@ from danswer.chunking.models import InferenceChunk
class QAModel: class QAModel:
@abc.abstractmethod @abc.abstractmethod
def answer_question( def answer_question(
self, query: str, context_docs: list[InferenceChunk] self,
query: str,
context_docs: list[InferenceChunk],
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]: ) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod
def answer_question_stream(
self,
query: str,
context_docs: list[InferenceChunk],
) -> Any:
raise NotImplementedError

View File

@ -2,6 +2,8 @@ import json
import math import math
import re import re
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import Dict from typing import Dict
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
@ -37,6 +39,10 @@ logger = setup_logger()
openai.api_key = OPENAI_API_KEY openai.api_key = OPENAI_API_KEY
def yield_json_line(json_dict):
return json.dumps(json_dict) + "\n"
def extract_answer_quotes_freeform( def extract_answer_quotes_freeform(
answer_raw: str, answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]: ) -> Tuple[Optional[str], Optional[list[str]]]:
@ -165,6 +171,15 @@ def process_answer(
return answer, quotes_dict return answer, quotes_dict
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
next_token = next_token.replace('\\"', "")
if answer_so_far and answer_so_far[-1] != "\\":
next_token = next_token[1:]
if '"' in next_token:
return True
return False
class OpenAICompletionQA(QAModel): class OpenAICompletionQA(QAModel):
def __init__( def __init__(
self, self,
@ -207,6 +222,57 @@ class OpenAICompletionQA(QAModel):
answer, quotes_dict = process_answer(model_output, context_docs) answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
filled_prompt = self.prompt_processor(query, top_contents)
logger.debug(filled_prompt)
try:
response = openai.Completion.create(
prompt=filled_prompt,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=self.model_version,
max_tokens=self.max_output_tokens,
stream=True,
)
model_output = ""
found_answer_start = False
found_answer_end = False
# iterate through the stream of events
for event in response:
event_text = event["choices"][0]["text"]
model_previous = model_output
model_output += event_text
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
found_answer_start = True
continue
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
continue
yield {"answer data": event_text}
except Exception as e:
logger.exception(e)
model_output = "Model Failure"
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
logger.info(answer)
yield quotes_dict
class OpenAIChatCompletionQA(QAModel): class OpenAIChatCompletionQA(QAModel):
def __init__( def __init__(
@ -257,3 +323,11 @@ class OpenAIChatCompletionQA(QAModel):
answer, quotes_dict = process_answer(model_output, context_docs) answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict return answer, quotes_dict
@log_function_time()
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Any:
raise NotImplementedError(
"Danswer with chat completion does not support streaming yet"
)

View File

@ -22,7 +22,9 @@ from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.datastores.interfaces import Datastore from danswer.datastores.interfaces import Datastore
from danswer.datastores.interfaces import DatastoreFilter from danswer.datastores.interfaces import DatastoreFilter
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger
@ -33,16 +35,6 @@ from sentence_transformers import SentenceTransformer # type: ignore
logger = setup_logger() logger = setup_logger()
# Important considerations when choosing models
# Max tokens count needs to be high considering use case (at least 512)
# Models used must be MIT or Apache license
# Inference/Indexing speed
# Bi/Cross-Encoder Model Configs
# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding)
DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1"
DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
_EMBED_MODEL: None | SentenceTransformer = None _EMBED_MODEL: None | SentenceTransformer = None
_RERANK_MODEL: None | CrossEncoder = None _RERANK_MODEL: None | CrossEncoder = None

View File

@ -1,3 +1,4 @@
import json
import time import time
from http import HTTPStatus from http import HTTPStatus
@ -11,6 +12,7 @@ from danswer.datastores import create_datastore
from danswer.db.engine import build_async_engine from danswer.db.engine import build_async_engine
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.question_answer import yield_json_line
from danswer.semantic_search.semantic_search import retrieve_ranked_documents from danswer.semantic_search.semantic_search import retrieve_ranked_documents
from danswer.server.models import KeywordResponse from danswer.server.models import KeywordResponse
from danswer.server.models import QAQuestion from danswer.server.models import QAQuestion
@ -24,6 +26,7 @@ from fastapi import APIRouter
from fastapi import Depends from fastapi import Depends
from fastapi import HTTPException from fastapi import HTTPException
from fastapi import Request from fastapi import Request
from fastapi.responses import StreamingResponse
from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@ -78,29 +81,73 @@ async def promote_admin(
return return
@router.post("/direct-qa", response_model=QAResponse) @router.get("/direct-qa", response_model=QAResponse)
def direct_qa(question: QAQuestion): def direct_qa(question: QAQuestion):
start_time = time.time() start_time = time.time()
qa_model = get_default_backend_qa_model()
query = question.query query = question.query
collection = question.collection collection = question.collection
filters = question.filters filters = question.filters
datastore = create_datastore(collection)
logger.info(f"Received semantic query: {query}") logger.info(f"Received semantic query: {query}")
ranked_chunks = retrieve_ranked_documents(query, filters, datastore) ranked_chunks = retrieve_ranked_documents(
query, filters, create_datastore(collection)
)
if not ranked_chunks: if not ranked_chunks:
return {"answer": None, "quotes": None} return {"answer": None, "quotes": None}
qa_model = get_default_backend_qa_model()
answer, quotes = qa_model.answer_question(query, ranked_chunks) answer, quotes = qa_model.answer_question(query, ranked_chunks)
logger.info(f"Total QA took {time.time() - start_time} seconds") logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse(answer=answer, quotes=quotes) return QAResponse(answer=answer, quotes=quotes)
@router.post("/keyword-search", response_model=KeywordResponse) @router.get("/stream-direct-qa")
def stream_direct_qa(question: QAQuestion):
top_documents_key = "top_documents"
answer_key = "answer"
quotes_key = "quotes"
def stream_qa_portions():
query = question.query
collection = question.collection
filters = question.filters
logger.info(f"Received semantic query: {query}")
ranked_chunks = retrieve_ranked_documents(
query, filters, create_datastore(collection)
)
if not ranked_chunks:
return yield_json_line(
{top_documents_key: None, answer_key: None, quotes_key: None}
)
linked_chunks = [
chunk
for chunk in ranked_chunks
if chunk.source_links and "0" in chunk.source_links
]
top_docs = {
top_documents_key: {
"document section links": [
chunk.source_links["0"] for chunk in linked_chunks
],
"blurbs": [chunk.blurb for chunk in linked_chunks],
}
}
yield yield_json_line(top_docs)
qa_model = get_default_backend_qa_model()
for response_dict in qa_model.answer_question_stream(query, ranked_chunks):
logger.debug(response_dict)
yield yield_json_line(response_dict)
return StreamingResponse(stream_qa_portions(), media_type="application/json")
@router.get("/keyword-search", response_model=KeywordResponse)
def keyword_search(question: QAQuestion): def keyword_search(question: QAQuestion):
ts_client = TSClient.get_instance() ts_client = TSClient.get_instance()
query = question.query query = question.query

View File

@ -22,19 +22,27 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"-s", "-t",
"--source-types", "--source-types",
type=str, type=str,
help="Comma separated list of source types to filter by", help="Comma separated list of source types to filter by",
) )
parser.add_argument(
"-s",
"--stream",
action="store_true",
help="Enable streaming response",
)
parser.add_argument("query", nargs="*", help="The query to process") parser.add_argument("query", nargs="*", help="The query to process")
while True: while True:
try: try:
user_input = input( user_input = input(
"\n\nAsk any question:\n" "\n\nAsk any question:\n"
" - prefix with -s to add a filter by source(s)\n" " - prefix with -t to add a filter by source type(s)\n"
" - prefix with -s to stream answer\n"
" - input an empty string to rerun last query\n\t" " - input an empty string to rerun last query\n\t"
) )
@ -55,40 +63,58 @@ if __name__ == "__main__":
source_types = source_types[0] source_types = source_types[0]
query = " ".join(args.query) query = " ".join(args.query)
endpoint = f"http://127.0.0.1:{APP_PORT}/direct-qa" endpoint = (
f"http://127.0.0.1:{APP_PORT}/direct-qa"
if not args.stream
else f"http://127.0.0.1:{APP_PORT}/stream-direct-qa"
)
if args.keyword_search: if args.keyword_search:
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search" endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
raise NotImplementedError("keyword search is not supported for now")
query_json = { query_json = {
"query": query, "query": query,
"collection": QDRANT_DEFAULT_COLLECTION, "collection": QDRANT_DEFAULT_COLLECTION,
"filters": [{SOURCE_TYPE: source_types}], "filters": [{SOURCE_TYPE: source_types}],
} }
if not args.stream:
response = requests.post(endpoint, json=query_json) response = requests.get(endpoint, json=query_json)
contents = json.loads(response.content) contents = json.loads(response.content)
if keyword_search: if keyword_search:
if contents["results"]: if contents["results"]:
for link in contents["results"]: for link in contents["results"]:
print(link) print(link)
else:
print("No matches found")
else: else:
print("No matches found") answer = contents.get("answer")
if answer:
print("Answer: " + answer)
else:
print("Answer: ?")
if contents.get("quotes"):
for ind, (quote, quote_info) in enumerate(
contents["quotes"].items()
):
print(f"Quote {str(ind + 1)}:\n{quote}")
print(
f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}"
)
print(f"Blurb: {quote_info[BLURB]}")
print(f"Link: {quote_info[SOURCE_LINK]}")
print(f"Source: {quote_info[SOURCE_TYPE]}")
else:
print("No quotes found")
else: else:
answer = contents.get("answer") answer = ""
if answer: with requests.get(endpoint, json=query_json, stream=True) as r:
print("Answer: " + answer) for json_response in r.iter_lines():
else: response_dict = json.loads(json_response.decode())
print("Answer: ?") if "answer data" not in response_dict:
if contents.get("quotes"): print(response_dict)
for ind, (quote, quote_info) in enumerate( else:
contents["quotes"].items() answer += response_dict["answer data"]
): print(answer)
print(f"Quote {str(ind + 1)}:\n{quote}")
print(f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}")
print(f"Blurb: {quote_info[BLURB]}")
print(f"Link: {quote_info[SOURCE_LINK]}")
print(f"Source: {quote_info[SOURCE_TYPE]}")
else:
print("No quotes found")
except Exception as e: except Exception as e:
print(f"Failed due to {e}, retrying") print(f"Failed due to {e}, retrying")