Prevent too many tokens to GPT (#245)

This commit is contained in:
Yuhong Sun 2023-07-28 16:00:26 -07:00 committed by GitHub
parent d53ce3bda1
commit 2a339ec34b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 113 additions and 21 deletions

View File

@ -4,8 +4,9 @@ from collections.abc import Callable
from danswer.chunking.models import IndexChunk
from danswer.configs.app_configs import BLURB_LENGTH
from danswer.configs.app_configs import CHUNK_OVERLAP
from danswer.configs.app_configs import CHUNK_MAX_CHAR_OVERLAP
from danswer.configs.app_configs import CHUNK_SIZE
from danswer.configs.app_configs import CHUNK_WORD_OVERLAP
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.text_processing import shared_precompare_cleanup
@ -47,45 +48,83 @@ def chunk_large_section(
document: Document,
start_chunk_id: int,
chunk_size: int = CHUNK_SIZE,
word_overlap: int = CHUNK_OVERLAP,
word_overlap: int = CHUNK_WORD_OVERLAP,
blurb_len: int = BLURB_LENGTH,
chunk_overflow_max: int = CHUNK_MAX_CHAR_OVERLAP,
) -> list[IndexChunk]:
"""Split large sections into multiple chunks with the final chunk having as much previous overlap as possible.
Backtracks word_overlap words, delimited by whitespace, backtrack up to chunk_overflow_max characters max
When chunk is finished in forward direction, attempt to finish the word, but only up to chunk_overflow_max
Some details:
- Backtracking (overlap) => finish current word by backtracking + an additional (word_overlap - 1) words
- Continuation chunks start with a space generally unless overflow limit is hit
- Chunks end with a space generally unless overflow limit is hit
"""
section_text = section.text
blurb = extract_blurb(section_text, blurb_len)
char_count = len(section_text)
chunk_strs: list[str] = []
# start_pos is the actual start of the chunk not including the backtracking overlap
# segment_start_pos counts backwards to include overlap from previous chunk
start_pos = segment_start_pos = 0
while start_pos < char_count:
back_overflow_chars = 0
forward_overflow_chars = 0
back_count_words = 0
end_pos = segment_end_pos = min(start_pos + chunk_size, char_count)
while not section_text[segment_end_pos - 1].isspace():
if segment_end_pos >= char_count:
# Forward overlap to attempt to finish the current word
while forward_overflow_chars < chunk_overflow_max:
if (
segment_end_pos >= char_count
or section_text[segment_end_pos - 1].isspace()
):
break
segment_end_pos += 1
while back_count_words <= word_overlap:
forward_overflow_chars += 1
# Backwards overlap counting up to word_overlap words (whitespace delineated) or chunk_overflow_max chars
# Counts back by finishing current word by backtracking + an additional (word_overlap - 1) words
# If starts on a space, it considers finishing the current word as done
while back_overflow_chars < chunk_overflow_max:
if segment_start_pos == 0:
break
# no -1 offset here because we want to include prepended space to be clear it's a continuation
if section_text[segment_start_pos].isspace():
back_count_words += 1
if back_count_words > word_overlap:
break
back_count_words += 1
segment_start_pos -= 1
if segment_start_pos != 0:
segment_start_pos += 2
back_overflow_chars += 1
# Extract chunk from section text based on the pointers from above
chunk_str = section_text[segment_start_pos:segment_end_pos]
if chunk_str[-1].isspace():
chunk_str = chunk_str[:-1]
chunk_strs.append(chunk_str)
# Move pointers to next section, not counting overlaps forward or backward
start_pos = segment_start_pos = end_pos
# Last chunk should be as long as possible, overlap favored over tiny chunk with no context
if len(chunk_strs) > 1:
chunk_strs.pop()
back_count_words = 0
back_overflow_chars = 0
# Backcount chunk size number of characters then
# add in the backcounting overlap like with every other previous chunk
start_pos = char_count - chunk_size
while back_count_words <= word_overlap:
while back_overflow_chars < chunk_overflow_max:
if start_pos == 0:
break
if section_text[start_pos].isspace():
if back_count_words > word_overlap:
break
back_count_words += 1
start_pos -= 1
chunk_strs.append(section_text[start_pos + 2 :])
back_overflow_chars += 1
chunk_strs.append(section_text[start_pos:])
chunks = []
for chunk_ind, chunk_str in enumerate(chunk_strs):
@ -105,7 +144,7 @@ def chunk_large_section(
def chunk_document(
document: Document,
chunk_size: int = CHUNK_SIZE,
subsection_overlap: int = CHUNK_OVERLAP,
subsection_overlap: int = CHUNK_WORD_OVERLAP,
blurb_len: int = BLURB_LENGTH,
) -> list[IndexChunk]:
chunks: list[IndexChunk] = []

View File

@ -115,9 +115,11 @@ ENABLE_MINI_CHUNK = False
# Mini chunks for fine-grained embedding, calculated as 128 tokens for 4 additional vectors for 512 chunk size above
# Not rounded down to not lose any context in full chunk.
MINI_CHUNK_SIZE = 512
# Each chunk includes an additional 5 words from previous chunk
# in extreme cases, may cause some words at the end to be truncated by embedding model
CHUNK_OVERLAP = 5
# Each chunk includes an additional CHUNK_WORD_OVERLAP words from previous chunk
CHUNK_WORD_OVERLAP = 5
# When trying to finish the last word in the chunk or counting back CHUNK_WORD_OVERLAP backwards,
# This is the max number of characters allowed in either direction
CHUNK_MAX_CHAR_OVERLAP = 50
#####

View File

@ -80,7 +80,7 @@ class WebConnector(LoadConnector):
continue
visited_links.add(current_url)
logger.info(f"Indexing {current_url}")
logger.info(f"Visiting {current_url}")
try:
current_visit_time = datetime.now().strftime("%B %d, %Y, %H:%M:%S")

View File

@ -2,12 +2,14 @@ import json
from abc import ABC
from collections.abc import Callable
from collections.abc import Generator
from copy import copy
from functools import wraps
from typing import Any
from typing import cast
from typing import TypeVar
import openai
import tiktoken
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import INCLUDE_METADATA
from danswer.configs.app_configs import OPENAI_API_KEY
@ -89,6 +91,24 @@ def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
return cast(F, wrapped_call)
def _tiktoken_trim_chunks(
chunks: list[InferenceChunk], model_version: str, max_chunk_toks: int = 512
) -> list[InferenceChunk]:
"""Edit chunks that have too high token count. Generally due to parsing issues or
characters from another language that are 1 char = 1 token
Trimming by tokens leads to information loss but currently no better way of handling
"""
encoder = tiktoken.encoding_for_model(model_version)
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
tokens = encoder.encode(chunk.content)
if len(tokens) > max_chunk_toks:
new_chunk = copy(chunk)
new_chunk.content = encoder.decode(tokens[:max_chunk_toks])
new_chunks[ind] = new_chunk
return new_chunks
# used to check if the QAModel is an OpenAI model
class OpenAIQAModel(QAModel, ABC):
pass
@ -123,6 +143,8 @@ class OpenAICompletionQA(OpenAIQAModel):
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
@ -151,6 +173,8 @@ class OpenAICompletionQA(OpenAIQAModel):
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
@ -215,6 +239,8 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
query: str,
context_docs: list[InferenceChunk],
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
messages = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
@ -251,6 +277,8 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
messages = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)

View File

@ -58,7 +58,7 @@ def semantic_reranking(
logger.debug(f"Reranked similarity scores: {ranked_sim_scores}")
return ranked_chunks
return list(ranked_chunks)
@log_function_time()

View File

@ -37,6 +37,7 @@ sentence-transformers==2.2.2
slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.12
tensorflow==2.12.0
tiktoken==0.4.0
transformers==4.30.1
typesense==0.15.1
uvicorn==0.21.1

View File

@ -19,6 +19,9 @@ WAR_AND_PEACE = (
class TestDocumentChunking(unittest.TestCase):
def setUp(self) -> None:
self.large_section = Section(text=WAR_AND_PEACE, link="https://www.test.com/")
self.large_unbroken_section = Section(
text="0123456789" * 40, link="https://www.test.com/"
)
self.document = Document(
id="test_document",
sections=[
@ -52,18 +55,37 @@ class TestDocumentChunking(unittest.TestCase):
chunk_size=100,
word_overlap=3,
)
self.assertEqual(len(chunks), 5)
self.assertEqual(chunks[0].content, WAR_AND_PEACE[:99])
contents = [chunk.content for chunk in chunks]
self.assertEqual(len(contents), 5)
self.assertEqual(contents[0], WAR_AND_PEACE[:100])
self.assertEqual(
chunks[-2].content, WAR_AND_PEACE[-176:-63]
contents[-2], WAR_AND_PEACE[-172:-62]
) # slightly longer than 100 due to overlap
self.assertEqual(
chunks[-1].content, WAR_AND_PEACE[-121:]
contents[-1], WAR_AND_PEACE[-125:]
) # large overlap with second to last segment
self.assertFalse(chunks[0].section_continuation)
self.assertTrue(chunks[1].section_continuation)
self.assertTrue(chunks[-1].section_continuation)
def test_chunk_max_overflow(self) -> None:
chunks = chunk_large_section(
section=self.large_unbroken_section,
document=self.document,
start_chunk_id=5,
chunk_size=100,
word_overlap=3,
)
contents = [chunk.content for chunk in chunks]
self.assertEqual(len(contents), 4)
self.assertEqual(contents[0], self.large_unbroken_section.text[:150])
self.assertEqual(contents[1], self.large_unbroken_section.text[50:250])
self.assertEqual(contents[2], self.large_unbroken_section.text[150:350])
# Last chunk counts back from the end, full chunk size (100) + 50 overlap => 400 - 150 = 250
self.assertEqual(contents[3], self.large_unbroken_section.text[250:])
def test_chunk_document(self) -> None:
chunks = chunk_document(self.document, chunk_size=100, subsection_overlap=3)
self.assertEqual(len(chunks), 8)