Add metadata for simple doc (#2212)

This commit is contained in:
Chris Weaver 2024-08-22 12:30:28 -07:00 committed by GitHub
parent 197b62aed1
commit 99db27d989
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 150 additions and 36 deletions

View File

@ -54,6 +54,7 @@ def translate_doc_response_to_simple_doc(
highlight for highlight in doc.match_highlights if highlight
],
source_type=doc.source_type,
metadata=doc.metadata,
)
for doc in doc_response.top_documents
]

View File

@ -64,6 +64,7 @@ class SimpleDoc(BaseModel):
blurb: str
match_highlights: list[str]
source_type: DocumentSource
metadata: dict | None
class ChatBasicResponse(BaseModel):

View File

@ -6,7 +6,7 @@ from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from danswer.db.enums import ConnectorCredentialPairStatus
from tests.integration.common.constants import API_SERVER_URL
from tests.integration.common_utils.constants import API_SERVER_URL
class ConnectorCreationDetails(BaseModel):

View File

@ -4,7 +4,7 @@ import requests
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from tests.integration.common.constants import API_SERVER_URL
from tests.integration.common_utils.constants import API_SERVER_URL
class DocumentSetClient:

View File

@ -0,0 +1,62 @@
import os
from typing import cast
import requests
from pydantic import BaseModel
from pydantic import PrivateAttr
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
class LLMProvider(BaseModel):
provider: str
api_key: str
default_model_name: str
api_base: str | None = None
api_version: str | None = None
is_default: bool = True
# only populated after creation
_provider_id: int | None = PrivateAttr()
def create(self) -> int:
llm_provider = LLMProviderUpsertRequest(
name=self.provider,
provider=self.provider,
default_model_name=self.default_model_name,
api_key=self.api_key,
api_base=self.api_base,
api_version=self.api_version,
custom_config=None,
fast_default_model_name=None,
is_public=True,
groups=None,
display_model_names=None,
model_names=None,
)
response = requests.put(
f"{API_SERVER_URL}/admin/llm/provider",
json=llm_provider.dict(),
)
response.raise_for_status()
self._provider_id = cast(int, response.json()["id"])
return self._provider_id
def delete(self) -> None:
response = requests.delete(
f"{API_SERVER_URL}/admin/llm/provider/{self._provider_id}"
)
response.raise_for_status()
def seed_default_openai_provider() -> LLMProvider:
llm = LLMProvider(
provider="openai",
default_model_name="gpt-4o-mini",
api_key=os.environ["OPENAI_API_KEY"],
)
llm.create()
return llm

View File

@ -19,6 +19,7 @@ from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from danswer.document_index.vespa.index import VespaIndex
from danswer.main import setup_postgres
from danswer.main import setup_vespa
from tests.integration.common_utils.llm import seed_default_openai_provider
def _run_migrations(
@ -165,4 +166,6 @@ def reset_all() -> None:
reset_postgres()
print("Resetting Vespa...")
reset_vespa()
print("Seeding LLM Providers...")
seed_default_openai_provider()
print("Finished resetting all.")

View File

@ -4,13 +4,18 @@ import requests
from pydantic import BaseModel
from danswer.configs.constants import DocumentSource
from tests.integration.common.connectors import ConnectorClient
from tests.integration.common.constants import API_SERVER_URL
from tests.integration.common_utils.connectors import ConnectorClient
from tests.integration.common_utils.constants import API_SERVER_URL
class SimpleTestDocument(BaseModel):
id: str
content: str
class SeedDocumentResponse(BaseModel):
cc_pair_id: int
document_ids: list[str]
documents: list[SimpleTestDocument]
class TestDocumentClient:
@ -23,11 +28,9 @@ class TestDocumentClient:
cc_pair_id = connector_details.cc_pair_id
# Create and ingest some documents
document_ids: list[str] = []
documents: list[dict] = []
for _ in range(num_docs):
document_id = f"test-doc-{uuid.uuid4()}"
document_ids.append(document_id)
document = {
"document": {
"id": document_id,
@ -38,12 +41,14 @@ class TestDocumentClient:
}
],
"source": DocumentSource.NOT_APPLICABLE,
"metadata": {},
# just for testing metadata
"metadata": {"document_id": document_id},
"semantic_identifier": f"Test Document {document_id}",
"from_ingestion_api": True,
},
"cc_pair_id": cc_pair_id,
}
documents.append(document)
response = requests.post(
f"{API_SERVER_URL}/danswer-api/ingestion",
json=document,
@ -53,7 +58,13 @@ class TestDocumentClient:
print("Seeding completed successfully.")
return SeedDocumentResponse(
cc_pair_id=cc_pair_id,
document_ids=document_ids,
documents=[
SimpleTestDocument(
id=document["document"]["id"],
content=document["document"]["sections"][0]["text"],
)
for document in documents
],
)

View File

@ -4,7 +4,7 @@ import requests
from ee.danswer.server.user_group.models import UserGroup
from ee.danswer.server.user_group.models import UserGroupCreate
from tests.integration.common.constants import API_SERVER_URL
from tests.integration.common_utils.constants import API_SERVER_URL
class UserGroupClient:

View File

@ -5,8 +5,8 @@ from sqlalchemy.orm import Session
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from tests.integration.common.reset import reset_all
from tests.integration.common.vespa import TestVespaClient
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.vespa import TestVespaClient
@pytest.fixture

View File

@ -2,13 +2,13 @@ import time
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from tests.integration.common.connectors import ConnectorClient
from tests.integration.common.constants import MAX_DELAY
from tests.integration.common.document_sets import DocumentSetClient
from tests.integration.common.seed_documents import TestDocumentClient
from tests.integration.common.user_groups import UserGroupClient
from tests.integration.common.user_groups import UserGroupCreate
from tests.integration.common.vespa import TestVespaClient
from tests.integration.common_utils.connectors import ConnectorClient
from tests.integration.common_utils.constants import MAX_DELAY
from tests.integration.common_utils.document_sets import DocumentSetClient
from tests.integration.common_utils.seed_documents import TestDocumentClient
from tests.integration.common_utils.user_groups import UserGroupClient
from tests.integration.common_utils.user_groups import UserGroupCreate
from tests.integration.common_utils.vespa import TestVespaClient
def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None:
@ -129,12 +129,12 @@ def test_connector_deletion(reset: None, vespa_client: TestVespaClient) -> None:
print("Connector 1 deleted")
# validate vespa documents
c1_vespa_docs = vespa_client.get_documents_by_id(c1_seed_res.document_ids)[
"documents"
]
c2_vespa_docs = vespa_client.get_documents_by_id(c2_seed_res.document_ids)[
"documents"
]
c1_vespa_docs = vespa_client.get_documents_by_id(
[doc.id for doc in c1_seed_res.documents]
)["documents"]
c2_vespa_docs = vespa_client.get_documents_by_id(
[doc.id for doc in c2_seed_res.documents]
)["documents"]
assert len(c1_vespa_docs) == 0
assert len(c2_vespa_docs) == 5

View File

@ -0,0 +1,36 @@
import requests
from tests.integration.common_utils.connectors import ConnectorClient
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.seed_documents import TestDocumentClient
def test_send_message_simple_with_history(reset: None) -> None:
# create connectors
c1_details = ConnectorClient.create_connector(name_prefix="tc1")
c1_seed_res = TestDocumentClient.seed_documents(
num_docs=5, cc_pair_id=c1_details.cc_pair_id
)
response = requests.post(
f"{API_SERVER_URL}/chat/send-message-simple-with-history",
json={
"messages": [{"message": c1_seed_res.documents[0].content, "role": "user"}],
"persona_id": 0,
"prompt_id": 0,
},
)
assert response.status_code == 200
response_json = response.json()
# Check that the top document is the correct document
assert response_json["simple_search_docs"][0]["id"] == c1_seed_res.documents[0].id
# assert that the metadata is correct
for doc in c1_seed_res.documents:
found_doc = next(
(x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None
)
assert found_doc
assert found_doc["metadata"]["document_id"] == doc.id

View File

@ -1,10 +1,10 @@
import time
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from tests.integration.common.seed_documents import TestDocumentClient
from tests.integration.common.vespa import TestVespaClient
from tests.integration.document_set.utils import create_document_set
from tests.integration.document_set.utils import fetch_document_sets
from tests.integration.common_utils.seed_documents import TestDocumentClient
from tests.integration.common_utils.vespa import TestVespaClient
from tests.integration.tests.document_set.utils import create_document_set
from tests.integration.tests.document_set.utils import fetch_document_sets
def test_multiple_document_sets_syncing_same_connnector(
@ -68,12 +68,12 @@ def test_multiple_document_sets_syncing_same_connnector(
doc_set_names = {doc_set.name for doc_set in doc_sets}
# make sure documents are as expected
result = vespa_client.get_documents_by_id(seed_result.document_ids)
seeded_document_ids = [doc.id for doc in seed_result.documents]
result = vespa_client.get_documents_by_id([doc.id for doc in seed_result.documents])
documents = result["documents"]
assert len(documents) == len(seed_result.document_ids)
assert all(
doc["fields"]["document_id"] in seed_result.document_ids for doc in documents
)
assert len(documents) == len(seed_result.documents)
assert all(doc["fields"]["document_id"] in seeded_document_ids for doc in documents)
assert all(
set(doc["fields"]["document_sets"].keys()) == doc_set_names for doc in documents
)

View File

@ -4,7 +4,7 @@ import requests
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from tests.integration.common.constants import API_SERVER_URL
from tests.integration.common_utils.constants import API_SERVER_URL
def create_document_set(doc_set_creation_request: DocumentSetCreationRequest) -> int: