mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
welcome to onyx
This commit is contained in:
@@ -0,0 +1,98 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from onyx.chat.llm_response_handler import ResponsePart
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class AnswerResponseHandler(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
# This is a dummy handler that returns nothing
|
||||
yield from []
|
||||
|
||||
|
||||
class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.display_doc_order_dict = display_doc_order_dict
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
display_doc_order_dict=self.display_doc_order_dict,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
# TODO remove this after citation issue is resolved
|
||||
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
return
|
||||
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
# Process the new content through the citation processor
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
|
||||
# No longer in use, remove later
|
||||
# class QuotesResponseHandler(AnswerResponseHandler):
|
||||
# def __init__(
|
||||
# self,
|
||||
# context_docs: list[LlmDoc],
|
||||
# is_json_prompt: bool = True,
|
||||
# ):
|
||||
# self.quotes_processor = QuotesProcessor(
|
||||
# context_docs=context_docs,
|
||||
# is_json_prompt=is_json_prompt,
|
||||
# )
|
||||
|
||||
# def handle_response_part(
|
||||
# self,
|
||||
# response_item: BaseMessage | None,
|
||||
# previous_response_items: list[BaseMessage],
|
||||
# ) -> Generator[ResponsePart, None, None]:
|
||||
# if response_item is None:
|
||||
# yield from self.quotes_processor.process_token(None)
|
||||
# return
|
||||
|
||||
# content = (
|
||||
# response_item.content if isinstance(response_item.content, str) else ""
|
||||
# )
|
||||
|
||||
# yield from self.quotes_processor.process_token(content)
|
195
backend/onyx/chat/stream_processing/citation_processing.py
Normal file
195
backend/onyx/chat/stream_processing/citation_processing.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from onyx.configs.chat_configs import STOP_STREAM_PAT
|
||||
from onyx.prompts.constants import TRIPLE_BACKTICK
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def in_code_block(llm_text: str) -> bool:
|
||||
count = llm_text.count(TRIPLE_BACKTICK)
|
||||
return count % 2 != 0
|
||||
|
||||
|
||||
class CitationProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.display_doc_order_dict = (
|
||||
display_doc_order_dict # original order of docs to displayed to user
|
||||
)
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
self.curr_segment = ""
|
||||
self.cited_inds: set[int] = set()
|
||||
self.hold = ""
|
||||
self.current_citations: list[int] = []
|
||||
self.past_cite_count = 0
|
||||
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
) -> Generator[OnyxAnswerPiece | CitationInfo, None, None]:
|
||||
# None -> end of stream
|
||||
if token is None:
|
||||
yield OnyxAnswerPiece(answer_piece=self.curr_segment)
|
||||
return
|
||||
|
||||
if self.stop_stream:
|
||||
next_hold = self.hold + token
|
||||
if self.stop_stream in next_hold:
|
||||
return
|
||||
if next_hold == self.stop_stream[: len(next_hold)]:
|
||||
self.hold = next_hold
|
||||
return
|
||||
token = next_hold
|
||||
self.hold = ""
|
||||
|
||||
self.curr_segment += token
|
||||
self.llm_out += token
|
||||
|
||||
# Handle code blocks without language tags
|
||||
if "`" in self.curr_segment:
|
||||
if self.curr_segment.endswith("`"):
|
||||
return
|
||||
elif "```" in self.curr_segment:
|
||||
piece_that_comes_after = self.curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
|
||||
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
|
||||
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
|
||||
possible_citation_found = re.search(
|
||||
possible_citation_pattern, self.curr_segment
|
||||
)
|
||||
|
||||
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
|
||||
self.current_citations = []
|
||||
|
||||
result = ""
|
||||
if citations_found and not in_code_block(self.llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(
|
||||
next(group for group in citation.groups() if group is not None)
|
||||
)
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
real_citation_num = self.order_mapping[context_llm_doc.document_id]
|
||||
|
||||
if real_citation_num not in self.citation_order:
|
||||
self.citation_order.append(real_citation_num)
|
||||
|
||||
target_citation_num = (
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_doc_order_dict:
|
||||
displayed_citation_num = self.display_doc_order_dict[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = real_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: length_to_add + start]
|
||||
+ self.curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
if self.curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
else:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
|
||||
if last_citation_end > 0:
|
||||
result += self.curr_segment[:last_citation_end]
|
||||
self.curr_segment = self.curr_segment[last_citation_end:]
|
||||
|
||||
if not possible_citation_found:
|
||||
result += self.curr_segment
|
||||
self.curr_segment = ""
|
||||
|
||||
if result:
|
||||
yield OnyxAnswerPiece(answer_piece=result)
|
315
backend/onyx/chat/stream_processing/quotes_processing.py
Normal file
315
backend/onyx/chat/stream_processing/quotes_processing.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# THIS IS NO LONGER IN USE
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
import regex
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswer
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.prompts.constants import ANSWER_PAT
|
||||
from onyx.prompts.constants import QUOTE_PAT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import clean_model_quote
|
||||
from onyx.utils.text_processing import clean_up_code_blocks
|
||||
from onyx.utils.text_processing import extract_embedded_json
|
||||
from onyx.utils.text_processing import shared_precompare_cleanup
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE)
|
||||
|
||||
|
||||
class OnyxQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class OnyxQuotes(BaseModel):
|
||||
quotes: list[OnyxQuote]
|
||||
|
||||
|
||||
def _extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> tuple[Optional[str], Optional[list[str]]]:
|
||||
"""Splits the model output into an Answer and 0 or more Quote sections.
|
||||
Splits by the Quote pattern, if not exist then assume it's all answer and no quotes
|
||||
"""
|
||||
# If no answer section, don't care about the quote
|
||||
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
|
||||
return None, None
|
||||
|
||||
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
|
||||
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
|
||||
answer_raw = answer_raw[len(ANSWER_PAT) :]
|
||||
|
||||
# Accept quote sections starting with the lower case version
|
||||
answer_raw = answer_raw.replace(
|
||||
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
|
||||
) # Just in case model unreliable
|
||||
|
||||
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
|
||||
sections_clean = [
|
||||
str(section).strip() for section in sections if str(section).strip()
|
||||
]
|
||||
if not sections_clean:
|
||||
return None, None
|
||||
|
||||
answer = str(sections_clean[0])
|
||||
if len(sections) == 1:
|
||||
return answer, None
|
||||
return answer, sections_clean[1:]
|
||||
|
||||
|
||||
def _extract_answer_quotes_json(
|
||||
answer_dict: dict[str, str | list[str]]
|
||||
) -> tuple[Optional[str], Optional[list[str]]]:
|
||||
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
|
||||
answer = str(answer_dict.get("answer"))
|
||||
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
|
||||
if isinstance(quotes, str):
|
||||
quotes = [quotes]
|
||||
return answer, quotes
|
||||
|
||||
|
||||
def _extract_answer_json(raw_model_output: str) -> dict:
|
||||
try:
|
||||
answer_json = extract_embedded_json(raw_model_output)
|
||||
except (ValueError, JSONDecodeError):
|
||||
# LLMs get confused when handling the list in the json. Sometimes it doesn't attend
|
||||
# enough to the previous { token so it just ends the list of quotes and stops there
|
||||
# here, we add logic to try to fix this LLM error.
|
||||
answer_json = extract_embedded_json(raw_model_output + "}")
|
||||
|
||||
if "answer" not in answer_json:
|
||||
raise ValueError("Model did not output an answer as expected.")
|
||||
|
||||
return answer_json
|
||||
|
||||
|
||||
def match_quotes_to_docs(
|
||||
quotes: list[str],
|
||||
docs: list[LlmDoc] | list[InferenceChunk],
|
||||
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
|
||||
fuzzy_search: bool = False,
|
||||
prefix_only_length: int = 100,
|
||||
) -> OnyxQuotes:
|
||||
onyx_quotes: list[OnyxQuote] = []
|
||||
for quote in quotes:
|
||||
max_edits = math.ceil(float(len(quote)) * max_error_percent)
|
||||
|
||||
for doc in docs:
|
||||
if not doc.source_links:
|
||||
continue
|
||||
|
||||
quote_clean = shared_precompare_cleanup(
|
||||
clean_model_quote(quote, trim_length=prefix_only_length)
|
||||
)
|
||||
chunk_clean = shared_precompare_cleanup(doc.content)
|
||||
|
||||
# Finding the offset of the quote in the plain text
|
||||
if fuzzy_search:
|
||||
re_search_str = (
|
||||
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
|
||||
)
|
||||
found = regex.search(re_search_str, chunk_clean)
|
||||
if not found:
|
||||
continue
|
||||
offset = found.span()[0]
|
||||
else:
|
||||
if quote_clean not in chunk_clean:
|
||||
continue
|
||||
offset = chunk_clean.index(quote_clean)
|
||||
|
||||
# Extracting the link from the offset
|
||||
curr_link = None
|
||||
for link_offset, link in doc.source_links.items():
|
||||
# Should always find one because offset is at least 0 and there
|
||||
# must be a 0 link_offset
|
||||
if int(link_offset) <= offset:
|
||||
curr_link = link
|
||||
else:
|
||||
break
|
||||
|
||||
onyx_quotes.append(
|
||||
OnyxQuote(
|
||||
quote=quote,
|
||||
document_id=doc.document_id,
|
||||
link=curr_link,
|
||||
source_type=doc.source_type,
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
blurb=doc.blurb,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
return OnyxQuotes(quotes=onyx_quotes)
|
||||
|
||||
|
||||
def separate_answer_quotes(
|
||||
answer_raw: str, is_json_prompt: bool = False
|
||||
) -> tuple[Optional[str], Optional[list[str]]]:
|
||||
"""Takes in a raw model output and pulls out the answer and the quotes sections."""
|
||||
if is_json_prompt:
|
||||
model_raw_json = _extract_answer_json(answer_raw)
|
||||
return _extract_answer_quotes_json(model_raw_json)
|
||||
|
||||
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
|
||||
|
||||
|
||||
def _process_answer(
|
||||
answer_raw: str,
|
||||
docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
) -> tuple[OnyxAnswer, OnyxQuotes]:
|
||||
"""Used (1) in the non-streaming case to process the model output
|
||||
into an Answer and Quotes AND (2) after the complete streaming response
|
||||
has been received to process the model output into an Answer and Quotes."""
|
||||
answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt)
|
||||
if not answer:
|
||||
logger.debug("No answer extracted from raw output")
|
||||
return OnyxAnswer(answer=None), OnyxQuotes(quotes=[])
|
||||
|
||||
logger.notice(f"Answer: {answer}")
|
||||
if not quote_strings:
|
||||
logger.debug("No quotes extracted from raw output")
|
||||
return OnyxAnswer(answer=answer), OnyxQuotes(quotes=[])
|
||||
logger.debug(f"All quotes (including unmatched): {quote_strings}")
|
||||
quotes = match_quotes_to_docs(quote_strings, docs)
|
||||
logger.debug(f"Final quotes: {quotes}")
|
||||
|
||||
return OnyxAnswer(answer=answer), quotes
|
||||
|
||||
|
||||
def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
next_token = next_token.replace('\\"', "")
|
||||
# If the previous character is an escape token, don't consider the first character of next_token
|
||||
# This does not work if it's an escaped escape sign before the " but this is rare, not worth handling
|
||||
if answer_so_far and answer_so_far[-1] == "\\":
|
||||
next_token = next_token[1:]
|
||||
if '"' in next_token:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _extract_quotes_from_completed_token_stream(
|
||||
model_output: str, context_docs: list[LlmDoc], is_json_prompt: bool = True
|
||||
) -> OnyxQuotes:
|
||||
answer, quotes = _process_answer(model_output, context_docs, is_json_prompt)
|
||||
if answer:
|
||||
logger.notice(answer)
|
||||
elif model_output:
|
||||
logger.warning("Answer extraction from model output failed.")
|
||||
|
||||
return quotes
|
||||
|
||||
|
||||
class QuotesProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.is_json_prompt = is_json_prompt
|
||||
|
||||
self.found_answer_start = False if is_json_prompt else True
|
||||
self.found_answer_end = False
|
||||
self.hold_quote = ""
|
||||
self.model_output = ""
|
||||
self.hold = ""
|
||||
|
||||
def process_token(
|
||||
self, token: str | None
|
||||
) -> Generator[OnyxAnswerPiece | OnyxQuotes, None, None]:
|
||||
# None -> end of stream
|
||||
if token is None:
|
||||
if self.model_output:
|
||||
yield _extract_quotes_from_completed_token_stream(
|
||||
model_output=self.model_output,
|
||||
context_docs=self.context_docs,
|
||||
is_json_prompt=self.is_json_prompt,
|
||||
)
|
||||
return
|
||||
|
||||
model_previous = self.model_output
|
||||
self.model_output += token
|
||||
if not self.found_answer_start:
|
||||
m = answer_pattern.search(self.model_output)
|
||||
if m:
|
||||
self.found_answer_start = True
|
||||
|
||||
# Prevent heavy cases of hallucinations
|
||||
if self.is_json_prompt and len(self.model_output) > 400:
|
||||
self.found_answer_end = True
|
||||
logger.warning("LLM did not produce json as prompted")
|
||||
logger.debug("Model output thus far:", self.model_output)
|
||||
return
|
||||
|
||||
remaining = self.model_output[m.end() :]
|
||||
|
||||
# Look for an unescaped quote, which means the answer is entirely contained
|
||||
# in this token e.g. if the token is `{"answer": "blah", "qu`
|
||||
quote_indices = [i for i, char in enumerate(remaining) if char == '"']
|
||||
for quote_idx in quote_indices:
|
||||
# Check if quote is escaped by counting backslashes before it
|
||||
num_backslashes = 0
|
||||
pos = quote_idx - 1
|
||||
while pos >= 0 and remaining[pos] == "\\":
|
||||
num_backslashes += 1
|
||||
pos -= 1
|
||||
# If even number of backslashes, quote is not escaped
|
||||
if num_backslashes % 2 == 0:
|
||||
yield OnyxAnswerPiece(answer_piece=remaining[:quote_idx])
|
||||
return
|
||||
|
||||
# If no unescaped quote found, yield the remaining string
|
||||
if len(remaining) > 0:
|
||||
yield OnyxAnswerPiece(answer_piece=remaining)
|
||||
return
|
||||
|
||||
if self.found_answer_start and not self.found_answer_end:
|
||||
if self.is_json_prompt and _stream_json_answer_end(model_previous, token):
|
||||
self.found_answer_end = True
|
||||
|
||||
if token:
|
||||
try:
|
||||
answer_token_section = token.index('"')
|
||||
yield OnyxAnswerPiece(
|
||||
answer_piece=self.hold_quote + token[:answer_token_section]
|
||||
)
|
||||
except ValueError:
|
||||
logger.error("Quotation mark not found in token")
|
||||
yield OnyxAnswerPiece(answer_piece=self.hold_quote + token)
|
||||
yield OnyxAnswerPiece(answer_piece=None)
|
||||
return
|
||||
|
||||
elif not self.is_json_prompt:
|
||||
quote_pat = f"\n{QUOTE_PAT}"
|
||||
quote_loose = f"\n{quote_pat[:-1]}\n"
|
||||
quote_pat_full = f"\n{quote_pat}"
|
||||
|
||||
if (
|
||||
quote_pat in self.hold_quote + token
|
||||
or quote_loose in self.hold_quote + token
|
||||
):
|
||||
self.found_answer_end = True
|
||||
yield OnyxAnswerPiece(answer_piece=None)
|
||||
return
|
||||
if self.hold_quote + token in quote_pat_full:
|
||||
self.hold_quote += token
|
||||
return
|
||||
|
||||
yield OnyxAnswerPiece(answer_piece=self.hold_quote + token)
|
||||
self.hold_quote = ""
|
23
backend/onyx/chat/stream_processing/utils.py
Normal file
23
backend/onyx/chat/stream_processing/utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
|
||||
|
||||
class DocumentIdOrderMapping(BaseModel):
|
||||
order_mapping: dict[str, int]
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
) -> DocumentIdOrderMapping:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
for chunk in chunks:
|
||||
if chunk.document_id not in order_mapping:
|
||||
order_mapping[chunk.document_id] = current
|
||||
current += 1
|
||||
|
||||
return DocumentIdOrderMapping(order_mapping=order_mapping)
|
Reference in New Issue
Block a user