mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-05 12:12:29 +02:00
Mypy fixes for default configs (#1442)
This commit is contained in:
@ -1,92 +0,0 @@
|
||||
# This file is purely for development use, not included in any builds
|
||||
# Use this to test the chat feature
|
||||
# This script does not allow for branching logic that is supported by the backend APIs
|
||||
# This script also does not allow for editing/regeneration of user/model messages
|
||||
# Have Danswer API server running to use this.
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
|
||||
LOCAL_CHAT_ENDPOINT = f"http://127.0.0.1:{APP_PORT}/chat/"
|
||||
|
||||
|
||||
def create_new_session() -> int:
|
||||
data = {"persona_id": 0} # Global default Persona
|
||||
response = requests.post(LOCAL_CHAT_ENDPOINT + "create-chat-session", json=data)
|
||||
response.raise_for_status()
|
||||
new_session_id = response.json()["chat_session_id"]
|
||||
return new_session_id
|
||||
|
||||
|
||||
def send_chat_message(
|
||||
message: str,
|
||||
chat_session_id: int,
|
||||
parent_message: int | None,
|
||||
) -> int:
|
||||
data = {
|
||||
"message": message,
|
||||
"chat_session_id": chat_session_id,
|
||||
"parent_message_id": parent_message,
|
||||
"prompt_id": 0, # Global default Prompt
|
||||
"retrieval_options": {
|
||||
"run_search": "always",
|
||||
"real_time": True,
|
||||
"filters": {"tags": []},
|
||||
},
|
||||
}
|
||||
|
||||
docs: list[dict] | None = None
|
||||
message_id: int | None = None
|
||||
with requests.post(
|
||||
LOCAL_CHAT_ENDPOINT + "send-message", json=data, stream=True
|
||||
) as r:
|
||||
for json_response in r.iter_lines():
|
||||
response_text = json.loads(json_response.decode())
|
||||
new_token = response_text.get("answer_piece")
|
||||
if docs is None:
|
||||
docs = response_text.get("top_documents")
|
||||
if message_id is None:
|
||||
message_id = response_text.get("message_id")
|
||||
if new_token:
|
||||
print(new_token, end="", flush=True)
|
||||
print()
|
||||
|
||||
if docs:
|
||||
docs.sort(key=lambda x: x["score"], reverse=True) # type: ignore
|
||||
print("\nReference Docs:")
|
||||
for ind, doc in enumerate(docs, start=1):
|
||||
print(f"\t - Doc {ind}: {doc.get('semantic_identifier')}")
|
||||
|
||||
if message_id is None:
|
||||
raise ValueError("Couldn't get latest message id")
|
||||
|
||||
return message_id
|
||||
|
||||
|
||||
def run_chat() -> None:
|
||||
try:
|
||||
new_session_id = create_new_session()
|
||||
print(f"Chat Session ID: {new_session_id}")
|
||||
except requests.exceptions.ConnectionError:
|
||||
print(
|
||||
"Looks like you haven't started the Danswer Backend server, please run the FastAPI server"
|
||||
)
|
||||
exit()
|
||||
return
|
||||
|
||||
parent_message = None
|
||||
while True:
|
||||
new_message = input(
|
||||
"\n\n----------------------------------\n"
|
||||
"Please provide a new chat message:\n> "
|
||||
)
|
||||
|
||||
parent_message = send_chat_message(
|
||||
new_message, new_session_id, parent_message=parent_message
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_chat()
|
@ -1,95 +0,0 @@
|
||||
# This file is purely for development use, not included in any builds
|
||||
import argparse
|
||||
import json
|
||||
from pprint import pprint
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
previous_query = None
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--type",
|
||||
type=str,
|
||||
default="hybrid",
|
||||
help='"hybrid" "semantic" or "keyword", defaults to "hybrid"',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--stream",
|
||||
action="store_true",
|
||||
help="Enable streaming response",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--filters",
|
||||
type=str,
|
||||
help="Comma separated list of source types to filter by (no spaces)",
|
||||
)
|
||||
|
||||
parser.add_argument("query", nargs="*", help="The query to process")
|
||||
|
||||
previous_input = None
|
||||
while True:
|
||||
try:
|
||||
user_input = input(
|
||||
"\n\nAsk any question:\n"
|
||||
" - Use -t (hybrid/semantic/keyword) flag to choose search flow.\n"
|
||||
" - prefix with -s to stream answer, --filters web,slack etc. for filters.\n"
|
||||
" - input an empty string to rerun last query.\n\t"
|
||||
)
|
||||
|
||||
if user_input:
|
||||
previous_input = user_input
|
||||
else:
|
||||
if not previous_input:
|
||||
print("No previous input")
|
||||
continue
|
||||
print(f"Re-executing previous question:\n\t{previous_input}")
|
||||
user_input = previous_input
|
||||
|
||||
args = parser.parse_args(user_input.split())
|
||||
|
||||
search_type = str(args.type).lower()
|
||||
stream = args.stream
|
||||
source_types = args.filters.split(",") if args.filters else None
|
||||
|
||||
query = " ".join(args.query)
|
||||
|
||||
if search_type not in ["hybrid", "semantic", "keyword"]:
|
||||
raise ValueError("Invalid Search")
|
||||
|
||||
elif stream:
|
||||
path = "stream-direct-qa"
|
||||
else:
|
||||
path = "direct-qa"
|
||||
|
||||
endpoint = f"http://127.0.0.1:{APP_PORT}/{path}"
|
||||
|
||||
query_json = {
|
||||
"query": query,
|
||||
"collection": DOCUMENT_INDEX_NAME,
|
||||
"filters": {SOURCE_TYPE: source_types},
|
||||
"enable_auto_detect_filters": True,
|
||||
"search_type": search_type,
|
||||
}
|
||||
|
||||
if args.stream:
|
||||
with requests.post(endpoint, json=query_json, stream=True) as r:
|
||||
for json_response in r.iter_lines():
|
||||
pprint(json.loads(json_response.decode()))
|
||||
else:
|
||||
response = requests.post(endpoint, json=query_json)
|
||||
contents = json.loads(response.content)
|
||||
pprint(contents)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed due to {e}, retrying")
|
Reference in New Issue
Block a user