mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-15 15:43:16 +02:00
Add metadata to GPT (#140)
This commit is contained in:
@ -4,6 +4,8 @@ from dataclasses import dataclass
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
from danswer.configs.constants import METADATA
|
||||||
|
from danswer.configs.constants import SOURCE_LINKS
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
|
|
||||||
|
|
||||||
@ -35,20 +37,23 @@ class InferenceChunk(BaseChunk):
|
|||||||
document_id: str
|
document_id: str
|
||||||
source_type: str
|
source_type: str
|
||||||
semantic_identifier: str
|
semantic_identifier: str
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, init_dict: dict[str, Any]) -> "InferenceChunk":
|
def from_dict(cls, init_dict: dict[str, Any]) -> "InferenceChunk":
|
||||||
init_kwargs = {
|
init_kwargs = {
|
||||||
k: v for k, v in init_dict.items() if k in inspect.signature(cls).parameters
|
k: v for k, v in init_dict.items() if k in inspect.signature(cls).parameters
|
||||||
}
|
}
|
||||||
if "source_links" in init_kwargs:
|
if SOURCE_LINKS in init_kwargs:
|
||||||
source_links = init_kwargs["source_links"]
|
source_links = init_kwargs[SOURCE_LINKS]
|
||||||
source_links_dict = (
|
source_links_dict = (
|
||||||
json.loads(source_links)
|
json.loads(source_links)
|
||||||
if isinstance(source_links, str)
|
if isinstance(source_links, str)
|
||||||
else source_links
|
else source_links
|
||||||
)
|
)
|
||||||
init_kwargs["source_links"] = {
|
init_kwargs[SOURCE_LINKS] = {
|
||||||
int(k): v for k, v in cast(dict[str, str], source_links_dict).items()
|
int(k): v for k, v in cast(dict[str, str], source_links_dict).items()
|
||||||
}
|
}
|
||||||
|
if METADATA in init_kwargs:
|
||||||
|
init_kwargs[METADATA] = json.loads(init_kwargs[METADATA])
|
||||||
return cls(**init_kwargs)
|
return cls(**init_kwargs)
|
||||||
|
@ -96,6 +96,8 @@ NUM_GENERATIVE_AI_INPUT_DOCS = 5
|
|||||||
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
||||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||||
QA_TIMEOUT = 10 # 10 seconds
|
QA_TIMEOUT = 10 # 10 seconds
|
||||||
|
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||||
|
INCLUDE_METADATA = False
|
||||||
|
|
||||||
|
|
||||||
#####
|
#####
|
||||||
|
@ -11,6 +11,7 @@ SEMANTIC_IDENTIFIER = "semantic_identifier"
|
|||||||
SECTION_CONTINUATION = "section_continuation"
|
SECTION_CONTINUATION = "section_continuation"
|
||||||
ALLOWED_USERS = "allowed_users"
|
ALLOWED_USERS = "allowed_users"
|
||||||
ALLOWED_GROUPS = "allowed_groups"
|
ALLOWED_GROUPS = "allowed_groups"
|
||||||
|
METADATA = "metadata"
|
||||||
OPENAI_API_KEY_STORAGE_KEY = "openai_api_key"
|
OPENAI_API_KEY_STORAGE_KEY = "openai_api_key"
|
||||||
HTML_SEPARATOR = "\n"
|
HTML_SEPARATOR = "\n"
|
||||||
PUBLIC_DOC_PAT = "PUBLIC"
|
PUBLIC_DOC_PAT = "PUBLIC"
|
||||||
|
@ -136,7 +136,10 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
|||||||
sections=[Section(link=page_url, text=page_text)],
|
sections=[Section(link=page_url, text=page_text)],
|
||||||
source=DocumentSource.CONFLUENCE,
|
source=DocumentSource.CONFLUENCE,
|
||||||
semantic_identifier=page["title"],
|
semantic_identifier=page["title"],
|
||||||
metadata={},
|
metadata={
|
||||||
|
"Wiki Space Name": self.space,
|
||||||
|
"Updated At": page["version"]["friendlyWhen"],
|
||||||
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return doc_batch, len(batch)
|
return doc_batch, len(batch)
|
||||||
|
@ -27,7 +27,7 @@ class ConnectorMissingException(Exception):
|
|||||||
|
|
||||||
def identify_connector_class(
|
def identify_connector_class(
|
||||||
source: DocumentSource,
|
source: DocumentSource,
|
||||||
input_type: InputType,
|
input_type: InputType | None = None,
|
||||||
) -> Type[BaseConnector]:
|
) -> Type[BaseConnector]:
|
||||||
connector_map = {
|
connector_map = {
|
||||||
DocumentSource.WEB: WebConnector,
|
DocumentSource.WEB: WebConnector,
|
||||||
@ -46,6 +46,10 @@ def identify_connector_class(
|
|||||||
connector_by_source = connector_map.get(source, {})
|
connector_by_source = connector_map.get(source, {})
|
||||||
|
|
||||||
if isinstance(connector_by_source, dict):
|
if isinstance(connector_by_source, dict):
|
||||||
|
if input_type is None:
|
||||||
|
# If not specified, default to most exhaustive update
|
||||||
|
connector = connector_by_source.get(InputType.LOAD_STATE)
|
||||||
|
else:
|
||||||
connector = connector_by_source.get(input_type)
|
connector = connector_by_source.get(input_type)
|
||||||
else:
|
else:
|
||||||
connector = connector_by_source
|
connector = connector_by_source
|
||||||
|
@ -15,6 +15,24 @@ class BaseConnector(abc.ABC):
|
|||||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_metadata(metadata: dict[str, Any]) -> list[str]:
|
||||||
|
"""Parse the metadata for a document/chunk into a string to pass to Generative AI as additional context"""
|
||||||
|
custom_parser_req_msg = (
|
||||||
|
"Specific metadata parsing required, connector has not implemented it."
|
||||||
|
)
|
||||||
|
metadata_lines = []
|
||||||
|
for metadata_key, metadata_value in metadata.items():
|
||||||
|
if isinstance(metadata_value, str):
|
||||||
|
metadata_lines.append(f"{metadata_key}: {metadata_value}")
|
||||||
|
elif isinstance(metadata_value, list):
|
||||||
|
if not all([isinstance(val, str) for val in metadata_value]):
|
||||||
|
raise RuntimeError(custom_parser_req_msg)
|
||||||
|
metadata_lines.append(f'{metadata_key}: {", ".join(metadata_value)}')
|
||||||
|
else:
|
||||||
|
raise RuntimeError(custom_parser_req_msg)
|
||||||
|
return metadata_lines
|
||||||
|
|
||||||
|
|
||||||
# Large set update or reindex, generally pulling a complete state or from a savestate file
|
# Large set update or reindex, generally pulling a complete state or from a savestate file
|
||||||
class LoadConnector(BaseConnector):
|
class LoadConnector(BaseConnector):
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import io
|
import io
|
||||||
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
@ -82,6 +83,8 @@ class WebConnector(LoadConnector):
|
|||||||
logger.info(f"Indexing {current_url}")
|
logger.info(f"Indexing {current_url}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
current_visit_time = datetime.now().strftime("%B %d, %Y, %H:%M:%S")
|
||||||
|
|
||||||
if restart_playwright:
|
if restart_playwright:
|
||||||
playwright = sync_playwright().start()
|
playwright = sync_playwright().start()
|
||||||
browser = playwright.chromium.launch(headless=True)
|
browser = playwright.chromium.launch(headless=True)
|
||||||
@ -102,7 +105,7 @@ class WebConnector(LoadConnector):
|
|||||||
sections=[Section(link=current_url, text=page_text)],
|
sections=[Section(link=current_url, text=page_text)],
|
||||||
source=DocumentSource.WEB,
|
source=DocumentSource.WEB,
|
||||||
semantic_identifier=current_url.split(".")[-1],
|
semantic_identifier=current_url.split(".")[-1],
|
||||||
metadata={},
|
metadata={"Time Visited": current_visit_time},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import json
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
@ -8,6 +9,7 @@ from danswer.configs.constants import BLURB
|
|||||||
from danswer.configs.constants import CHUNK_ID
|
from danswer.configs.constants import CHUNK_ID
|
||||||
from danswer.configs.constants import CONTENT
|
from danswer.configs.constants import CONTENT
|
||||||
from danswer.configs.constants import DOCUMENT_ID
|
from danswer.configs.constants import DOCUMENT_ID
|
||||||
|
from danswer.configs.constants import METADATA
|
||||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||||
from danswer.configs.constants import SECTION_CONTINUATION
|
from danswer.configs.constants import SECTION_CONTINUATION
|
||||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||||
@ -23,7 +25,6 @@ from qdrant_client import QdrantClient
|
|||||||
from qdrant_client.http import models
|
from qdrant_client.http import models
|
||||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||||
from qdrant_client.http.models.models import UpdateResult
|
from qdrant_client.http.models.models import UpdateResult
|
||||||
from qdrant_client.http.models.models import UpdateStatus
|
|
||||||
from qdrant_client.models import CollectionsResponse
|
from qdrant_client.models import CollectionsResponse
|
||||||
from qdrant_client.models import Distance
|
from qdrant_client.models import Distance
|
||||||
from qdrant_client.models import PointStruct
|
from qdrant_client.models import PointStruct
|
||||||
@ -71,7 +72,7 @@ def get_qdrant_document_whitelists(
|
|||||||
def delete_qdrant_doc_chunks(
|
def delete_qdrant_doc_chunks(
|
||||||
document_id: str, collection_name: str, q_client: QdrantClient
|
document_id: str, collection_name: str, q_client: QdrantClient
|
||||||
) -> bool:
|
) -> bool:
|
||||||
res = q_client.delete(
|
q_client.delete(
|
||||||
collection_name=collection_name,
|
collection_name=collection_name,
|
||||||
points_selector=models.FilterSelector(
|
points_selector=models.FilterSelector(
|
||||||
filter=models.Filter(
|
filter=models.Filter(
|
||||||
@ -136,6 +137,7 @@ def index_qdrant_chunks(
|
|||||||
SECTION_CONTINUATION: chunk.section_continuation,
|
SECTION_CONTINUATION: chunk.section_continuation,
|
||||||
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
|
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
|
||||||
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
|
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
|
||||||
|
METADATA: json.dumps(document.metadata),
|
||||||
},
|
},
|
||||||
vector=embedding,
|
vector=embedding,
|
||||||
)
|
)
|
||||||
|
@ -14,6 +14,7 @@ from danswer.configs.constants import BLURB
|
|||||||
from danswer.configs.constants import CHUNK_ID
|
from danswer.configs.constants import CHUNK_ID
|
||||||
from danswer.configs.constants import CONTENT
|
from danswer.configs.constants import CONTENT
|
||||||
from danswer.configs.constants import DOCUMENT_ID
|
from danswer.configs.constants import DOCUMENT_ID
|
||||||
|
from danswer.configs.constants import METADATA
|
||||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||||
from danswer.configs.constants import SECTION_CONTINUATION
|
from danswer.configs.constants import SECTION_CONTINUATION
|
||||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||||
@ -62,6 +63,7 @@ def create_typesense_collection(
|
|||||||
{"name": SECTION_CONTINUATION, "type": "bool"},
|
{"name": SECTION_CONTINUATION, "type": "bool"},
|
||||||
{"name": ALLOWED_USERS, "type": "string[]"},
|
{"name": ALLOWED_USERS, "type": "string[]"},
|
||||||
{"name": ALLOWED_GROUPS, "type": "string[]"},
|
{"name": ALLOWED_GROUPS, "type": "string[]"},
|
||||||
|
{"name": METADATA, "type": "string"},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
ts_client.collections.create(collection_schema)
|
ts_client.collections.create(collection_schema)
|
||||||
@ -139,6 +141,7 @@ def index_typesense_chunks(
|
|||||||
SECTION_CONTINUATION: chunk.section_continuation,
|
SECTION_CONTINUATION: chunk.section_continuation,
|
||||||
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
|
ALLOWED_USERS: doc_user_map[document.id][ALLOWED_USERS],
|
||||||
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
|
ALLOWED_GROUPS: doc_user_map[document.id][ALLOWED_GROUPS],
|
||||||
|
METADATA: json.dumps(document.metadata),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,6 +16,7 @@ from typing import Union
|
|||||||
import openai
|
import openai
|
||||||
import regex
|
import regex
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.app_configs import INCLUDE_METADATA
|
||||||
from danswer.configs.app_configs import OPENAI_API_KEY
|
from danswer.configs.app_configs import OPENAI_API_KEY
|
||||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||||
from danswer.configs.constants import BLURB
|
from danswer.configs.constants import BLURB
|
||||||
@ -32,7 +33,6 @@ from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg
|
|||||||
from danswer.direct_qa.qa_prompts import json_chat_processor
|
from danswer.direct_qa.qa_prompts import json_chat_processor
|
||||||
from danswer.direct_qa.qa_prompts import json_processor
|
from danswer.direct_qa.qa_prompts import json_processor
|
||||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
|
||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from danswer.utils.text_processing import clean_model_quote
|
from danswer.utils.text_processing import clean_model_quote
|
||||||
@ -250,24 +250,29 @@ class OpenAIQAModel(QAModel):
|
|||||||
class OpenAICompletionQA(OpenAIQAModel):
|
class OpenAICompletionQA(OpenAIQAModel):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prompt_processor: Callable[[str, list[str]], str] = json_processor,
|
prompt_processor: Callable[
|
||||||
|
[str, list[InferenceChunk], bool], str
|
||||||
|
] = json_processor,
|
||||||
model_version: str = OPENAI_MODEL_VERSION,
|
model_version: str = OPENAI_MODEL_VERSION,
|
||||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
|
include_metadata: bool = INCLUDE_METADATA,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.prompt_processor = prompt_processor
|
self.prompt_processor = prompt_processor
|
||||||
self.model_version = model_version
|
self.model_version = model_version
|
||||||
self.max_output_tokens = max_output_tokens
|
self.max_output_tokens = max_output_tokens
|
||||||
self.api_key = api_key or get_openai_api_key()
|
self.api_key = api_key or get_openai_api_key()
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
self.include_metadata = include_metadata
|
||||||
|
|
||||||
@log_function_time()
|
@log_function_time()
|
||||||
def answer_question(
|
def answer_question(
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||||
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
filled_prompt = self.prompt_processor(
|
||||||
filled_prompt = self.prompt_processor(query, top_contents)
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
logger.debug(filled_prompt)
|
logger.debug(filled_prompt)
|
||||||
|
|
||||||
openai_call = _handle_openai_exceptions_wrapper(
|
openai_call = _handle_openai_exceptions_wrapper(
|
||||||
@ -293,8 +298,9 @@ class OpenAICompletionQA(OpenAIQAModel):
|
|||||||
def answer_question_stream(
|
def answer_question_stream(
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
) -> Generator[dict[str, Any] | None, None, None]:
|
) -> Generator[dict[str, Any] | None, None, None]:
|
||||||
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
filled_prompt = self.prompt_processor(
|
||||||
filled_prompt = self.prompt_processor(query, top_contents)
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
logger.debug(filled_prompt)
|
logger.debug(filled_prompt)
|
||||||
|
|
||||||
openai_call = _handle_openai_exceptions_wrapper(
|
openai_call = _handle_openai_exceptions_wrapper(
|
||||||
@ -353,13 +359,14 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prompt_processor: Callable[
|
prompt_processor: Callable[
|
||||||
[str, list[str]], list[dict[str, str]]
|
[str, list[InferenceChunk], bool], list[dict[str, str]]
|
||||||
] = json_chat_processor,
|
] = json_chat_processor,
|
||||||
model_version: str = OPENAI_MODEL_VERSION,
|
model_version: str = OPENAI_MODEL_VERSION,
|
||||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||||
timeout: int | None = None,
|
timeout: int | None = None,
|
||||||
reflexion_try_count: int = 0,
|
reflexion_try_count: int = 0,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
|
include_metadata: bool = INCLUDE_METADATA,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.prompt_processor = prompt_processor
|
self.prompt_processor = prompt_processor
|
||||||
self.model_version = model_version
|
self.model_version = model_version
|
||||||
@ -367,13 +374,15 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
|||||||
self.reflexion_try_count = reflexion_try_count
|
self.reflexion_try_count = reflexion_try_count
|
||||||
self.api_key = api_key or get_openai_api_key()
|
self.api_key = api_key or get_openai_api_key()
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
self.include_metadata = include_metadata
|
||||||
|
|
||||||
@log_function_time()
|
@log_function_time()
|
||||||
def answer_question(
|
def answer_question(
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
self,
|
||||||
|
query: str,
|
||||||
|
context_docs: list[InferenceChunk],
|
||||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||||
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
messages = self.prompt_processor(query, context_docs, self.include_metadata)
|
||||||
messages = self.prompt_processor(query, top_contents)
|
|
||||||
logger.debug(messages)
|
logger.debug(messages)
|
||||||
model_output = ""
|
model_output = ""
|
||||||
for _ in range(self.reflexion_try_count + 1):
|
for _ in range(self.reflexion_try_count + 1):
|
||||||
@ -407,8 +416,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
|||||||
def answer_question_stream(
|
def answer_question_stream(
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
) -> Generator[dict[str, Any] | None, None, None]:
|
) -> Generator[dict[str, Any] | None, None, None]:
|
||||||
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
messages = self.prompt_processor(query, context_docs, self.include_metadata)
|
||||||
messages = self.prompt_processor(query, top_contents)
|
|
||||||
logger.debug(messages)
|
logger.debug(messages)
|
||||||
|
|
||||||
openai_call = _handle_openai_exceptions_wrapper(
|
openai_call = _handle_openai_exceptions_wrapper(
|
||||||
|
@ -1,6 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
|
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.constants import DocumentSource
|
||||||
|
from danswer.connectors.factory import identify_connector_class
|
||||||
|
|
||||||
|
|
||||||
|
GENERAL_SEP_PAT = "---\n"
|
||||||
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
||||||
|
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
|
||||||
QUESTION_PAT = "Query:"
|
QUESTION_PAT = "Query:"
|
||||||
ANSWER_PAT = "Answer:"
|
ANSWER_PAT = "Answer:"
|
||||||
UNCERTAINTY_PAT = "?"
|
UNCERTAINTY_PAT = "?"
|
||||||
@ -9,7 +16,7 @@ QUOTE_PAT = "Quote:"
|
|||||||
BASE_PROMPT = (
|
BASE_PROMPT = (
|
||||||
f"Answer the query based on provided documents and quote relevant sections. "
|
f"Answer the query based on provided documents and quote relevant sections. "
|
||||||
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
|
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
|
||||||
f"The quotes must be EXACT substrings from the documents.\n"
|
f"The quotes must be EXACT substrings from the documents."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -24,20 +31,122 @@ SAMPLE_JSON_RESPONSE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def json_processor(question: str, documents: list[str]) -> str:
|
def add_metadata_section(
|
||||||
|
prompt_current: str,
|
||||||
|
chunk: InferenceChunk,
|
||||||
|
prepend_tab: bool = False,
|
||||||
|
include_sep: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Inserts a metadata section at the start of a document, providing additional context to the upcoming document.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
prompt_current (str): The existing content of the prompt so far with.
|
||||||
|
chunk (InferenceChunk): An object that contains the document's source type and metadata information to be added.
|
||||||
|
prepend_tab (bool, optional): If set to True, a tab character is added at the start of each line in the metadata
|
||||||
|
section for consistent spacing for LLM.
|
||||||
|
include_sep (bool, optional): If set to True, includes default section separator pattern at the end of the metadata
|
||||||
|
section.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The prompt with the newly added metadata section.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _prepend(s: str, ppt: bool) -> str:
|
||||||
|
return "\t" + s if ppt else s
|
||||||
|
|
||||||
|
prompt_current += _prepend(f"DOCUMENT SOURCE: {chunk.source_type}\n", prepend_tab)
|
||||||
|
if chunk.metadata:
|
||||||
|
prompt_current += _prepend(f"METADATA:\n", prepend_tab)
|
||||||
|
connector_class = identify_connector_class(DocumentSource(chunk.source_type))
|
||||||
|
for metadata_line in connector_class.parse_metadata(chunk.metadata):
|
||||||
|
prompt_current += _prepend(f"\t{metadata_line}\n", prepend_tab)
|
||||||
|
prompt_current += _prepend(DOC_CONTENT_START_PAT, prepend_tab)
|
||||||
|
if include_sep:
|
||||||
|
prompt_current += GENERAL_SEP_PAT
|
||||||
|
return prompt_current
|
||||||
|
|
||||||
|
|
||||||
|
def json_processor(
|
||||||
|
question: str,
|
||||||
|
chunks: list[InferenceChunk],
|
||||||
|
include_metadata: bool = False,
|
||||||
|
include_sep: bool = True,
|
||||||
|
) -> str:
|
||||||
prompt = (
|
prompt = (
|
||||||
BASE_PROMPT + f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
BASE_PROMPT + f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
||||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||||
)
|
)
|
||||||
|
|
||||||
for document in documents:
|
for chunk in chunks:
|
||||||
prompt += f"\n{DOC_SEP_PAT}\n{document}"
|
prompt += f"\n\n{DOC_SEP_PAT}\n"
|
||||||
|
if include_metadata:
|
||||||
|
prompt = add_metadata_section(
|
||||||
|
prompt, chunk, prepend_tab=False, include_sep=include_sep
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt += chunk.content
|
||||||
|
|
||||||
prompt += "\n\n---\n\n"
|
prompt += "\n\n---\n\n"
|
||||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def json_chat_processor(
|
||||||
|
question: str,
|
||||||
|
chunks: list[InferenceChunk],
|
||||||
|
include_metadata: bool = False,
|
||||||
|
include_sep: bool = False,
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
metadata_prompt_section = "with metadata and contents " if include_metadata else ""
|
||||||
|
intro_msg = (
|
||||||
|
f"You are a Question Answering assistant that answers queries based on the provided most relevant documents.\n"
|
||||||
|
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
|
||||||
|
)
|
||||||
|
|
||||||
|
complete_answer_not_found_response = (
|
||||||
|
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
||||||
|
)
|
||||||
|
task_msg = (
|
||||||
|
"Now answer the next user query based on documents above and quote relevant sections.\n"
|
||||||
|
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
|
||||||
|
"All quotes MUST be EXACT substrings from provided documents.\n"
|
||||||
|
"Your responses should be informative and concise.\n"
|
||||||
|
"You MUST prioritize information from provided documents over internal knowledge.\n"
|
||||||
|
"If the query cannot be answered based on the documents, respond with "
|
||||||
|
f"{complete_answer_not_found_response}\n"
|
||||||
|
"If the query requires aggregating the number of documents, respond with "
|
||||||
|
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
||||||
|
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
||||||
|
)
|
||||||
|
messages = [{"role": "system", "content": intro_msg}]
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
full_context = ""
|
||||||
|
if include_metadata:
|
||||||
|
full_context = add_metadata_section(
|
||||||
|
full_context, chunk, prepend_tab=False, include_sep=include_sep
|
||||||
|
)
|
||||||
|
full_context += chunk.content
|
||||||
|
messages.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": full_context,
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "Acknowledged"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
messages.append({"role": "system", "content": task_msg})
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
# EVERYTHING BELOW IS DEPRECATED, kept around as reference, may use again in future
|
||||||
|
|
||||||
|
|
||||||
# Chain of Thought approach works however has higher token cost (more expensive) and is slower.
|
# Chain of Thought approach works however has higher token cost (more expensive) and is slower.
|
||||||
# Should use this one if users ask questions that require logical reasoning.
|
# Should use this one if users ask questions that require logical reasoning.
|
||||||
def json_cot_variant_processor(question: str, documents: list[str]) -> str:
|
def json_cot_variant_processor(question: str, documents: list[str]) -> str:
|
||||||
@ -100,46 +209,6 @@ def freeform_processor(question: str, documents: list[str]) -> str:
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def json_chat_processor(question: str, documents: list[str]) -> list[dict[str, str]]:
|
|
||||||
intro_msg = (
|
|
||||||
"You are a Question Answering assistant that answers queries based on provided documents.\n"
|
|
||||||
'Start by reading the following documents and responding with "Acknowledged".'
|
|
||||||
)
|
|
||||||
|
|
||||||
complete_answer_not_found_response = (
|
|
||||||
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
|
||||||
)
|
|
||||||
task_msg = (
|
|
||||||
"Now answer the next user query based on documents above and quote relevant sections.\n"
|
|
||||||
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
|
|
||||||
"All quotes MUST be EXACT substrings from provided documents.\n"
|
|
||||||
"Your responses should be informative and concise.\n"
|
|
||||||
"You MUST prioritize information from provided documents over internal knowledge.\n"
|
|
||||||
"If the query cannot be answered based on the documents, respond with "
|
|
||||||
f"{complete_answer_not_found_response}\n"
|
|
||||||
"If the query requires aggregating whole documents, respond with "
|
|
||||||
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
|
||||||
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
|
||||||
)
|
|
||||||
messages = [{"role": "system", "content": intro_msg}]
|
|
||||||
|
|
||||||
for document in documents:
|
|
||||||
messages.extend(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": document,
|
|
||||||
},
|
|
||||||
{"role": "assistant", "content": "Acknowledged"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
messages.append({"role": "system", "content": task_msg})
|
|
||||||
|
|
||||||
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
def freeform_chat_processor(
|
def freeform_chat_processor(
|
||||||
question: str, documents: list[str]
|
question: str, documents: list[str]
|
||||||
) -> list[dict[str, str]]:
|
) -> list[dict[str, str]]:
|
||||||
|
@ -112,6 +112,7 @@ class TestQAPostprocessing(unittest.TestCase):
|
|||||||
blurb="anything",
|
blurb="anything",
|
||||||
semantic_identifier="anything",
|
semantic_identifier="anything",
|
||||||
section_continuation=False,
|
section_continuation=False,
|
||||||
|
metadata={},
|
||||||
)
|
)
|
||||||
test_chunk_1 = InferenceChunk(
|
test_chunk_1 = InferenceChunk(
|
||||||
document_id="test doc 1",
|
document_id="test doc 1",
|
||||||
@ -122,6 +123,7 @@ class TestQAPostprocessing(unittest.TestCase):
|
|||||||
blurb="whatever",
|
blurb="whatever",
|
||||||
semantic_identifier="whatever",
|
semantic_identifier="whatever",
|
||||||
section_continuation=False,
|
section_continuation=False,
|
||||||
|
metadata={},
|
||||||
)
|
)
|
||||||
|
|
||||||
test_quotes = [
|
test_quotes = [
|
||||||
|
Reference in New Issue
Block a user