mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Prevent too many tokens to GPT (#245)
This commit is contained in:
parent
d53ce3bda1
commit
2a339ec34b
@ -4,8 +4,9 @@ from collections.abc import Callable
|
||||
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.configs.app_configs import BLURB_LENGTH
|
||||
from danswer.configs.app_configs import CHUNK_OVERLAP
|
||||
from danswer.configs.app_configs import CHUNK_MAX_CHAR_OVERLAP
|
||||
from danswer.configs.app_configs import CHUNK_SIZE
|
||||
from danswer.configs.app_configs import CHUNK_WORD_OVERLAP
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
@ -47,45 +48,83 @@ def chunk_large_section(
|
||||
document: Document,
|
||||
start_chunk_id: int,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
word_overlap: int = CHUNK_OVERLAP,
|
||||
word_overlap: int = CHUNK_WORD_OVERLAP,
|
||||
blurb_len: int = BLURB_LENGTH,
|
||||
chunk_overflow_max: int = CHUNK_MAX_CHAR_OVERLAP,
|
||||
) -> list[IndexChunk]:
|
||||
"""Split large sections into multiple chunks with the final chunk having as much previous overlap as possible.
|
||||
Backtracks word_overlap words, delimited by whitespace, backtrack up to chunk_overflow_max characters max
|
||||
When chunk is finished in forward direction, attempt to finish the word, but only up to chunk_overflow_max
|
||||
|
||||
Some details:
|
||||
- Backtracking (overlap) => finish current word by backtracking + an additional (word_overlap - 1) words
|
||||
- Continuation chunks start with a space generally unless overflow limit is hit
|
||||
- Chunks end with a space generally unless overflow limit is hit
|
||||
"""
|
||||
section_text = section.text
|
||||
blurb = extract_blurb(section_text, blurb_len)
|
||||
char_count = len(section_text)
|
||||
chunk_strs: list[str] = []
|
||||
|
||||
# start_pos is the actual start of the chunk not including the backtracking overlap
|
||||
# segment_start_pos counts backwards to include overlap from previous chunk
|
||||
start_pos = segment_start_pos = 0
|
||||
while start_pos < char_count:
|
||||
back_overflow_chars = 0
|
||||
forward_overflow_chars = 0
|
||||
back_count_words = 0
|
||||
end_pos = segment_end_pos = min(start_pos + chunk_size, char_count)
|
||||
while not section_text[segment_end_pos - 1].isspace():
|
||||
if segment_end_pos >= char_count:
|
||||
|
||||
# Forward overlap to attempt to finish the current word
|
||||
while forward_overflow_chars < chunk_overflow_max:
|
||||
if (
|
||||
segment_end_pos >= char_count
|
||||
or section_text[segment_end_pos - 1].isspace()
|
||||
):
|
||||
break
|
||||
segment_end_pos += 1
|
||||
while back_count_words <= word_overlap:
|
||||
forward_overflow_chars += 1
|
||||
|
||||
# Backwards overlap counting up to word_overlap words (whitespace delineated) or chunk_overflow_max chars
|
||||
# Counts back by finishing current word by backtracking + an additional (word_overlap - 1) words
|
||||
# If starts on a space, it considers finishing the current word as done
|
||||
while back_overflow_chars < chunk_overflow_max:
|
||||
if segment_start_pos == 0:
|
||||
break
|
||||
# no -1 offset here because we want to include prepended space to be clear it's a continuation
|
||||
if section_text[segment_start_pos].isspace():
|
||||
back_count_words += 1
|
||||
if back_count_words > word_overlap:
|
||||
break
|
||||
back_count_words += 1
|
||||
segment_start_pos -= 1
|
||||
if segment_start_pos != 0:
|
||||
segment_start_pos += 2
|
||||
back_overflow_chars += 1
|
||||
|
||||
# Extract chunk from section text based on the pointers from above
|
||||
chunk_str = section_text[segment_start_pos:segment_end_pos]
|
||||
if chunk_str[-1].isspace():
|
||||
chunk_str = chunk_str[:-1]
|
||||
chunk_strs.append(chunk_str)
|
||||
|
||||
# Move pointers to next section, not counting overlaps forward or backward
|
||||
start_pos = segment_start_pos = end_pos
|
||||
|
||||
# Last chunk should be as long as possible, overlap favored over tiny chunk with no context
|
||||
if len(chunk_strs) > 1:
|
||||
chunk_strs.pop()
|
||||
back_count_words = 0
|
||||
back_overflow_chars = 0
|
||||
# Backcount chunk size number of characters then
|
||||
# add in the backcounting overlap like with every other previous chunk
|
||||
start_pos = char_count - chunk_size
|
||||
while back_count_words <= word_overlap:
|
||||
while back_overflow_chars < chunk_overflow_max:
|
||||
if start_pos == 0:
|
||||
break
|
||||
if section_text[start_pos].isspace():
|
||||
if back_count_words > word_overlap:
|
||||
break
|
||||
back_count_words += 1
|
||||
start_pos -= 1
|
||||
chunk_strs.append(section_text[start_pos + 2 :])
|
||||
back_overflow_chars += 1
|
||||
chunk_strs.append(section_text[start_pos:])
|
||||
|
||||
chunks = []
|
||||
for chunk_ind, chunk_str in enumerate(chunk_strs):
|
||||
@ -105,7 +144,7 @@ def chunk_large_section(
|
||||
def chunk_document(
|
||||
document: Document,
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
subsection_overlap: int = CHUNK_OVERLAP,
|
||||
subsection_overlap: int = CHUNK_WORD_OVERLAP,
|
||||
blurb_len: int = BLURB_LENGTH,
|
||||
) -> list[IndexChunk]:
|
||||
chunks: list[IndexChunk] = []
|
||||
|
@ -115,9 +115,11 @@ ENABLE_MINI_CHUNK = False
|
||||
# Mini chunks for fine-grained embedding, calculated as 128 tokens for 4 additional vectors for 512 chunk size above
|
||||
# Not rounded down to not lose any context in full chunk.
|
||||
MINI_CHUNK_SIZE = 512
|
||||
# Each chunk includes an additional 5 words from previous chunk
|
||||
# in extreme cases, may cause some words at the end to be truncated by embedding model
|
||||
CHUNK_OVERLAP = 5
|
||||
# Each chunk includes an additional CHUNK_WORD_OVERLAP words from previous chunk
|
||||
CHUNK_WORD_OVERLAP = 5
|
||||
# When trying to finish the last word in the chunk or counting back CHUNK_WORD_OVERLAP backwards,
|
||||
# This is the max number of characters allowed in either direction
|
||||
CHUNK_MAX_CHAR_OVERLAP = 50
|
||||
|
||||
|
||||
#####
|
||||
|
@ -80,7 +80,7 @@ class WebConnector(LoadConnector):
|
||||
continue
|
||||
visited_links.add(current_url)
|
||||
|
||||
logger.info(f"Indexing {current_url}")
|
||||
logger.info(f"Visiting {current_url}")
|
||||
|
||||
try:
|
||||
current_visit_time = datetime.now().strftime("%B %d, %Y, %H:%M:%S")
|
||||
|
@ -2,12 +2,14 @@ import json
|
||||
from abc import ABC
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from copy import copy
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import openai
|
||||
import tiktoken
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import INCLUDE_METADATA
|
||||
from danswer.configs.app_configs import OPENAI_API_KEY
|
||||
@ -89,6 +91,24 @@ def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _tiktoken_trim_chunks(
|
||||
chunks: list[InferenceChunk], model_version: str, max_chunk_toks: int = 512
|
||||
) -> list[InferenceChunk]:
|
||||
"""Edit chunks that have too high token count. Generally due to parsing issues or
|
||||
characters from another language that are 1 char = 1 token
|
||||
Trimming by tokens leads to information loss but currently no better way of handling
|
||||
"""
|
||||
encoder = tiktoken.encoding_for_model(model_version)
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
tokens = encoder.encode(chunk.content)
|
||||
if len(tokens) > max_chunk_toks:
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = encoder.decode(tokens[:max_chunk_toks])
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
|
||||
|
||||
# used to check if the QAModel is an OpenAI model
|
||||
class OpenAIQAModel(QAModel, ABC):
|
||||
pass
|
||||
@ -123,6 +143,8 @@ class OpenAICompletionQA(OpenAIQAModel):
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
@ -151,6 +173,8 @@ class OpenAICompletionQA(OpenAIQAModel):
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
@ -215,6 +239,8 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
messages = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
@ -251,6 +277,8 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
messages = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
|
@ -58,7 +58,7 @@ def semantic_reranking(
|
||||
|
||||
logger.debug(f"Reranked similarity scores: {ranked_sim_scores}")
|
||||
|
||||
return ranked_chunks
|
||||
return list(ranked_chunks)
|
||||
|
||||
|
||||
@log_function_time()
|
||||
|
@ -37,6 +37,7 @@ sentence-transformers==2.2.2
|
||||
slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.12
|
||||
tensorflow==2.12.0
|
||||
tiktoken==0.4.0
|
||||
transformers==4.30.1
|
||||
typesense==0.15.1
|
||||
uvicorn==0.21.1
|
||||
|
@ -19,6 +19,9 @@ WAR_AND_PEACE = (
|
||||
class TestDocumentChunking(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.large_section = Section(text=WAR_AND_PEACE, link="https://www.test.com/")
|
||||
self.large_unbroken_section = Section(
|
||||
text="0123456789" * 40, link="https://www.test.com/"
|
||||
)
|
||||
self.document = Document(
|
||||
id="test_document",
|
||||
sections=[
|
||||
@ -52,18 +55,37 @@ class TestDocumentChunking(unittest.TestCase):
|
||||
chunk_size=100,
|
||||
word_overlap=3,
|
||||
)
|
||||
self.assertEqual(len(chunks), 5)
|
||||
self.assertEqual(chunks[0].content, WAR_AND_PEACE[:99])
|
||||
contents = [chunk.content for chunk in chunks]
|
||||
|
||||
self.assertEqual(len(contents), 5)
|
||||
self.assertEqual(contents[0], WAR_AND_PEACE[:100])
|
||||
self.assertEqual(
|
||||
chunks[-2].content, WAR_AND_PEACE[-176:-63]
|
||||
contents[-2], WAR_AND_PEACE[-172:-62]
|
||||
) # slightly longer than 100 due to overlap
|
||||
self.assertEqual(
|
||||
chunks[-1].content, WAR_AND_PEACE[-121:]
|
||||
contents[-1], WAR_AND_PEACE[-125:]
|
||||
) # large overlap with second to last segment
|
||||
self.assertFalse(chunks[0].section_continuation)
|
||||
self.assertTrue(chunks[1].section_continuation)
|
||||
self.assertTrue(chunks[-1].section_continuation)
|
||||
|
||||
def test_chunk_max_overflow(self) -> None:
|
||||
chunks = chunk_large_section(
|
||||
section=self.large_unbroken_section,
|
||||
document=self.document,
|
||||
start_chunk_id=5,
|
||||
chunk_size=100,
|
||||
word_overlap=3,
|
||||
)
|
||||
contents = [chunk.content for chunk in chunks]
|
||||
|
||||
self.assertEqual(len(contents), 4)
|
||||
self.assertEqual(contents[0], self.large_unbroken_section.text[:150])
|
||||
self.assertEqual(contents[1], self.large_unbroken_section.text[50:250])
|
||||
self.assertEqual(contents[2], self.large_unbroken_section.text[150:350])
|
||||
# Last chunk counts back from the end, full chunk size (100) + 50 overlap => 400 - 150 = 250
|
||||
self.assertEqual(contents[3], self.large_unbroken_section.text[250:])
|
||||
|
||||
def test_chunk_document(self) -> None:
|
||||
chunks = chunk_document(self.document, chunk_size=100, subsection_overlap=3)
|
||||
self.assertEqual(len(chunks), 8)
|
||||
|
Loading…
x
Reference in New Issue
Block a user