mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-02 16:00:34 +02:00
DAN-23 Stream model output (#30)
This commit is contained in:
parent
c6a0baed13
commit
20b25e322f
@ -1,6 +1,5 @@
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from danswer.connectors.models import Document
|
||||
|
||||
@ -10,9 +9,9 @@ class BaseChunk:
|
||||
chunk_id: int
|
||||
blurb: str # The first sentence(s) of the first Section of the chunk
|
||||
content: str
|
||||
source_links: Optional[
|
||||
dict[int, str]
|
||||
] # Holds the link and the offsets into the raw Chunk text
|
||||
source_links: dict[
|
||||
int, str
|
||||
] | None # Holds the link and the offsets into the raw Chunk text
|
||||
section_continuation: bool # True if this Chunk's start is not at the start of a Section
|
||||
|
||||
|
||||
|
@ -1,5 +1,16 @@
|
||||
import os
|
||||
|
||||
# Important considerations when choosing models
|
||||
# Max tokens count needs to be high considering use case (at least 512)
|
||||
# Models used must be MIT or Apache license
|
||||
# Inference/Indexing speed
|
||||
|
||||
# Bi/Cross-Encoder Model Configs
|
||||
# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding)
|
||||
DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1"
|
||||
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model
|
||||
|
||||
QUERY_EMBEDDING_CONTEXT_SIZE = 256
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
CROSS_EMBED_CONTEXT_SIZE = 512
|
||||
|
@ -11,7 +11,7 @@ from danswer.configs.constants import SECTION_CONTINUATION
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.semantic_search.semantic_search import DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from danswer.utils.clients import get_qdrant_client
|
||||
from danswer.utils.logging import setup_logger
|
||||
from qdrant_client import QdrantClient
|
||||
|
@ -1,5 +1,5 @@
|
||||
import abc
|
||||
from typing import *
|
||||
from typing import Any
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
|
||||
@ -7,6 +7,16 @@ from danswer.chunking.models import InferenceChunk
|
||||
class QAModel:
|
||||
@abc.abstractmethod
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def answer_question_stream(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> Any:
|
||||
raise NotImplementedError
|
||||
|
@ -2,6 +2,8 @@ import json
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
@ -37,6 +39,10 @@ logger = setup_logger()
|
||||
openai.api_key = OPENAI_API_KEY
|
||||
|
||||
|
||||
def yield_json_line(json_dict):
|
||||
return json.dumps(json_dict) + "\n"
|
||||
|
||||
|
||||
def extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
@ -165,6 +171,15 @@ def process_answer(
|
||||
return answer, quotes_dict
|
||||
|
||||
|
||||
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
next_token = next_token.replace('\\"', "")
|
||||
if answer_so_far and answer_so_far[-1] != "\\":
|
||||
next_token = next_token[1:]
|
||||
if '"' in next_token:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class OpenAICompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
@ -207,6 +222,57 @@ class OpenAICompletionQA(QAModel):
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
||||
filled_prompt = self.prompt_processor(query, top_contents)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
try:
|
||||
response = openai.Completion.create(
|
||||
prompt=filled_prompt,
|
||||
temperature=0,
|
||||
top_p=1,
|
||||
frequency_penalty=0,
|
||||
presence_penalty=0,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
model_output = ""
|
||||
found_answer_start = False
|
||||
found_answer_end = False
|
||||
# iterate through the stream of events
|
||||
for event in response:
|
||||
event_text = event["choices"][0]["text"]
|
||||
model_previous = model_output
|
||||
model_output += event_text
|
||||
|
||||
if not found_answer_start and '{"answer":"' in model_output.replace(
|
||||
" ", ""
|
||||
).replace("\n", ""):
|
||||
found_answer_start = True
|
||||
continue
|
||||
|
||||
if found_answer_start and not found_answer_end:
|
||||
if stream_answer_end(model_previous, event_text):
|
||||
found_answer_end = True
|
||||
continue
|
||||
yield {"answer data": event_text}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
model_output = "Model Failure"
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
logger.info(answer)
|
||||
|
||||
yield quotes_dict
|
||||
|
||||
|
||||
class OpenAIChatCompletionQA(QAModel):
|
||||
def __init__(
|
||||
@ -257,3 +323,11 @@ class OpenAIChatCompletionQA(QAModel):
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
@log_function_time()
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Any:
|
||||
raise NotImplementedError(
|
||||
"Danswer with chat completion does not support streaming yet"
|
||||
)
|
||||
|
@ -22,7 +22,9 @@ from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.datastores.interfaces import Datastore
|
||||
from danswer.datastores.interfaces import DatastoreFilter
|
||||
from danswer.utils.logging import setup_logger
|
||||
@ -33,16 +35,6 @@ from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Important considerations when choosing models
|
||||
# Max tokens count needs to be high considering use case (at least 512)
|
||||
# Models used must be MIT or Apache license
|
||||
# Inference/Indexing speed
|
||||
|
||||
# Bi/Cross-Encoder Model Configs
|
||||
# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding)
|
||||
DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1"
|
||||
DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model
|
||||
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||
|
||||
_EMBED_MODEL: None | SentenceTransformer = None
|
||||
_RERANK_MODEL: None | CrossEncoder = None
|
||||
|
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
||||
@ -11,6 +12,7 @@ from danswer.datastores import create_datastore
|
||||
from danswer.db.engine import build_async_engine
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.question_answer import yield_json_line
|
||||
from danswer.semantic_search.semantic_search import retrieve_ranked_documents
|
||||
from danswer.server.models import KeywordResponse
|
||||
from danswer.server.models import QAQuestion
|
||||
@ -24,6 +26,7 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
@ -78,29 +81,73 @@ async def promote_admin(
|
||||
return
|
||||
|
||||
|
||||
@router.post("/direct-qa", response_model=QAResponse)
|
||||
@router.get("/direct-qa", response_model=QAResponse)
|
||||
def direct_qa(question: QAQuestion):
|
||||
start_time = time.time()
|
||||
qa_model = get_default_backend_qa_model()
|
||||
|
||||
query = question.query
|
||||
collection = question.collection
|
||||
filters = question.filters
|
||||
|
||||
datastore = create_datastore(collection)
|
||||
|
||||
logger.info(f"Received semantic query: {query}")
|
||||
|
||||
ranked_chunks = retrieve_ranked_documents(query, filters, datastore)
|
||||
ranked_chunks = retrieve_ranked_documents(
|
||||
query, filters, create_datastore(collection)
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return {"answer": None, "quotes": None}
|
||||
|
||||
qa_model = get_default_backend_qa_model()
|
||||
answer, quotes = qa_model.answer_question(query, ranked_chunks)
|
||||
|
||||
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
||||
|
||||
return QAResponse(answer=answer, quotes=quotes)
|
||||
|
||||
|
||||
@router.post("/keyword-search", response_model=KeywordResponse)
|
||||
@router.get("/stream-direct-qa")
|
||||
def stream_direct_qa(question: QAQuestion):
|
||||
top_documents_key = "top_documents"
|
||||
answer_key = "answer"
|
||||
quotes_key = "quotes"
|
||||
|
||||
def stream_qa_portions():
|
||||
query = question.query
|
||||
collection = question.collection
|
||||
filters = question.filters
|
||||
logger.info(f"Received semantic query: {query}")
|
||||
|
||||
ranked_chunks = retrieve_ranked_documents(
|
||||
query, filters, create_datastore(collection)
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return yield_json_line(
|
||||
{top_documents_key: None, answer_key: None, quotes_key: None}
|
||||
)
|
||||
|
||||
linked_chunks = [
|
||||
chunk
|
||||
for chunk in ranked_chunks
|
||||
if chunk.source_links and "0" in chunk.source_links
|
||||
]
|
||||
top_docs = {
|
||||
top_documents_key: {
|
||||
"document section links": [
|
||||
chunk.source_links["0"] for chunk in linked_chunks
|
||||
],
|
||||
"blurbs": [chunk.blurb for chunk in linked_chunks],
|
||||
}
|
||||
}
|
||||
yield yield_json_line(top_docs)
|
||||
|
||||
qa_model = get_default_backend_qa_model()
|
||||
for response_dict in qa_model.answer_question_stream(query, ranked_chunks):
|
||||
logger.debug(response_dict)
|
||||
yield yield_json_line(response_dict)
|
||||
|
||||
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
||||
|
||||
|
||||
@router.get("/keyword-search", response_model=KeywordResponse)
|
||||
def keyword_search(question: QAQuestion):
|
||||
ts_client = TSClient.get_instance()
|
||||
query = question.query
|
||||
|
@ -22,19 +22,27 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"-t",
|
||||
"--source-types",
|
||||
type=str,
|
||||
help="Comma separated list of source types to filter by",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--stream",
|
||||
action="store_true",
|
||||
help="Enable streaming response",
|
||||
)
|
||||
|
||||
parser.add_argument("query", nargs="*", help="The query to process")
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input(
|
||||
"\n\nAsk any question:\n"
|
||||
" - prefix with -s to add a filter by source(s)\n"
|
||||
" - prefix with -t to add a filter by source type(s)\n"
|
||||
" - prefix with -s to stream answer\n"
|
||||
" - input an empty string to rerun last query\n\t"
|
||||
)
|
||||
|
||||
@ -55,40 +63,58 @@ if __name__ == "__main__":
|
||||
source_types = source_types[0]
|
||||
query = " ".join(args.query)
|
||||
|
||||
endpoint = f"http://127.0.0.1:{APP_PORT}/direct-qa"
|
||||
endpoint = (
|
||||
f"http://127.0.0.1:{APP_PORT}/direct-qa"
|
||||
if not args.stream
|
||||
else f"http://127.0.0.1:{APP_PORT}/stream-direct-qa"
|
||||
)
|
||||
if args.keyword_search:
|
||||
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
|
||||
raise NotImplementedError("keyword search is not supported for now")
|
||||
|
||||
query_json = {
|
||||
"query": query,
|
||||
"collection": QDRANT_DEFAULT_COLLECTION,
|
||||
"filters": [{SOURCE_TYPE: source_types}],
|
||||
}
|
||||
|
||||
response = requests.post(endpoint, json=query_json)
|
||||
contents = json.loads(response.content)
|
||||
if keyword_search:
|
||||
if contents["results"]:
|
||||
for link in contents["results"]:
|
||||
print(link)
|
||||
if not args.stream:
|
||||
response = requests.get(endpoint, json=query_json)
|
||||
contents = json.loads(response.content)
|
||||
if keyword_search:
|
||||
if contents["results"]:
|
||||
for link in contents["results"]:
|
||||
print(link)
|
||||
else:
|
||||
print("No matches found")
|
||||
else:
|
||||
print("No matches found")
|
||||
answer = contents.get("answer")
|
||||
if answer:
|
||||
print("Answer: " + answer)
|
||||
else:
|
||||
print("Answer: ?")
|
||||
if contents.get("quotes"):
|
||||
for ind, (quote, quote_info) in enumerate(
|
||||
contents["quotes"].items()
|
||||
):
|
||||
print(f"Quote {str(ind + 1)}:\n{quote}")
|
||||
print(
|
||||
f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}"
|
||||
)
|
||||
print(f"Blurb: {quote_info[BLURB]}")
|
||||
print(f"Link: {quote_info[SOURCE_LINK]}")
|
||||
print(f"Source: {quote_info[SOURCE_TYPE]}")
|
||||
else:
|
||||
print("No quotes found")
|
||||
else:
|
||||
answer = contents.get("answer")
|
||||
if answer:
|
||||
print("Answer: " + answer)
|
||||
else:
|
||||
print("Answer: ?")
|
||||
if contents.get("quotes"):
|
||||
for ind, (quote, quote_info) in enumerate(
|
||||
contents["quotes"].items()
|
||||
):
|
||||
print(f"Quote {str(ind + 1)}:\n{quote}")
|
||||
print(f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}")
|
||||
print(f"Blurb: {quote_info[BLURB]}")
|
||||
print(f"Link: {quote_info[SOURCE_LINK]}")
|
||||
print(f"Source: {quote_info[SOURCE_TYPE]}")
|
||||
else:
|
||||
print("No quotes found")
|
||||
answer = ""
|
||||
with requests.get(endpoint, json=query_json, stream=True) as r:
|
||||
for json_response in r.iter_lines():
|
||||
response_dict = json.loads(json_response.decode())
|
||||
if "answer data" not in response_dict:
|
||||
print(response_dict)
|
||||
else:
|
||||
answer += response_dict["answer data"]
|
||||
print(answer)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed due to {e}, retrying")
|
||||
|
Loading…
x
Reference in New Issue
Block a user