mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Quote loading UI + adding back period to end of answer + adding custom logo (#55)
* Logo * Add spinners + some small housekeeping on the backend
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
@@ -18,5 +19,5 @@ class QAModel:
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> Any:
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
raise NotImplementedError
|
||||
|
@@ -4,6 +4,7 @@ import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
@@ -241,12 +242,12 @@ class OpenAICompletionQA(QAModel):
|
||||
stream=True,
|
||||
)
|
||||
|
||||
model_output = ""
|
||||
model_output: str = ""
|
||||
found_answer_start = False
|
||||
found_answer_end = False
|
||||
# iterate through the stream of events
|
||||
for event in response:
|
||||
event_text = event["choices"][0]["text"]
|
||||
event_text = cast(str, event["choices"][0]["text"])
|
||||
model_previous = model_output
|
||||
model_output += event_text
|
||||
|
||||
@@ -259,6 +260,7 @@ class OpenAICompletionQA(QAModel):
|
||||
if found_answer_start and not found_answer_end:
|
||||
if stream_answer_end(model_previous, event_text):
|
||||
found_answer_end = True
|
||||
yield {"answer_finished": True}
|
||||
continue
|
||||
yield {"answer_data": event_text}
|
||||
|
||||
@@ -343,11 +345,11 @@ class OpenAIChatCompletionQA(QAModel):
|
||||
stream=True,
|
||||
)
|
||||
|
||||
model_output = ""
|
||||
model_output: str = ""
|
||||
found_answer_start = False
|
||||
found_answer_end = False
|
||||
for event in response:
|
||||
event_dict = event["choices"][0]["delta"]
|
||||
event_dict = cast(str, event["choices"][0]["delta"])
|
||||
if (
|
||||
"content" not in event_dict
|
||||
): # could be a role message or empty termination
|
||||
@@ -365,6 +367,7 @@ class OpenAIChatCompletionQA(QAModel):
|
||||
if found_answer_start and not found_answer_end:
|
||||
if stream_answer_end(model_previous, event_text):
|
||||
found_answer_end = True
|
||||
yield {"answer_finished": True}
|
||||
continue
|
||||
yield {"answer_data": event_text}
|
||||
|
||||
|
@@ -16,7 +16,6 @@
|
||||
# Specifically the sentence-transformers/all-distilroberta-v1 and cross-encoder/ms-marco-MiniLM-L-6-v2 models
|
||||
# The original authors can be found at https://www.sbert.net/
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
@@ -65,8 +64,8 @@ def warm_up_models() -> None:
|
||||
@log_function_time()
|
||||
def semantic_reranking(
|
||||
query: str,
|
||||
chunks: List[InferenceChunk],
|
||||
) -> List[InferenceChunk]:
|
||||
chunks: list[InferenceChunk],
|
||||
) -> list[InferenceChunk]:
|
||||
cross_encoder = get_default_reranking_model()
|
||||
sim_scores = cross_encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
|
||||
scored_results = list(zip(sim_scores, chunks))
|
||||
@@ -84,7 +83,7 @@ def retrieve_ranked_documents(
|
||||
filters: list[DatastoreFilter] | None,
|
||||
datastore: Datastore,
|
||||
num_hits: int = NUM_RETURNED_HITS,
|
||||
) -> List[InferenceChunk] | None:
|
||||
) -> list[InferenceChunk] | None:
|
||||
top_chunks = datastore.semantic_retrieval(query, filters, num_hits)
|
||||
if not top_chunks:
|
||||
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
|
||||
|
@@ -1,15 +1,18 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.utils.logging import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
|
||||
def log_function_time(
|
||||
func_name: str | None = None,
|
||||
) -> Callable[[Callable], Callable]:
|
||||
) -> Callable[[F], F]:
|
||||
"""Build a timing wrapper for a function. Logs how long the function took to run.
|
||||
Use like:
|
||||
|
||||
|
Reference in New Issue
Block a user