Add metadata to GPT ()

This commit is contained in:
Yuhong Sun 2023-07-14 16:54:42 -07:00 committed by GitHub
parent 8928d61492
commit e4820045f9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 185 additions and 65 deletions
backend
danswer
tests/unit/qa_service/direct_qa

@ -4,6 +4,8 @@ from dataclasses import dataclass
from typing import Any
from typing import cast
from danswer.configs.constants import METADATA
from danswer.configs.constants import SOURCE_LINKS
from danswer.connectors.models import Document
@ -35,20 +37,23 @@ class InferenceChunk(BaseChunk):
document_id: str
source_type: str
semantic_identifier: str
metadata: dict[str, Any]
@classmethod
def from_dict(cls, init_dict: dict[str, Any]) -> "InferenceChunk":
init_kwargs = {
k: v for k, v in init_dict.items() if k in inspect.signature(cls).parameters
}
if "source_links" in init_kwargs:
source_links = init_kwargs["source_links"]
if SOURCE_LINKS in init_kwargs:
source_links = init_kwargs[SOURCE_LINKS]
source_links_dict = (
json.loads(source_links)
if isinstance(source_links, str)
else source_links
)
init_kwargs["source_links"] = {
init_kwargs[SOURCE_LINKS] = {
int(k): v for k, v in cast(dict[str, str], source_links_dict).items()
}
if METADATA in init_kwargs:
init_kwargs[METADATA] = json.loads(init_kwargs[METADATA])
return cls(**init_kwargs)

@ -96,6 +96,8 @@ NUM_GENERATIVE_AI_INPUT_DOCS = 5
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = 10 # 10 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False
#####

@ -11,6 +11,7 @@ SEMANTIC_IDENTIFIER = "semantic_identifier"
SECTION_CONTINUATION = "section_continuation"
ALLOWED_USERS = "allowed_users"
ALLOWED_GROUPS = "allowed_groups"
METADATA = "metadata"
OPENAI_API_KEY_STORAGE_KEY = "openai_api_key"
HTML_SEPARATOR = "\n"
PUBLIC_DOC_PAT = "PUBLIC"

@ -136,7 +136,10 @@ class ConfluenceConnector(LoadConnector, PollConnector):
sections=[Section(link=page_url, text=page_text)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=page["title"],
metadata={},
metadata={
"Wiki Space Name": self.space,
"Updated At": page["version"]["friendlyWhen"],
},
)
)
return doc_batch, len(batch)

@ -27,7 +27,7 @@ class ConnectorMissingException(Exception):
def identify_connector_class(
source: DocumentSource,
input_type: InputType,
input_type: InputType | None = None,
) -> Type[BaseConnector]:
connector_map = {
DocumentSource.WEB: WebConnector,
@ -46,7 +46,11 @@ def identify_connector_class(
connector_by_source = connector_map.get(source, {})
if isinstance(connector_by_source, dict):
connector = connector_by_source.get(input_type)
if input_type is None:
# If not specified, default to most exhaustive update
connector = connector_by_source.get(InputType.LOAD_STATE)
else:
connector = connector_by_source.get(input_type)
else:
connector = connector_by_source
if connector is None:

@ -15,6 +15,24 @@ class BaseConnector(abc.ABC):
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError
@staticmethod
def parse_metadata(metadata: dict[str, Any]) -> list[str]:
"""Parse the metadata for a document/chunk into a string to pass to Generative AI as additional context"""
custom_parser_req_msg = (
"Specific metadata parsing required, connector has not implemented it."
)
metadata_lines = []
for metadata_key, metadata_value in metadata.items():
if isinstance(metadata_value, str):
metadata_lines.append(f"{metadata_key}: {metadata_value}")
elif isinstance(metadata_value, list):
if not all([isinstance(val, str) for val in metadata_value]):
raise RuntimeError(custom_parser_req_msg)
metadata_lines.append(f'{metadata_key}: {", ".join(metadata_value)}')
else:
raise RuntimeError(custom_parser_req_msg)
return metadata_lines
# Large set update or reindex, generally pulling a complete state or from a savestate file
class LoadConnector(BaseConnector):

@ -1,4 +1,5 @@
import io
from datetime import datetime
from typing import Any
from typing import cast
from urllib.parse import urljoin
@ -82,6 +83,8 @@ class WebConnector(LoadConnector):
logger.info(f"Indexing {current_url}")
try:
current_visit_time = datetime.now().strftime("%B %d, %Y, %H:%M:%S")
if restart_playwright:
playwright = sync_playwright().start()
browser = playwright.chromium.launch(headless=True)
@ -102,7 +105,7 @@ class WebConnector(LoadConnector):
sections=[Section(link=current_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=current_url.split(".")[-1],
metadata={},
metadata={"Time Visited": current_visit_time},
)
)
continue

@ -1,3 +1,4 @@
import json
from functools import partial
from uuid import UUID
@ -8,6 +9,7 @@ from danswer.configs.constants import BLURB
from danswer.configs.constants import CHUNK_ID
from danswer.configs.constants import CONTENT
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import METADATA
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SEMANTIC_IDENTIFIER
@ -23,7 +25,6 @@ from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.exceptions import ResponseHandlingException
from qdrant_client.http.models.models import UpdateResult
from qdrant_client.http.models.models import UpdateStatus
from qdrant_client.models import CollectionsResponse
from qdrant_client.models import Distance
from qdrant_client.models import PointStruct
@ -71,7 +72,7 @@ def get_qdrant_document_whitelists(
def delete_qdrant_doc_chunks(
document_id: str, collection_name: str, q_client: QdrantClient
) -> bool:
res = q_client.delete(
q_client.delete(
collection_name=collection_name,
points_selector=models.FilterSelector(
filter=models.Filter(
@ -136,6 +137,7 @@ def index_qdrant_chunks(
SECTION_CONTINUATION: chunk.section_continuation,
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
METADATA: json.dumps(document.metadata),
},
vector=embedding,
)

@ -14,6 +14,7 @@ from danswer.configs.constants import BLURB
from danswer.configs.constants import CHUNK_ID
from danswer.configs.constants import CONTENT
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import METADATA
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SEMANTIC_IDENTIFIER
@ -62,6 +63,7 @@ def create_typesense_collection(
{"name": SECTION_CONTINUATION, "type": "bool"},
{"name": ALLOWED_USERS, "type": "string[]"},
{"name": ALLOWED_GROUPS, "type": "string[]"},
{"name": METADATA, "type": "string"},
],
}
ts_client.collections.create(collection_schema)
@ -139,6 +141,7 @@ def index_typesense_chunks(
SECTION_CONTINUATION: chunk.section_continuation,
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
METADATA: json.dumps(document.metadata),
}
)

@ -16,6 +16,7 @@ from typing import Union
import openai
import regex
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import INCLUDE_METADATA
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import BLURB
@ -32,7 +33,6 @@ from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg
from danswer.direct_qa.qa_prompts import json_chat_processor
from danswer.direct_qa.qa_prompts import json_processor
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.utils.logging import setup_logger
from danswer.utils.text_processing import clean_model_quote
@ -250,24 +250,29 @@ class OpenAIQAModel(QAModel):
class OpenAICompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: Callable[[str, list[str]], str] = json_processor,
prompt_processor: Callable[
[str, list[InferenceChunk], bool], str
] = json_processor,
model_version: str = OPENAI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
api_key: str | None = None,
timeout: int | None = None,
include_metadata: bool = INCLUDE_METADATA,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.api_key = api_key or get_openai_api_key()
self.timeout = timeout
self.include_metadata = include_metadata
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
filled_prompt = self.prompt_processor(query, top_contents)
filled_prompt = self.prompt_processor(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
openai_call = _handle_openai_exceptions_wrapper(
@ -293,8 +298,9 @@ class OpenAICompletionQA(OpenAIQAModel):
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
filled_prompt = self.prompt_processor(query, top_contents)
filled_prompt = self.prompt_processor(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
openai_call = _handle_openai_exceptions_wrapper(
@ -353,13 +359,14 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: Callable[
[str, list[str]], list[dict[str, str]]
[str, list[InferenceChunk], bool], list[dict[str, str]]
] = json_chat_processor,
model_version: str = OPENAI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
timeout: int | None = None,
reflexion_try_count: int = 0,
api_key: str | None = None,
include_metadata: bool = INCLUDE_METADATA,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
@ -367,13 +374,15 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
self.reflexion_try_count = reflexion_try_count
self.api_key = api_key or get_openai_api_key()
self.timeout = timeout
self.include_metadata = include_metadata
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
self,
query: str,
context_docs: list[InferenceChunk],
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
messages = self.prompt_processor(query, top_contents)
messages = self.prompt_processor(query, context_docs, self.include_metadata)
logger.debug(messages)
model_output = ""
for _ in range(self.reflexion_try_count + 1):
@ -407,8 +416,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
messages = self.prompt_processor(query, top_contents)
messages = self.prompt_processor(query, context_docs, self.include_metadata)
logger.debug(messages)
openai_call = _handle_openai_exceptions_wrapper(

@ -1,6 +1,13 @@
import json
from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import DocumentSource
from danswer.connectors.factory import identify_connector_class
GENERAL_SEP_PAT = "---\n"
DOC_SEP_PAT = "---NEW DOCUMENT---"
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
QUESTION_PAT = "Query:"
ANSWER_PAT = "Answer:"
UNCERTAINTY_PAT = "?"
@ -9,7 +16,7 @@ QUOTE_PAT = "Quote:"
BASE_PROMPT = (
f"Answer the query based on provided documents and quote relevant sections. "
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
f"The quotes must be EXACT substrings from the documents.\n"
f"The quotes must be EXACT substrings from the documents."
)
@ -24,20 +31,122 @@ SAMPLE_JSON_RESPONSE = {
}
def json_processor(question: str, documents: list[str]) -> str:
def add_metadata_section(
prompt_current: str,
chunk: InferenceChunk,
prepend_tab: bool = False,
include_sep: bool = False,
) -> str:
"""
Inserts a metadata section at the start of a document, providing additional context to the upcoming document.
Parameters:
prompt_current (str): The existing content of the prompt so far with.
chunk (InferenceChunk): An object that contains the document's source type and metadata information to be added.
prepend_tab (bool, optional): If set to True, a tab character is added at the start of each line in the metadata
section for consistent spacing for LLM.
include_sep (bool, optional): If set to True, includes default section separator pattern at the end of the metadata
section.
Returns:
str: The prompt with the newly added metadata section.
"""
def _prepend(s: str, ppt: bool) -> str:
return "\t" + s if ppt else s
prompt_current += _prepend(f"DOCUMENT SOURCE: {chunk.source_type}\n", prepend_tab)
if chunk.metadata:
prompt_current += _prepend(f"METADATA:\n", prepend_tab)
connector_class = identify_connector_class(DocumentSource(chunk.source_type))
for metadata_line in connector_class.parse_metadata(chunk.metadata):
prompt_current += _prepend(f"\t{metadata_line}\n", prepend_tab)
prompt_current += _prepend(DOC_CONTENT_START_PAT, prepend_tab)
if include_sep:
prompt_current += GENERAL_SEP_PAT
return prompt_current
def json_processor(
question: str,
chunks: list[InferenceChunk],
include_metadata: bool = False,
include_sep: bool = True,
) -> str:
prompt = (
BASE_PROMPT + f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
)
for document in documents:
prompt += f"\n{DOC_SEP_PAT}\n{document}"
for chunk in chunks:
prompt += f"\n\n{DOC_SEP_PAT}\n"
if include_metadata:
prompt = add_metadata_section(
prompt, chunk, prepend_tab=False, include_sep=include_sep
)
prompt += chunk.content
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
return prompt
def json_chat_processor(
question: str,
chunks: list[InferenceChunk],
include_metadata: bool = False,
include_sep: bool = False,
) -> list[dict[str, str]]:
metadata_prompt_section = "with metadata and contents " if include_metadata else ""
intro_msg = (
f"You are a Question Answering assistant that answers queries based on the provided most relevant documents.\n"
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
)
complete_answer_not_found_response = (
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
)
task_msg = (
"Now answer the next user query based on documents above and quote relevant sections.\n"
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
"All quotes MUST be EXACT substrings from provided documents.\n"
"Your responses should be informative and concise.\n"
"You MUST prioritize information from provided documents over internal knowledge.\n"
"If the query cannot be answered based on the documents, respond with "
f"{complete_answer_not_found_response}\n"
"If the query requires aggregating the number of documents, respond with "
'{"answer": "Aggregations not supported", "quotes": []}\n'
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
)
messages = [{"role": "system", "content": intro_msg}]
for chunk in chunks:
full_context = ""
if include_metadata:
full_context = add_metadata_section(
full_context, chunk, prepend_tab=False, include_sep=include_sep
)
full_context += chunk.content
messages.extend(
[
{
"role": "user",
"content": full_context,
},
{"role": "assistant", "content": "Acknowledged"},
]
)
messages.append({"role": "system", "content": task_msg})
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
return messages
# EVERYTHING BELOW IS DEPRECATED, kept around as reference, may use again in future
# Chain of Thought approach works however has higher token cost (more expensive) and is slower.
# Should use this one if users ask questions that require logical reasoning.
def json_cot_variant_processor(question: str, documents: list[str]) -> str:
@ -100,46 +209,6 @@ def freeform_processor(question: str, documents: list[str]) -> str:
return prompt
def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, str]]:
intro_msg = (
"You are a Question Answering assistant that answers queries based on provided documents.\n"
'Start by reading the following documents and responding with "Acknowledged".'
)
complete_answer_not_found_response = (
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
)
task_msg = (
"Now answer the next user query based on documents above and quote relevant sections.\n"
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
"All quotes MUST be EXACT substrings from provided documents.\n"
"Your responses should be informative and concise.\n"
"You MUST prioritize information from provided documents over internal knowledge.\n"
"If the query cannot be answered based on the documents, respond with "
f"{complete_answer_not_found_response}\n"
"If the query requires aggregating whole documents, respond with "
'{"answer": "Aggregations not supported", "quotes": []}\n'
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
)
messages = [{"role": "system", "content": intro_msg}]
for document in documents:
messages.extend(
[
{
"role": "user",
"content": document,
},
{"role": "assistant", "content": "Acknowledged"},
]
)
messages.append({"role": "system", "content": task_msg})
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
return messages
def freeform_chat_processor(
question: str, documents: list[str]
) -> list[dict[str, str]]:

@ -112,6 +112,7 @@ class TestQAPostprocessing(unittest.TestCase):
blurb="anything",
semantic_identifier="anything",
section_continuation=False,
metadata={},
)
test_chunk_1 = InferenceChunk(
document_id="test doc 1",
@ -122,6 +123,7 @@ class TestQAPostprocessing(unittest.TestCase):
blurb="whatever",
semantic_identifier="whatever",
section_continuation=False,
metadata={},
)
test_quotes = [