From 2a339ec34b73925edc727db0e8b07fa14aa0450e Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 28 Jul 2023 16:00:26 -0700 Subject: [PATCH] Prevent too many tokens to GPT (#245) --- backend/danswer/chunking/chunk.py | 63 +++++++++++++++---- backend/danswer/configs/app_configs.py | 8 ++- backend/danswer/connectors/web/connector.py | 2 +- backend/danswer/direct_qa/open_ai.py | 28 +++++++++ backend/danswer/search/semantic_search.py | 2 +- backend/requirements/default.txt | 1 + .../unit/qa_service/chunking/test_chunk.py | 30 +++++++-- 7 files changed, 113 insertions(+), 21 deletions(-) diff --git a/backend/danswer/chunking/chunk.py b/backend/danswer/chunking/chunk.py index 1dda4b9da..faa3196b5 100644 --- a/backend/danswer/chunking/chunk.py +++ b/backend/danswer/chunking/chunk.py @@ -4,8 +4,9 @@ from collections.abc import Callable from danswer.chunking.models import IndexChunk from danswer.configs.app_configs import BLURB_LENGTH -from danswer.configs.app_configs import CHUNK_OVERLAP +from danswer.configs.app_configs import CHUNK_MAX_CHAR_OVERLAP from danswer.configs.app_configs import CHUNK_SIZE +from danswer.configs.app_configs import CHUNK_WORD_OVERLAP from danswer.connectors.models import Document from danswer.connectors.models import Section from danswer.utils.text_processing import shared_precompare_cleanup @@ -47,45 +48,83 @@ def chunk_large_section( document: Document, start_chunk_id: int, chunk_size: int = CHUNK_SIZE, - word_overlap: int = CHUNK_OVERLAP, + word_overlap: int = CHUNK_WORD_OVERLAP, blurb_len: int = BLURB_LENGTH, + chunk_overflow_max: int = CHUNK_MAX_CHAR_OVERLAP, ) -> list[IndexChunk]: + """Split large sections into multiple chunks with the final chunk having as much previous overlap as possible. + Backtracks word_overlap words, delimited by whitespace, backtrack up to chunk_overflow_max characters max + When chunk is finished in forward direction, attempt to finish the word, but only up to chunk_overflow_max + + Some details: + - Backtracking (overlap) => finish current word by backtracking + an additional (word_overlap - 1) words + - Continuation chunks start with a space generally unless overflow limit is hit + - Chunks end with a space generally unless overflow limit is hit + """ section_text = section.text blurb = extract_blurb(section_text, blurb_len) char_count = len(section_text) chunk_strs: list[str] = [] + + # start_pos is the actual start of the chunk not including the backtracking overlap + # segment_start_pos counts backwards to include overlap from previous chunk start_pos = segment_start_pos = 0 while start_pos < char_count: + back_overflow_chars = 0 + forward_overflow_chars = 0 back_count_words = 0 end_pos = segment_end_pos = min(start_pos + chunk_size, char_count) - while not section_text[segment_end_pos - 1].isspace(): - if segment_end_pos >= char_count: + + # Forward overlap to attempt to finish the current word + while forward_overflow_chars < chunk_overflow_max: + if ( + segment_end_pos >= char_count + or section_text[segment_end_pos - 1].isspace() + ): break segment_end_pos += 1 - while back_count_words <= word_overlap: + forward_overflow_chars += 1 + + # Backwards overlap counting up to word_overlap words (whitespace delineated) or chunk_overflow_max chars + # Counts back by finishing current word by backtracking + an additional (word_overlap - 1) words + # If starts on a space, it considers finishing the current word as done + while back_overflow_chars < chunk_overflow_max: if segment_start_pos == 0: break + # no -1 offset here because we want to include prepended space to be clear it's a continuation if section_text[segment_start_pos].isspace(): back_count_words += 1 + if back_count_words > word_overlap: + break + back_count_words += 1 segment_start_pos -= 1 - if segment_start_pos != 0: - segment_start_pos += 2 + back_overflow_chars += 1 + + # Extract chunk from section text based on the pointers from above chunk_str = section_text[segment_start_pos:segment_end_pos] - if chunk_str[-1].isspace(): - chunk_str = chunk_str[:-1] chunk_strs.append(chunk_str) + + # Move pointers to next section, not counting overlaps forward or backward start_pos = segment_start_pos = end_pos # Last chunk should be as long as possible, overlap favored over tiny chunk with no context if len(chunk_strs) > 1: chunk_strs.pop() back_count_words = 0 + back_overflow_chars = 0 + # Backcount chunk size number of characters then + # add in the backcounting overlap like with every other previous chunk start_pos = char_count - chunk_size - while back_count_words <= word_overlap: + while back_overflow_chars < chunk_overflow_max: + if start_pos == 0: + break if section_text[start_pos].isspace(): + if back_count_words > word_overlap: + break back_count_words += 1 start_pos -= 1 - chunk_strs.append(section_text[start_pos + 2 :]) + back_overflow_chars += 1 + chunk_strs.append(section_text[start_pos:]) chunks = [] for chunk_ind, chunk_str in enumerate(chunk_strs): @@ -105,7 +144,7 @@ def chunk_large_section( def chunk_document( document: Document, chunk_size: int = CHUNK_SIZE, - subsection_overlap: int = CHUNK_OVERLAP, + subsection_overlap: int = CHUNK_WORD_OVERLAP, blurb_len: int = BLURB_LENGTH, ) -> list[IndexChunk]: chunks: list[IndexChunk] = [] diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index a77e9209e..8ee836b12 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -115,9 +115,11 @@ ENABLE_MINI_CHUNK = False # Mini chunks for fine-grained embedding, calculated as 128 tokens for 4 additional vectors for 512 chunk size above # Not rounded down to not lose any context in full chunk. MINI_CHUNK_SIZE = 512 -# Each chunk includes an additional 5 words from previous chunk -# in extreme cases, may cause some words at the end to be truncated by embedding model -CHUNK_OVERLAP = 5 +# Each chunk includes an additional CHUNK_WORD_OVERLAP words from previous chunk +CHUNK_WORD_OVERLAP = 5 +# When trying to finish the last word in the chunk or counting back CHUNK_WORD_OVERLAP backwards, +# This is the max number of characters allowed in either direction +CHUNK_MAX_CHAR_OVERLAP = 50 ##### diff --git a/backend/danswer/connectors/web/connector.py b/backend/danswer/connectors/web/connector.py index 44bd53284..38941fb1a 100644 --- a/backend/danswer/connectors/web/connector.py +++ b/backend/danswer/connectors/web/connector.py @@ -80,7 +80,7 @@ class WebConnector(LoadConnector): continue visited_links.add(current_url) - logger.info(f"Indexing {current_url}") + logger.info(f"Visiting {current_url}") try: current_visit_time = datetime.now().strftime("%B %d, %Y, %H:%M:%S") diff --git a/backend/danswer/direct_qa/open_ai.py b/backend/danswer/direct_qa/open_ai.py index 87aac726c..47ed4cbeb 100644 --- a/backend/danswer/direct_qa/open_ai.py +++ b/backend/danswer/direct_qa/open_ai.py @@ -2,12 +2,14 @@ import json from abc import ABC from collections.abc import Callable from collections.abc import Generator +from copy import copy from functools import wraps from typing import Any from typing import cast from typing import TypeVar import openai +import tiktoken from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import INCLUDE_METADATA from danswer.configs.app_configs import OPENAI_API_KEY @@ -89,6 +91,24 @@ def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F: return cast(F, wrapped_call) +def _tiktoken_trim_chunks( + chunks: list[InferenceChunk], model_version: str, max_chunk_toks: int = 512 +) -> list[InferenceChunk]: + """Edit chunks that have too high token count. Generally due to parsing issues or + characters from another language that are 1 char = 1 token + Trimming by tokens leads to information loss but currently no better way of handling + """ + encoder = tiktoken.encoding_for_model(model_version) + new_chunks = copy(chunks) + for ind, chunk in enumerate(new_chunks): + tokens = encoder.encode(chunk.content) + if len(tokens) > max_chunk_toks: + new_chunk = copy(chunk) + new_chunk.content = encoder.decode(tokens[:max_chunk_toks]) + new_chunks[ind] = new_chunk + return new_chunks + + # used to check if the QAModel is an OpenAI model class OpenAIQAModel(QAModel, ABC): pass @@ -123,6 +143,8 @@ class OpenAICompletionQA(OpenAIQAModel): def answer_question( self, query: str, context_docs: list[InferenceChunk] ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) + filled_prompt = self.prompt_processor.fill_prompt( query, context_docs, self.include_metadata ) @@ -151,6 +173,8 @@ class OpenAICompletionQA(OpenAIQAModel): def answer_question_stream( self, query: str, context_docs: list[InferenceChunk] ) -> Generator[dict[str, Any] | None, None, None]: + context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) + filled_prompt = self.prompt_processor.fill_prompt( query, context_docs, self.include_metadata ) @@ -215,6 +239,8 @@ class OpenAIChatCompletionQA(OpenAIQAModel): query: str, context_docs: list[InferenceChunk], ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) + messages = self.prompt_processor.fill_prompt( query, context_docs, self.include_metadata ) @@ -251,6 +277,8 @@ class OpenAIChatCompletionQA(OpenAIQAModel): def answer_question_stream( self, query: str, context_docs: list[InferenceChunk] ) -> Generator[dict[str, Any] | None, None, None]: + context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) + messages = self.prompt_processor.fill_prompt( query, context_docs, self.include_metadata ) diff --git a/backend/danswer/search/semantic_search.py b/backend/danswer/search/semantic_search.py index 78eae870f..cc09737b2 100644 --- a/backend/danswer/search/semantic_search.py +++ b/backend/danswer/search/semantic_search.py @@ -58,7 +58,7 @@ def semantic_reranking( logger.debug(f"Reranked similarity scores: {ranked_sim_scores}") - return ranked_chunks + return list(ranked_chunks) @log_function_time() diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 32abb8de8..5e9add21e 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -37,6 +37,7 @@ sentence-transformers==2.2.2 slack-sdk==3.20.2 SQLAlchemy[mypy]==2.0.12 tensorflow==2.12.0 +tiktoken==0.4.0 transformers==4.30.1 typesense==0.15.1 uvicorn==0.21.1 diff --git a/backend/tests/unit/qa_service/chunking/test_chunk.py b/backend/tests/unit/qa_service/chunking/test_chunk.py index 142f7e240..738ba0c1f 100644 --- a/backend/tests/unit/qa_service/chunking/test_chunk.py +++ b/backend/tests/unit/qa_service/chunking/test_chunk.py @@ -19,6 +19,9 @@ WAR_AND_PEACE = ( class TestDocumentChunking(unittest.TestCase): def setUp(self) -> None: self.large_section = Section(text=WAR_AND_PEACE, link="https://www.test.com/") + self.large_unbroken_section = Section( + text="0123456789" * 40, link="https://www.test.com/" + ) self.document = Document( id="test_document", sections=[ @@ -52,18 +55,37 @@ class TestDocumentChunking(unittest.TestCase): chunk_size=100, word_overlap=3, ) - self.assertEqual(len(chunks), 5) - self.assertEqual(chunks[0].content, WAR_AND_PEACE[:99]) + contents = [chunk.content for chunk in chunks] + + self.assertEqual(len(contents), 5) + self.assertEqual(contents[0], WAR_AND_PEACE[:100]) self.assertEqual( - chunks[-2].content, WAR_AND_PEACE[-176:-63] + contents[-2], WAR_AND_PEACE[-172:-62] ) # slightly longer than 100 due to overlap self.assertEqual( - chunks[-1].content, WAR_AND_PEACE[-121:] + contents[-1], WAR_AND_PEACE[-125:] ) # large overlap with second to last segment self.assertFalse(chunks[0].section_continuation) self.assertTrue(chunks[1].section_continuation) self.assertTrue(chunks[-1].section_continuation) + def test_chunk_max_overflow(self) -> None: + chunks = chunk_large_section( + section=self.large_unbroken_section, + document=self.document, + start_chunk_id=5, + chunk_size=100, + word_overlap=3, + ) + contents = [chunk.content for chunk in chunks] + + self.assertEqual(len(contents), 4) + self.assertEqual(contents[0], self.large_unbroken_section.text[:150]) + self.assertEqual(contents[1], self.large_unbroken_section.text[50:250]) + self.assertEqual(contents[2], self.large_unbroken_section.text[150:350]) + # Last chunk counts back from the end, full chunk size (100) + 50 overlap => 400 - 150 = 250 + self.assertEqual(contents[3], self.large_unbroken_section.text[250:]) + def test_chunk_document(self) -> None: chunks = chunk_document(self.document, chunk_size=100, subsection_overlap=3) self.assertEqual(len(chunks), 8)