From f404c4b4482ae8a46d484dbd045e3ad49e66ba5f Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Wed, 18 Sep 2024 23:00:58 -0700 Subject: [PATCH] Move code block default language creation to citation processing (#2501) * move code block default language creation to citaiton processing * add test cases * update copy --- backend/danswer/chat/process_message.py | 1 + .../stream_processing/citation_processing.py | 9 ++ .../test_citation_processing.py | 86 +++++++++++++++++++ 3 files changed, 96 insertions(+) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 26c59ebf6..93ad3bdd3 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -709,6 +709,7 @@ def stream_chat_message_objects( yield FinalUsedContextDocsResponse( final_context_docs=packet.response ) + elif packet.id == IMAGE_GENERATION_RESPONSE_ID: img_generation_response = cast( list[ImageGenerationResponse], packet.response diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py index a72fc70a8..f1e548955 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -85,6 +85,15 @@ def extract_citations_from_stream( curr_segment += token llm_out += token + # Handle code blocks without language tags + if "`" in curr_segment: + if curr_segment.endswith("`"): + continue + elif "```" in curr_segment: + piece_that_comes_after = curr_segment.split("```")[1][0] + if piece_that_comes_after == "\n" and in_code_block(llm_out): + curr_segment = curr_segment.replace("```", "```plaintext") + citation_pattern = r"\[(\d+)\]" citations_found = list(re.finditer(citation_pattern, curr_segment)) diff --git a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py index 473ccf245..12e3254d6 100644 --- a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py +++ b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py @@ -286,6 +286,92 @@ def process_text( "[[1]](https://0.com) Citation at the beginning. ", ["doc_0"], ), + ( + "Code block without language specification", + [ + "Here's", + " a code block", + ":\n```\nd", + "ef example():\n pass\n", + "```\n", + "End of code.", + ], + "Here's a code block:\n```plaintext\ndef example():\n pass\n```\nEnd of code.", + [], + ), + ( + "Code block with language specification", + [ + "Here's a Python code block:\n", + "```", + "python", + "\n", + "def greet", + "(name):", + "\n ", + "print", + "(f'Hello, ", + "{name}!')", + "\n", + "greet('World')", + "\n```\n", + "This function ", + "greets the user.", + ], + "Here's a Python code block:\n```python\ndef greet(name):\n " + "print(f'Hello, {name}!')\ngreet('World')\n```\nThis function greets the user.", + [], + ), + ( + "Multiple code blocks with different languages", + [ + "JavaScript example:\n", + "```", + "javascript", + "\n", + "console", + ".", + "log", + "('Hello, World!');", + "\n```\n", + "Python example", + ":\n", + "```", + "python", + "\n", + "print", + "('Hello, World!')", + "\n```\n", + "Both print greetings", + ".", + ], + "JavaScript example:\n```javascript\nconsole.log('Hello, World!');\n" + "```\nPython example:\n```python\nprint('Hello, World!')\n" + "```\nBoth print greetings.", + [], + ), + ( + "Code block with text block", + [ + "Here's a code block with a text block:\n", + "```\n", + "# This is a comment", + "\n", + "x = 10 # This assigns 10 to x\n", + "print", + "(x) # This prints x", + "\n```\n", + "The code demonstrates variable assignment.", + ], + "Here's a code block with a text block:\n" + "```plaintext\n" + "# This is a comment\n" + "x = 10 # This assigns 10 to x\n" + "print(x) # This prints x\n" + "```\n" + "The code demonstrates variable assignment.", + [], + ), ], ) def test_citation_extraction(