mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-03 00:10:24 +02:00
DAN-23 Stream model output (#30)
This commit is contained in:
parent
c6a0baed13
commit
20b25e322f
@ -1,6 +1,5 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
|
|
||||||
@ -10,9 +9,9 @@ class BaseChunk:
|
|||||||
chunk_id: int
|
chunk_id: int
|
||||||
blurb: str # The first sentence(s) of the first Section of the chunk
|
blurb: str # The first sentence(s) of the first Section of the chunk
|
||||||
content: str
|
content: str
|
||||||
source_links: Optional[
|
source_links: dict[
|
||||||
dict[int, str]
|
int, str
|
||||||
] # Holds the link and the offsets into the raw Chunk text
|
] | None # Holds the link and the offsets into the raw Chunk text
|
||||||
section_continuation: bool # True if this Chunk's start is not at the start of a Section
|
section_continuation: bool # True if this Chunk's start is not at the start of a Section
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,5 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
# Important considerations when choosing models
|
||||||
|
# Max tokens count needs to be high considering use case (at least 512)
|
||||||
|
# Models used must be MIT or Apache license
|
||||||
|
# Inference/Indexing speed
|
||||||
|
|
||||||
|
# Bi/Cross-Encoder Model Configs
|
||||||
|
# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding)
|
||||||
|
DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1"
|
||||||
|
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
||||||
|
DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model
|
||||||
|
|
||||||
QUERY_EMBEDDING_CONTEXT_SIZE = 256
|
QUERY_EMBEDDING_CONTEXT_SIZE = 256
|
||||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||||
CROSS_EMBED_CONTEXT_SIZE = 512
|
CROSS_EMBED_CONTEXT_SIZE = 512
|
||||||
|
@ -11,7 +11,7 @@ from danswer.configs.constants import SECTION_CONTINUATION
|
|||||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||||
from danswer.configs.constants import SOURCE_LINKS
|
from danswer.configs.constants import SOURCE_LINKS
|
||||||
from danswer.configs.constants import SOURCE_TYPE
|
from danswer.configs.constants import SOURCE_TYPE
|
||||||
from danswer.semantic_search.semantic_search import DOC_EMBEDDING_DIM
|
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||||
from danswer.utils.clients import get_qdrant_client
|
from danswer.utils.clients import get_qdrant_client
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from qdrant_client import QdrantClient
|
from qdrant_client import QdrantClient
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import abc
|
import abc
|
||||||
from typing import *
|
from typing import Any
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
|
||||||
@ -7,6 +7,16 @@ from danswer.chunking.models import InferenceChunk
|
|||||||
class QAModel:
|
class QAModel:
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def answer_question(
|
def answer_question(
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
self,
|
||||||
|
query: str,
|
||||||
|
context_docs: list[InferenceChunk],
|
||||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def answer_question_stream(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
context_docs: list[InferenceChunk],
|
||||||
|
) -> Any:
|
||||||
|
raise NotImplementedError
|
||||||
|
@ -2,6 +2,8 @@ import json
|
|||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
@ -37,6 +39,10 @@ logger = setup_logger()
|
|||||||
openai.api_key = OPENAI_API_KEY
|
openai.api_key = OPENAI_API_KEY
|
||||||
|
|
||||||
|
|
||||||
|
def yield_json_line(json_dict):
|
||||||
|
return json.dumps(json_dict) + "\n"
|
||||||
|
|
||||||
|
|
||||||
def extract_answer_quotes_freeform(
|
def extract_answer_quotes_freeform(
|
||||||
answer_raw: str,
|
answer_raw: str,
|
||||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||||
@ -165,6 +171,15 @@ def process_answer(
|
|||||||
return answer, quotes_dict
|
return answer, quotes_dict
|
||||||
|
|
||||||
|
|
||||||
|
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||||
|
next_token = next_token.replace('\\"', "")
|
||||||
|
if answer_so_far and answer_so_far[-1] != "\\":
|
||||||
|
next_token = next_token[1:]
|
||||||
|
if '"' in next_token:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompletionQA(QAModel):
|
class OpenAICompletionQA(QAModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -207,6 +222,57 @@ class OpenAICompletionQA(QAModel):
|
|||||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
return answer, quotes_dict
|
return answer, quotes_dict
|
||||||
|
|
||||||
|
def answer_question_stream(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> Generator[dict[str, Any] | None, None, None]:
|
||||||
|
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
||||||
|
filled_prompt = self.prompt_processor(query, top_contents)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = openai.Completion.create(
|
||||||
|
prompt=filled_prompt,
|
||||||
|
temperature=0,
|
||||||
|
top_p=1,
|
||||||
|
frequency_penalty=0,
|
||||||
|
presence_penalty=0,
|
||||||
|
model=self.model_version,
|
||||||
|
max_tokens=self.max_output_tokens,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
model_output = ""
|
||||||
|
found_answer_start = False
|
||||||
|
found_answer_end = False
|
||||||
|
# iterate through the stream of events
|
||||||
|
for event in response:
|
||||||
|
event_text = event["choices"][0]["text"]
|
||||||
|
model_previous = model_output
|
||||||
|
model_output += event_text
|
||||||
|
|
||||||
|
if not found_answer_start and '{"answer":"' in model_output.replace(
|
||||||
|
" ", ""
|
||||||
|
).replace("\n", ""):
|
||||||
|
found_answer_start = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if found_answer_start and not found_answer_end:
|
||||||
|
if stream_answer_end(model_previous, event_text):
|
||||||
|
found_answer_end = True
|
||||||
|
continue
|
||||||
|
yield {"answer data": event_text}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(e)
|
||||||
|
model_output = "Model Failure"
|
||||||
|
|
||||||
|
logger.debug(model_output)
|
||||||
|
|
||||||
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
|
logger.info(answer)
|
||||||
|
|
||||||
|
yield quotes_dict
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletionQA(QAModel):
|
class OpenAIChatCompletionQA(QAModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -257,3 +323,11 @@ class OpenAIChatCompletionQA(QAModel):
|
|||||||
|
|
||||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
return answer, quotes_dict
|
return answer, quotes_dict
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def answer_question_stream(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> Any:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Danswer with chat completion does not support streaming yet"
|
||||||
|
)
|
||||||
|
@ -22,7 +22,9 @@ from danswer.chunking.models import InferenceChunk
|
|||||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||||
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||||
|
from danswer.configs.model_configs import CROSS_ENCODER_MODEL
|
||||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
|
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||||
from danswer.datastores.interfaces import Datastore
|
from danswer.datastores.interfaces import Datastore
|
||||||
from danswer.datastores.interfaces import DatastoreFilter
|
from danswer.datastores.interfaces import DatastoreFilter
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
@ -33,16 +35,6 @@ from sentence_transformers import SentenceTransformer # type: ignore
|
|||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
# Important considerations when choosing models
|
|
||||||
# Max tokens count needs to be high considering use case (at least 512)
|
|
||||||
# Models used must be MIT or Apache license
|
|
||||||
# Inference/Indexing speed
|
|
||||||
|
|
||||||
# Bi/Cross-Encoder Model Configs
|
|
||||||
# Use 'multi-qa-MiniLM-L6-cos-v1' if license is added because it is 3x faster (384 dimensional embedding)
|
|
||||||
DOCUMENT_ENCODER_MODEL = "sentence-transformers/all-distilroberta-v1"
|
|
||||||
DOC_EMBEDDING_DIM = 768 # Depends on the document encoder model
|
|
||||||
CROSS_ENCODER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
|
||||||
|
|
||||||
_EMBED_MODEL: None | SentenceTransformer = None
|
_EMBED_MODEL: None | SentenceTransformer = None
|
||||||
_RERANK_MODEL: None | CrossEncoder = None
|
_RERANK_MODEL: None | CrossEncoder = None
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
import time
|
import time
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
|
||||||
@ -11,6 +12,7 @@ from danswer.datastores import create_datastore
|
|||||||
from danswer.db.engine import build_async_engine
|
from danswer.db.engine import build_async_engine
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa import get_default_backend_qa_model
|
from danswer.direct_qa import get_default_backend_qa_model
|
||||||
|
from danswer.direct_qa.question_answer import yield_json_line
|
||||||
from danswer.semantic_search.semantic_search import retrieve_ranked_documents
|
from danswer.semantic_search.semantic_search import retrieve_ranked_documents
|
||||||
from danswer.server.models import KeywordResponse
|
from danswer.server.models import KeywordResponse
|
||||||
from danswer.server.models import QAQuestion
|
from danswer.server.models import QAQuestion
|
||||||
@ -24,6 +26,7 @@ from fastapi import APIRouter
|
|||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@ -78,29 +81,73 @@ async def promote_admin(
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
@router.post("/direct-qa", response_model=QAResponse)
|
@router.get("/direct-qa", response_model=QAResponse)
|
||||||
def direct_qa(question: QAQuestion):
|
def direct_qa(question: QAQuestion):
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
qa_model = get_default_backend_qa_model()
|
|
||||||
query = question.query
|
query = question.query
|
||||||
collection = question.collection
|
collection = question.collection
|
||||||
filters = question.filters
|
filters = question.filters
|
||||||
|
|
||||||
datastore = create_datastore(collection)
|
|
||||||
|
|
||||||
logger.info(f"Received semantic query: {query}")
|
logger.info(f"Received semantic query: {query}")
|
||||||
|
|
||||||
ranked_chunks = retrieve_ranked_documents(query, filters, datastore)
|
ranked_chunks = retrieve_ranked_documents(
|
||||||
|
query, filters, create_datastore(collection)
|
||||||
|
)
|
||||||
if not ranked_chunks:
|
if not ranked_chunks:
|
||||||
return {"answer": None, "quotes": None}
|
return {"answer": None, "quotes": None}
|
||||||
|
|
||||||
|
qa_model = get_default_backend_qa_model()
|
||||||
answer, quotes = qa_model.answer_question(query, ranked_chunks)
|
answer, quotes = qa_model.answer_question(query, ranked_chunks)
|
||||||
|
|
||||||
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
||||||
|
|
||||||
return QAResponse(answer=answer, quotes=quotes)
|
return QAResponse(answer=answer, quotes=quotes)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/keyword-search", response_model=KeywordResponse)
|
@router.get("/stream-direct-qa")
|
||||||
|
def stream_direct_qa(question: QAQuestion):
|
||||||
|
top_documents_key = "top_documents"
|
||||||
|
answer_key = "answer"
|
||||||
|
quotes_key = "quotes"
|
||||||
|
|
||||||
|
def stream_qa_portions():
|
||||||
|
query = question.query
|
||||||
|
collection = question.collection
|
||||||
|
filters = question.filters
|
||||||
|
logger.info(f"Received semantic query: {query}")
|
||||||
|
|
||||||
|
ranked_chunks = retrieve_ranked_documents(
|
||||||
|
query, filters, create_datastore(collection)
|
||||||
|
)
|
||||||
|
if not ranked_chunks:
|
||||||
|
return yield_json_line(
|
||||||
|
{top_documents_key: None, answer_key: None, quotes_key: None}
|
||||||
|
)
|
||||||
|
|
||||||
|
linked_chunks = [
|
||||||
|
chunk
|
||||||
|
for chunk in ranked_chunks
|
||||||
|
if chunk.source_links and "0" in chunk.source_links
|
||||||
|
]
|
||||||
|
top_docs = {
|
||||||
|
top_documents_key: {
|
||||||
|
"document section links": [
|
||||||
|
chunk.source_links["0"] for chunk in linked_chunks
|
||||||
|
],
|
||||||
|
"blurbs": [chunk.blurb for chunk in linked_chunks],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
yield yield_json_line(top_docs)
|
||||||
|
|
||||||
|
qa_model = get_default_backend_qa_model()
|
||||||
|
for response_dict in qa_model.answer_question_stream(query, ranked_chunks):
|
||||||
|
logger.debug(response_dict)
|
||||||
|
yield yield_json_line(response_dict)
|
||||||
|
|
||||||
|
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/keyword-search", response_model=KeywordResponse)
|
||||||
def keyword_search(question: QAQuestion):
|
def keyword_search(question: QAQuestion):
|
||||||
ts_client = TSClient.get_instance()
|
ts_client = TSClient.get_instance()
|
||||||
query = question.query
|
query = question.query
|
||||||
|
@ -22,19 +22,27 @@ if __name__ == "__main__":
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"-s",
|
"-t",
|
||||||
"--source-types",
|
"--source-types",
|
||||||
type=str,
|
type=str,
|
||||||
help="Comma separated list of source types to filter by",
|
help="Comma separated list of source types to filter by",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"-s",
|
||||||
|
"--stream",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable streaming response",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument("query", nargs="*", help="The query to process")
|
parser.add_argument("query", nargs="*", help="The query to process")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
user_input = input(
|
user_input = input(
|
||||||
"\n\nAsk any question:\n"
|
"\n\nAsk any question:\n"
|
||||||
" - prefix with -s to add a filter by source(s)\n"
|
" - prefix with -t to add a filter by source type(s)\n"
|
||||||
|
" - prefix with -s to stream answer\n"
|
||||||
" - input an empty string to rerun last query\n\t"
|
" - input an empty string to rerun last query\n\t"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -55,40 +63,58 @@ if __name__ == "__main__":
|
|||||||
source_types = source_types[0]
|
source_types = source_types[0]
|
||||||
query = " ".join(args.query)
|
query = " ".join(args.query)
|
||||||
|
|
||||||
endpoint = f"http://127.0.0.1:{APP_PORT}/direct-qa"
|
endpoint = (
|
||||||
|
f"http://127.0.0.1:{APP_PORT}/direct-qa"
|
||||||
|
if not args.stream
|
||||||
|
else f"http://127.0.0.1:{APP_PORT}/stream-direct-qa"
|
||||||
|
)
|
||||||
if args.keyword_search:
|
if args.keyword_search:
|
||||||
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
|
endpoint = f"http://127.0.0.1:{APP_PORT}/keyword-search"
|
||||||
|
raise NotImplementedError("keyword search is not supported for now")
|
||||||
|
|
||||||
query_json = {
|
query_json = {
|
||||||
"query": query,
|
"query": query,
|
||||||
"collection": QDRANT_DEFAULT_COLLECTION,
|
"collection": QDRANT_DEFAULT_COLLECTION,
|
||||||
"filters": [{SOURCE_TYPE: source_types}],
|
"filters": [{SOURCE_TYPE: source_types}],
|
||||||
}
|
}
|
||||||
|
if not args.stream:
|
||||||
response = requests.post(endpoint, json=query_json)
|
response = requests.get(endpoint, json=query_json)
|
||||||
contents = json.loads(response.content)
|
contents = json.loads(response.content)
|
||||||
if keyword_search:
|
if keyword_search:
|
||||||
if contents["results"]:
|
if contents["results"]:
|
||||||
for link in contents["results"]:
|
for link in contents["results"]:
|
||||||
print(link)
|
print(link)
|
||||||
|
else:
|
||||||
|
print("No matches found")
|
||||||
else:
|
else:
|
||||||
print("No matches found")
|
answer = contents.get("answer")
|
||||||
|
if answer:
|
||||||
|
print("Answer: " + answer)
|
||||||
|
else:
|
||||||
|
print("Answer: ?")
|
||||||
|
if contents.get("quotes"):
|
||||||
|
for ind, (quote, quote_info) in enumerate(
|
||||||
|
contents["quotes"].items()
|
||||||
|
):
|
||||||
|
print(f"Quote {str(ind + 1)}:\n{quote}")
|
||||||
|
print(
|
||||||
|
f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}"
|
||||||
|
)
|
||||||
|
print(f"Blurb: {quote_info[BLURB]}")
|
||||||
|
print(f"Link: {quote_info[SOURCE_LINK]}")
|
||||||
|
print(f"Source: {quote_info[SOURCE_TYPE]}")
|
||||||
|
else:
|
||||||
|
print("No quotes found")
|
||||||
else:
|
else:
|
||||||
answer = contents.get("answer")
|
answer = ""
|
||||||
if answer:
|
with requests.get(endpoint, json=query_json, stream=True) as r:
|
||||||
print("Answer: " + answer)
|
for json_response in r.iter_lines():
|
||||||
else:
|
response_dict = json.loads(json_response.decode())
|
||||||
print("Answer: ?")
|
if "answer data" not in response_dict:
|
||||||
if contents.get("quotes"):
|
print(response_dict)
|
||||||
for ind, (quote, quote_info) in enumerate(
|
else:
|
||||||
contents["quotes"].items()
|
answer += response_dict["answer data"]
|
||||||
):
|
print(answer)
|
||||||
print(f"Quote {str(ind + 1)}:\n{quote}")
|
|
||||||
print(f"Semantic Identifier: {quote_info[SEMANTIC_IDENTIFIER]}")
|
|
||||||
print(f"Blurb: {quote_info[BLURB]}")
|
|
||||||
print(f"Link: {quote_info[SOURCE_LINK]}")
|
|
||||||
print(f"Source: {quote_info[SOURCE_TYPE]}")
|
|
||||||
else:
|
|
||||||
print("No quotes found")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed due to {e}, retrying")
|
print(f"Failed due to {e}, retrying")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user