DAN-23 Stream model output (#30)

This commit is contained in:
Yuhong Sun 2023-05-11 22:49:26 -07:00 committed by GitHub
parent c6a0baed13
commit 20b25e322f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 210 additions and 51 deletions

View File

@ -1,6 +1,5 @@
import inspect
from dataclasses import dataclass
from typing import Optional
from danswer.connectors.models import Document
@ -10,9 +9,9 @@ class BaseChunk:
chunk_id: int
blurb: str # The first sentence(s) of the first Section of the chunk
content: str
source_links: Optional[
dict[int, str]
] # Holds the link and the offsets into the raw Chunk text
source_links: dict[
int, str
] | None # Holds the link and the offsets into the raw Chunk text
section_continuation: bool # True if this Chunk's start is not at the start of a Section

View File

@ -1,5 +1,16 @@
import os
# Important considerations when choosing models
# Max tokens count needs to be high considering use case (at least 512)
# Models used must be MIT or Apache license
# Inference/Indexing speed
# Bi/Cross-Encoder Model Configs
# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding)
DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1"
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model
QUERY_EMBEDDING_CONTEXT_SIZE = 256
DOC_EMBEDDING_CONTEXT_SIZE = 512
CROSS_EMBED_CONTEXT_SIZE = 512

View File

@ -11,7 +11,7 @@ from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.semantic_search.semantic_search import DOC_EMBEDDING_DIM
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logging import setup_logger
from qdrant_client import QdrantClient

View File

@ -1,5 +1,5 @@
import abc
from typing import *
from typing import Any
from danswer.chunking.models import InferenceChunk
@ -7,6 +7,16 @@ from danswer.chunking.models import InferenceChunk
class QAModel:
@abc.abstractmethod
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
self,
query: str,
context_docs: list[InferenceChunk],
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
raise NotImplementedError
@abc.abstractmethod
def answer_question_stream(
self,
query: str,
context_docs: list[InferenceChunk],
) -> Any:
raise NotImplementedError

View File

@ -2,6 +2,8 @@ import json
import math
import re
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import Dict
from typing import Optional
from typing import Tuple
@ -37,6 +39,10 @@ logger = setup_logger()
openai.api_key = OPENAI_API_KEY
def yield_json_line(json_dict):
return json.dumps(json_dict) + "\n"
def extract_answer_quotes_freeform(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
@ -165,6 +171,15 @@ def process_answer(
return answer, quotes_dict
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
next_token = next_token.replace('\\"', "")
if answer_so_far and answer_so_far[-1] != "\\":
next_token = next_token[1:]
if '"' in next_token:
return True
return False
class OpenAICompletionQA(QAModel):
def __init__(
self,
@ -207,6 +222,57 @@ class OpenAICompletionQA(QAModel):
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
filled_prompt = self.prompt_processor(query, top_contents)
logger.debug(filled_prompt)
try:
response = openai.Completion.create(
prompt=filled_prompt,
temperature=0,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
model=self.model_version,
max_tokens=self.max_output_tokens,
stream=True,
)
model_output = ""
found_answer_start = False
found_answer_end = False
# iterate through the stream of events
for event in response:
event_text = event["choices"][0]["text"]
model_previous = model_output
model_output += event_text
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
found_answer_start = True
continue
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
continue
yield {"answer data": event_text}
except Exception as e:
logger.exception(e)
model_output = "Model Failure"
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
logger.info(answer)
yield quotes_dict
class OpenAIChatCompletionQA(QAModel):
def __init__(
@ -257,3 +323,11 @@ class OpenAIChatCompletionQA(QAModel):
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
@log_function_time()
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Any:
raise NotImplementedError(
"Danswer with chat completion does not support streaming yet"
)

View File

@ -22,7 +22,9 @@ from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.datastores.interfaces import Datastore
from danswer.datastores.interfaces import DatastoreFilter
from danswer.utils.logging import setup_logger
@ -33,16 +35,6 @@ from sentence_transformers import SentenceTransformer # type: ignore
logger = setup_logger()
# Important considerations when choosing models
# Max tokens count needs to be high considering use case (at least 512)
# Models used must be MIT or Apache license
# Inference/Indexing speed
# Bi/Cross-Encoder Model Configs
# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding)
DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1"
DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
_EMBED_MODEL: None | SentenceTransformer = None
_RERANK_MODEL: None | CrossEncoder = None

View File

@ -1,3 +1,4 @@
import json
import time
from http import HTTPStatus
@ -11,6 +12,7 @@ from danswer.datastores import create_datastore
from danswer.db.engine import build_async_engine
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.question_answer import yield_json_line
from danswer.semantic_search.semantic_search import retrieve_ranked_documents
from danswer.server.models import KeywordResponse
from danswer.server.models import QAQuestion
@ -24,6 +26,7 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi.responses import StreamingResponse
from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy.ext.asyncio import AsyncSession
@ -78,29 +81,73 @@ async def promote_admin(
return
@router.post("/direct-qa", response_model=QAResponse)
@router.get("/direct-qa", response_model=QAResponse)
def direct_qa(question: QAQuestion):
start_time = time.time()
qa_model = get_default_backend_qa_model()
query = question.query
collection = question.collection
filters = question.filters
datastore = create_datastore(collection)
logger.info(f"Received semantic query: {query}")
ranked_chunks = retrieve_ranked_documents(query, filters, datastore)
ranked_chunks = retrieve_ranked_documents(
query, filters, create_datastore(collection)
)
if not ranked_chunks:
return {"answer": None, "quotes": None}
qa_model = get_default_backend_qa_model()
answer, quotes = qa_model.answer_question(query, ranked_chunks)
logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse(answer=answer, quotes=quotes)
@router.post("/keyword-search", response_model=KeywordResponse)
@router.get("/stream-direct-qa")
def stream_direct_qa(question: QAQuestion):
top_documents_key = "top_documents"
answer_key = "answer"
quotes_key = "quotes"
def stream_qa_portions():
query = question.query
collection = question.collection
filters = question.filters
logger.info(f"Received semantic query: {query}")
ranked_chunks = retrieve_ranked_documents(
query, filters, create_datastore(collection)
)
if not ranked_chunks:
return yield_json_line(
{top_documents_key: None, answer_key: None, quotes_key: None}
)
linked_chunks = [
chunk
for chunk in ranked_chunks
if chunk.source_links and "0" in chunk.source_links
]
top_docs = {
top_documents_key: {
"document section links": [
chunk.source_links["0"] for chunk in linked_chunks
],
"blurbs": [chunk.blurb for chunk in linked_chunks],
}
}
yield yield_json_line(top_docs)
qa_model = get_default_backend_qa_model()
for response_dict in qa_model.answer_question_stream(query, ranked_chunks):
logger.debug(response_dict)
yield yield_json_line(response_dict)
return StreamingResponse(stream_qa_portions(), media_type="application/json")
@router.get("/keyword-search", response_model=KeywordResponse)
def keyword_search(question: QAQuestion):
ts_client = TSClient.get_instance()
query = question.query

View File

@ -22,19 +22,27 @@ if __name__ == "__main__":
)
parser.add_argument(
"-s",
"-t",
"--source-types",
type=str,
help="Comma separated list of source types to filter by",
)
parser.add_argument(
"-s",
"--stream",
action="store_true",
help="Enable streaming response",
)
parser.add_argument("query", nargs="*", help="The query to process")
while True:
try:
user_input = input(
"\n\nAsk any question:\n"
" - prefix with -s to add a filter by source(s)\n"
" - prefix with -t to add a filter by source type(s)\n"
" - prefix with -s to stream answer\n"
" - input an empty string to rerun last query\n\t"
)
@ -55,40 +63,58 @@ if __name__ == "__main__":
source_types = source_types[0]
query = " ".join(args.query)
endpoint = f"http://127.0.0.1:{APP_PORT}/direct-qa"
endpoint = (
f"http://127.0.0.1:{APP_PORT}/direct-qa"
if not args.stream
else f"http://127.0.0.1:{APP_PORT}/stream-direct-qa"
)
if args.keyword_search:
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
raise NotImplementedError("keyword search is not supported for now")
query_json = {
"query": query,
"collection": QDRANT_DEFAULT_COLLECTION,
"filters": [{SOURCE_TYPE: source_types}],
}
response = requests.post(endpoint, json=query_json)
contents = json.loads(response.content)
if keyword_search:
if contents["results"]:
for link in contents["results"]:
print(link)
if not args.stream:
response = requests.get(endpoint, json=query_json)
contents = json.loads(response.content)
if keyword_search:
if contents["results"]:
for link in contents["results"]:
print(link)
else:
print("No matches found")
else:
print("No matches found")
answer = contents.get("answer")
if answer:
print("Answer: " + answer)
else:
print("Answer: ?")
if contents.get("quotes"):
for ind, (quote, quote_info) in enumerate(
contents["quotes"].items()
):
print(f"Quote {str(ind + 1)}:\n{quote}")
print(
f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}"
)
print(f"Blurb: {quote_info[BLURB]}")
print(f"Link: {quote_info[SOURCE_LINK]}")
print(f"Source: {quote_info[SOURCE_TYPE]}")
else:
print("No quotes found")
else:
answer = contents.get("answer")
if answer:
print("Answer: " + answer)
else:
print("Answer: ?")
if contents.get("quotes"):
for ind, (quote, quote_info) in enumerate(
contents["quotes"].items()
):
print(f"Quote {str(ind + 1)}:\n{quote}")
print(f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}")
print(f"Blurb: {quote_info[BLURB]}")
print(f"Link: {quote_info[SOURCE_LINK]}")
print(f"Source: {quote_info[SOURCE_TYPE]}")
else:
print("No quotes found")
answer = ""
with requests.get(endpoint, json=query_json, stream=True) as r:
for json_response in r.iter_lines():
response_dict = json.loads(json_response.decode())
if "answer data" not in response_dict:
print(response_dict)
else:
answer += response_dict["answer data"]
print(answer)
except Exception as e:
print(f"Failed due to {e}, retrying")