Quote loading UI + adding back period to end of answer + adding custom logo (#55)

* Logo

* Add spinners + some small housekeeping on the backend
This commit is contained in:
Chris Weaver
2023-05-16 20:14:06 -07:00
committed by GitHub
parent 821df50fa9
commit 494514dc68
9 changed files with 117 additions and 56 deletions

View File

@@ -1,4 +1,5 @@
import abc
from collections.abc import Generator
from typing import Any
from danswer.chunking.models import InferenceChunk
@@ -18,5 +19,5 @@ class QAModel:
self,
query: str,
context_docs: list[InferenceChunk],
) -> Any:
) -> Generator[dict[str, Any] | None, None, None]:
raise NotImplementedError

View File

@@ -4,6 +4,7 @@ import re
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import Dict
from typing import Optional
from typing import Tuple
@@ -241,12 +242,12 @@ class OpenAICompletionQA(QAModel):
stream=True,
)
model_output = ""
model_output: str = ""
found_answer_start = False
found_answer_end = False
# iterate through the stream of events
for event in response:
event_text = event["choices"][0]["text"]
event_text = cast(str, event["choices"][0]["text"])
model_previous = model_output
model_output += event_text
@@ -259,6 +260,7 @@ class OpenAICompletionQA(QAModel):
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
yield {"answer_finished": True}
continue
yield {"answer_data": event_text}
@@ -343,11 +345,11 @@ class OpenAIChatCompletionQA(QAModel):
stream=True,
)
model_output = ""
model_output: str = ""
found_answer_start = False
found_answer_end = False
for event in response:
event_dict = event["choices"][0]["delta"]
event_dict = cast(str, event["choices"][0]["delta"])
if (
"content" not in event_dict
): # could be a role message or empty termination
@@ -365,6 +367,7 @@ class OpenAIChatCompletionQA(QAModel):
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
yield {"answer_finished": True}
continue
yield {"answer_data": event_text}

View File

@@ -16,7 +16,6 @@
# Specifically the sentence-transformers/all-distilroberta-v1 and cross-encoder/ms-marco-MiniLM-L-6-v2 models
# The original authors can be found at https://www.sbert.net/
import json
from typing import List
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RETURNED_HITS
@@ -65,8 +64,8 @@ def warm_up_models() -> None:
@log_function_time()
def semantic_reranking(
query: str,
chunks: List[InferenceChunk],
) -> List[InferenceChunk]:
chunks: list[InferenceChunk],
) -> list[InferenceChunk]:
cross_encoder = get_default_reranking_model()
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
scored_results = list(zip(sim_scores, chunks))
@@ -84,7 +83,7 @@ def retrieve_ranked_documents(
filters: list[DatastoreFilter] | None,
datastore: Datastore,
num_hits: int = NUM_RETURNED_HITS,
) -> List[InferenceChunk] | None:
) -> list[InferenceChunk] | None:
top_chunks = datastore.semantic_retrieval(query, filters, num_hits)
if not top_chunks:
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")

View File

@@ -1,15 +1,18 @@
import time
from collections.abc import Callable
from typing import Any
from typing import TypeVar
from danswer.utils.logging import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable)
def log_function_time(
func_name: str | None = None,
) -> Callable[[Callable], Callable]:
) -> Callable[[F], F]:
"""Build a timing wrapper for a function. Logs how long the function took to run.
Use like: