mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-28 12:58:41 +02:00
Basic Slack Bot Support (#128)
This commit is contained in:
@@ -1,6 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -15,15 +14,14 @@ from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
|||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.connectors.models import Section
|
from danswer.connectors.models import Section
|
||||||
from danswer.connectors.slack.utils import get_message_link
|
from danswer.connectors.slack.utils import get_message_link
|
||||||
|
from danswer.connectors.slack.utils import make_slack_api_call_paginated
|
||||||
|
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||||
from danswer.utils.logging import setup_logger
|
from danswer.utils.logging import setup_logger
|
||||||
from slack_sdk import WebClient
|
from slack_sdk import WebClient
|
||||||
from slack_sdk.errors import SlackApiError
|
|
||||||
from slack_sdk.web import SlackResponse
|
from slack_sdk.web import SlackResponse
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
SLACK_LIMIT = 900
|
|
||||||
|
|
||||||
|
|
||||||
ChannelType = dict[str, Any]
|
ChannelType = dict[str, Any]
|
||||||
MessageType = dict[str, Any]
|
MessageType = dict[str, Any]
|
||||||
@@ -31,61 +29,10 @@ MessageType = dict[str, Any]
|
|||||||
ThreadType = list[MessageType]
|
ThreadType = list[MessageType]
|
||||||
|
|
||||||
|
|
||||||
def _make_slack_api_call_paginated(
|
|
||||||
call: Callable[..., SlackResponse],
|
|
||||||
) -> Callable[..., list[dict[str, Any]]]:
|
|
||||||
"""Wraps calls to slack API so that they automatically handle pagination"""
|
|
||||||
|
|
||||||
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
|
|
||||||
results: list[dict[str, Any]] = []
|
|
||||||
cursor: str | None = None
|
|
||||||
has_more = True
|
|
||||||
while has_more:
|
|
||||||
for result in call(cursor=cursor, limit=SLACK_LIMIT, **kwargs):
|
|
||||||
has_more = result.get("has_more", False)
|
|
||||||
cursor = result.get("response_metadata", {}).get("next_cursor", "")
|
|
||||||
results.append(cast(dict[str, Any], result))
|
|
||||||
return results
|
|
||||||
|
|
||||||
return paginated_call
|
|
||||||
|
|
||||||
|
|
||||||
def _make_slack_api_rate_limited(
|
|
||||||
call: Callable[..., SlackResponse], max_retries: int = 3
|
|
||||||
) -> Callable[..., SlackResponse]:
|
|
||||||
"""Wraps calls to slack API so that they automatically handle rate limiting"""
|
|
||||||
|
|
||||||
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
|
||||||
for _ in range(max_retries):
|
|
||||||
try:
|
|
||||||
# Make the API call
|
|
||||||
response = call(**kwargs)
|
|
||||||
|
|
||||||
# Check for errors in the response
|
|
||||||
if response.get("ok"):
|
|
||||||
return response
|
|
||||||
else:
|
|
||||||
raise SlackApiError("", response)
|
|
||||||
|
|
||||||
except SlackApiError as e:
|
|
||||||
if e.response["error"] == "ratelimited":
|
|
||||||
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
|
|
||||||
retry_after = int(e.response.headers.get("Retry-After", 1))
|
|
||||||
time.sleep(retry_after)
|
|
||||||
else:
|
|
||||||
# Raise the error for non-transient errors
|
|
||||||
raise
|
|
||||||
|
|
||||||
# If the code reaches this point, all retries have been exhausted
|
|
||||||
raise Exception(f"Max retries ({max_retries}) exceeded")
|
|
||||||
|
|
||||||
return rate_limited_call
|
|
||||||
|
|
||||||
|
|
||||||
def _make_slack_api_call(
|
def _make_slack_api_call(
|
||||||
call: Callable[..., SlackResponse], **kwargs: Any
|
call: Callable[..., SlackResponse], **kwargs: Any
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
return _make_slack_api_call_paginated(_make_slack_api_rate_limited(call))(**kwargs)
|
return make_slack_api_call_paginated(make_slack_api_rate_limited(call))(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
|
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:
|
||||||
|
@@ -1,6 +1,14 @@
|
|||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
|
from slack_sdk.errors import SlackApiError
|
||||||
|
from slack_sdk.web import SlackResponse
|
||||||
|
|
||||||
|
# number of messages we request per page when fetching paginated slack messages
|
||||||
|
_SLACK_LIMIT = 900
|
||||||
|
|
||||||
|
|
||||||
def get_message_link(
|
def get_message_link(
|
||||||
event: dict[str, Any], workspace: str, channel_id: str | None = None
|
event: dict[str, Any], workspace: str, channel_id: str | None = None
|
||||||
@@ -13,3 +21,54 @@ def get_message_link(
|
|||||||
return (
|
return (
|
||||||
f"https://{workspace}.slack.com/archives/{channel_id}/p{message_ts_without_dot}"
|
f"https://{workspace}.slack.com/archives/{channel_id}/p{message_ts_without_dot}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_slack_api_call_paginated(
|
||||||
|
call: Callable[..., SlackResponse],
|
||||||
|
) -> Callable[..., list[dict[str, Any]]]:
|
||||||
|
"""Wraps calls to slack API so that they automatically handle pagination"""
|
||||||
|
|
||||||
|
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
|
||||||
|
results: list[dict[str, Any]] = []
|
||||||
|
cursor: str | None = None
|
||||||
|
has_more = True
|
||||||
|
while has_more:
|
||||||
|
for result in call(cursor=cursor, limit=_SLACK_LIMIT, **kwargs):
|
||||||
|
has_more = result.get("has_more", False)
|
||||||
|
cursor = result.get("response_metadata", {}).get("next_cursor", "")
|
||||||
|
results.append(cast(dict[str, Any], result))
|
||||||
|
return results
|
||||||
|
|
||||||
|
return paginated_call
|
||||||
|
|
||||||
|
|
||||||
|
def make_slack_api_rate_limited(
|
||||||
|
call: Callable[..., SlackResponse], max_retries: int = 3
|
||||||
|
) -> Callable[..., SlackResponse]:
|
||||||
|
"""Wraps calls to slack API so that they automatically handle rate limiting"""
|
||||||
|
|
||||||
|
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||||
|
for _ in range(max_retries):
|
||||||
|
try:
|
||||||
|
# Make the API call
|
||||||
|
response = call(**kwargs)
|
||||||
|
|
||||||
|
# Check for errors in the response
|
||||||
|
if response.get("ok"):
|
||||||
|
return response
|
||||||
|
else:
|
||||||
|
raise SlackApiError("", response)
|
||||||
|
|
||||||
|
except SlackApiError as e:
|
||||||
|
if e.response["error"] == "ratelimited":
|
||||||
|
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
|
||||||
|
retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||||
|
time.sleep(retry_after)
|
||||||
|
else:
|
||||||
|
# Raise the error for non-transient errors
|
||||||
|
raise
|
||||||
|
|
||||||
|
# If the code reaches this point, all retries have been exhausted
|
||||||
|
raise Exception(f"Max retries ({max_retries}) exceeded")
|
||||||
|
|
||||||
|
return rate_limited_call
|
||||||
|
@@ -2,8 +2,8 @@ from typing import Any
|
|||||||
|
|
||||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
from danswer.direct_qa.question_answer import OpenAIChatCompletionQA
|
from danswer.direct_qa.llm import OpenAIChatCompletionQA
|
||||||
from danswer.direct_qa.question_answer import OpenAICompletionQA
|
from danswer.direct_qa.llm import OpenAICompletionQA
|
||||||
|
|
||||||
|
|
||||||
def get_default_backend_qa_model(
|
def get_default_backend_qa_model(
|
||||||
|
82
backend/danswer/direct_qa/answer_question.py
Normal file
82
backend/danswer/direct_qa/answer_question.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
import time
|
||||||
|
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
||||||
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
|
from danswer.datastores.qdrant.store import QdrantIndex
|
||||||
|
from danswer.datastores.typesense.store import TypesenseIndex
|
||||||
|
from danswer.db.models import User
|
||||||
|
from danswer.direct_qa import get_default_backend_qa_model
|
||||||
|
from danswer.search.danswer_helper import query_intent
|
||||||
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||||
|
from danswer.search.models import SearchType
|
||||||
|
from danswer.search.semantic_search import chunks_to_search_docs
|
||||||
|
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||||
|
from danswer.server.models import QAResponse
|
||||||
|
from danswer.server.models import QuestionRequest
|
||||||
|
from danswer.utils.logging import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def answer_question(question: QuestionRequest, user: User | None) -> QAResponse:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
query = question.query
|
||||||
|
collection = question.collection
|
||||||
|
filters = question.filters
|
||||||
|
use_keyword = question.use_keyword
|
||||||
|
offset_count = question.offset if question.offset is not None else 0
|
||||||
|
logger.info(f"Received QA query: {query}")
|
||||||
|
|
||||||
|
predicted_search, predicted_flow = query_intent(query)
|
||||||
|
if use_keyword is None:
|
||||||
|
use_keyword = predicted_search == SearchType.KEYWORD
|
||||||
|
|
||||||
|
user_id = None if user is None else user.id
|
||||||
|
if use_keyword:
|
||||||
|
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||||
|
query, user_id, filters, TypesenseIndex(collection)
|
||||||
|
)
|
||||||
|
unranked_chunks: list[InferenceChunk] | None = []
|
||||||
|
else:
|
||||||
|
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||||
|
query, user_id, filters, QdrantIndex(collection)
|
||||||
|
)
|
||||||
|
if not ranked_chunks:
|
||||||
|
return QAResponse(
|
||||||
|
answer=None,
|
||||||
|
quotes=None,
|
||||||
|
top_ranked_docs=None,
|
||||||
|
lower_ranked_docs=None,
|
||||||
|
predicted_flow=predicted_flow,
|
||||||
|
predicted_search=predicted_search,
|
||||||
|
)
|
||||||
|
|
||||||
|
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||||
|
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
|
||||||
|
if chunk_offset >= len(ranked_chunks):
|
||||||
|
raise ValueError("Chunks offset too large, should not retry this many times")
|
||||||
|
|
||||||
|
error_msg = None
|
||||||
|
try:
|
||||||
|
answer, quotes = qa_model.answer_question(
|
||||||
|
query,
|
||||||
|
ranked_chunks[chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS],
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# exception is logged in the answer_question method, no need to re-log
|
||||||
|
answer, quotes = None, None
|
||||||
|
error_msg = f"Error occurred in call to LLM - {e}"
|
||||||
|
|
||||||
|
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
||||||
|
|
||||||
|
return QAResponse(
|
||||||
|
answer=answer,
|
||||||
|
quotes=quotes,
|
||||||
|
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
||||||
|
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
||||||
|
predicted_flow=predicted_flow,
|
||||||
|
predicted_search=predicted_search,
|
||||||
|
error_msg=error_msg,
|
||||||
|
)
|
@@ -1,5 +1,5 @@
|
|||||||
from danswer.direct_qa import get_default_backend_qa_model
|
from danswer.direct_qa import get_default_backend_qa_model
|
||||||
from danswer.direct_qa.question_answer import OpenAIQAModel
|
from danswer.direct_qa.llm import OpenAIQAModel
|
||||||
from openai.error import AuthenticationError
|
from openai.error import AuthenticationError
|
||||||
from openai.error import Timeout
|
from openai.error import Timeout
|
||||||
|
|
||||||
|
206
backend/danswer/listeners/slack_listener.py
Normal file
206
backend/danswer/listeners/slack_listener.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||||
|
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||||
|
from danswer.direct_qa.answer_question import answer_question
|
||||||
|
from danswer.server.models import QAResponse
|
||||||
|
from danswer.server.models import QuestionRequest
|
||||||
|
from danswer.server.models import SearchDoc
|
||||||
|
from danswer.utils.logging import setup_logger
|
||||||
|
from slack_sdk import WebClient
|
||||||
|
from slack_sdk.socket_mode import SocketModeClient
|
||||||
|
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||||
|
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
_NUM_RETRIES = 3
|
||||||
|
_NUM_DOCS_TO_DISPLAY = 5
|
||||||
|
|
||||||
|
|
||||||
|
def _get_socket_client() -> SocketModeClient:
|
||||||
|
# For more info on how to set this up, checkout the docs:
|
||||||
|
# https://docs.danswer.dev/slack_bot_setup
|
||||||
|
app_token = os.environ.get("DANSWER_BOT_SLACK_APP_TOKEN")
|
||||||
|
if not app_token:
|
||||||
|
raise RuntimeError("DANSWER_BOT_SLACK_APP_TOKEN is not set")
|
||||||
|
bot_token = os.environ.get("DANSWER_BOT_SLACK_BOT_TOKEN")
|
||||||
|
if not bot_token:
|
||||||
|
raise RuntimeError("DANSWER_BOT_SLACK_BOT_TOKEN is not set")
|
||||||
|
return SocketModeClient(
|
||||||
|
# This app-level token will be used only for establishing a connection
|
||||||
|
app_token=app_token,
|
||||||
|
web_client=WebClient(token=bot_token),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _process_quotes(
|
||||||
|
quotes: dict[str, dict[str, str | int | None]] | None
|
||||||
|
) -> tuple[str | None, list[str]]:
|
||||||
|
if not quotes:
|
||||||
|
return None, []
|
||||||
|
|
||||||
|
quote_lines: list[str] = []
|
||||||
|
doc_identifiers: list[str] = []
|
||||||
|
for quote_dict in quotes.values():
|
||||||
|
doc_link = quote_dict.get("document_id")
|
||||||
|
doc_name = quote_dict.get("semantic_identifier")
|
||||||
|
if doc_link and doc_name and doc_name not in doc_identifiers:
|
||||||
|
doc_identifiers.append(str(doc_name))
|
||||||
|
quote_lines.append(f"- <{doc_link}|{doc_name}>")
|
||||||
|
|
||||||
|
if not quote_lines:
|
||||||
|
return None, []
|
||||||
|
|
||||||
|
return "\n".join(quote_lines), doc_identifiers
|
||||||
|
|
||||||
|
|
||||||
|
def _process_documents(
|
||||||
|
documents: list[SearchDoc] | None, already_displayed_doc_identifiers: list[str]
|
||||||
|
) -> str | None:
|
||||||
|
if not documents:
|
||||||
|
return None
|
||||||
|
|
||||||
|
top_documents = [
|
||||||
|
d
|
||||||
|
for d in documents
|
||||||
|
if d.semantic_identifier not in already_displayed_doc_identifiers
|
||||||
|
][:_NUM_DOCS_TO_DISPLAY]
|
||||||
|
top_documents_str = "\n".join(
|
||||||
|
[f"- <{d.link}|{d.semantic_identifier}>" for d in top_documents]
|
||||||
|
)
|
||||||
|
return "*Other potentially relevant documents:*\n" + top_documents_str
|
||||||
|
|
||||||
|
|
||||||
|
def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> None:
|
||||||
|
if req.type == "events_api":
|
||||||
|
# Acknowledge the request anyway
|
||||||
|
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||||||
|
client.send_socket_mode_response(response)
|
||||||
|
|
||||||
|
# Ensure that the message is a new message + of expected type
|
||||||
|
event_type = req.payload.get("event", {}).get("type")
|
||||||
|
if event_type != "message":
|
||||||
|
logger.info(f"Ignoring non-message event of type '{event_type}'")
|
||||||
|
|
||||||
|
message_subtype = req.payload.get("event", {}).get("subtype")
|
||||||
|
if req.payload.get("event", {}).get("subtype") is not None:
|
||||||
|
# this covers things like channel_join, channel_leave, etc.
|
||||||
|
logger.info(
|
||||||
|
f"Ignoring message with subtype '{message_subtype}' since is is a special message type"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if req.payload.get("event", {}).get("bot_profile"):
|
||||||
|
logger.info("Ignoring message from bot")
|
||||||
|
return
|
||||||
|
|
||||||
|
msg = req.payload.get("event", {}).get("text")
|
||||||
|
thread_ts = req.payload.get("event", {}).get("ts")
|
||||||
|
if not msg:
|
||||||
|
logger.error("Unable to process empty message")
|
||||||
|
return
|
||||||
|
|
||||||
|
# TODO: message should be enqueued and processed elsewhere,
|
||||||
|
# but doing it here for now for simplicity
|
||||||
|
|
||||||
|
def _get_answer(question: QuestionRequest) -> QAResponse | None:
|
||||||
|
try:
|
||||||
|
answer = answer_question(question=question, user=None)
|
||||||
|
if not answer.error_msg:
|
||||||
|
return answer
|
||||||
|
else:
|
||||||
|
raise RuntimeError(answer.error_msg)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unable to process message: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
answer = None
|
||||||
|
for _ in range(_NUM_RETRIES):
|
||||||
|
answer = _get_answer(
|
||||||
|
QuestionRequest(
|
||||||
|
query=req.payload.get("event", {}).get("text"),
|
||||||
|
collection=QDRANT_DEFAULT_COLLECTION,
|
||||||
|
use_keyword=False,
|
||||||
|
filters=None,
|
||||||
|
offset=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if answer:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not answer:
|
||||||
|
logger.error(
|
||||||
|
f"Unable to process message - did not successfully answer in {_NUM_RETRIES} attempts"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
if not answer.answer:
|
||||||
|
logger.error(f"Unable to process message - no answer found")
|
||||||
|
return
|
||||||
|
|
||||||
|
# convert raw response into "nicely" formatted Slack message
|
||||||
|
quote_str, doc_identifiers = _process_quotes(answer.quotes)
|
||||||
|
top_documents_str = _process_documents(answer.top_ranked_docs, doc_identifiers)
|
||||||
|
if quote_str:
|
||||||
|
text = f"{answer.answer}\n\n*Sources:*\n{quote_str}\n\n{top_documents_str}"
|
||||||
|
else:
|
||||||
|
text = f"{answer.answer}\n\n*Warning*: no sources were quoted for this answer, so it may be unreliable 😔\n\n{top_documents_str}"
|
||||||
|
|
||||||
|
def _respond_in_thread(
|
||||||
|
channel: str,
|
||||||
|
text: str,
|
||||||
|
thread_ts: str,
|
||||||
|
) -> str | None:
|
||||||
|
slack_call = make_slack_api_rate_limited(client.web_client.chat_postMessage)
|
||||||
|
response = slack_call(
|
||||||
|
channel=channel,
|
||||||
|
text=text,
|
||||||
|
thread_ts=thread_ts,
|
||||||
|
)
|
||||||
|
if not response.get("ok"):
|
||||||
|
return f"Unable to post message: {response}"
|
||||||
|
return None
|
||||||
|
|
||||||
|
successfully_answered = False
|
||||||
|
for _ in range(_NUM_RETRIES):
|
||||||
|
error_msg = _respond_in_thread(
|
||||||
|
channel=req.payload.get("event", {}).get("channel"),
|
||||||
|
text=text,
|
||||||
|
thread_ts=thread_ts,
|
||||||
|
)
|
||||||
|
if error_msg:
|
||||||
|
logger.error(error_msg)
|
||||||
|
else:
|
||||||
|
successfully_answered = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not successfully_answered:
|
||||||
|
logger.error(
|
||||||
|
f"Unable to process message - could not respond in slack in {_NUM_RETRIES} attempts"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Successfully processed message with ts: '{thread_ts}'")
|
||||||
|
|
||||||
|
|
||||||
|
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
|
||||||
|
# the slack bot in your workspace, and then add the bot to any channels you want to
|
||||||
|
# try and answer questions for. Running this file will setup Danswer to listen to all
|
||||||
|
# messages in those channels and attempt to answer them. As of now, it will only respond
|
||||||
|
# to messages sent directly in the channel - it will not respond to messages sent within a
|
||||||
|
# thread.
|
||||||
|
#
|
||||||
|
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
|
||||||
|
# without issue.
|
||||||
|
if __name__ == "__main__":
|
||||||
|
socket_client = _get_socket_client()
|
||||||
|
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
|
||||||
|
# Establish a WebSocket connection to the Socket Mode servers
|
||||||
|
logger.info("Listening for messages from Slack...")
|
||||||
|
socket_client.connect()
|
||||||
|
|
||||||
|
# Just not to stop this process
|
||||||
|
from threading import Event
|
||||||
|
|
||||||
|
Event().wait()
|
@@ -38,7 +38,7 @@ from danswer.db.engine import get_sqlalchemy_async_engine
|
|||||||
from danswer.db.index_attempt import create_index_attempt
|
from danswer.db.index_attempt import create_index_attempt
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa.key_validation import check_openai_api_key_is_valid
|
from danswer.direct_qa.key_validation import check_openai_api_key_is_valid
|
||||||
from danswer.direct_qa.question_answer import get_openai_api_key
|
from danswer.direct_qa.llm import get_openai_api_key
|
||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
from danswer.server.models import ApiKey
|
from danswer.server.models import ApiKey
|
||||||
|
@@ -105,6 +105,7 @@ class QAResponse(SearchResponse):
|
|||||||
quotes: dict[str, dict[str, str | int | None]] | None
|
quotes: dict[str, dict[str, str | int | None]] | None
|
||||||
predicted_flow: QueryFlow
|
predicted_flow: QueryFlow
|
||||||
predicted_search: SearchType
|
predicted_search: SearchType
|
||||||
|
error_msg: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class UserByEmail(BaseModel):
|
class UserByEmail(BaseModel):
|
||||||
|
@@ -9,7 +9,8 @@ from danswer.datastores.qdrant.store import QdrantIndex
|
|||||||
from danswer.datastores.typesense.store import TypesenseIndex
|
from danswer.datastores.typesense.store import TypesenseIndex
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa import get_default_backend_qa_model
|
from danswer.direct_qa import get_default_backend_qa_model
|
||||||
from danswer.direct_qa.question_answer import get_json_line
|
from danswer.direct_qa.answer_question import answer_question
|
||||||
|
from danswer.direct_qa.llm import get_json_line
|
||||||
from danswer.search.danswer_helper import query_intent
|
from danswer.search.danswer_helper import query_intent
|
||||||
from danswer.search.danswer_helper import recommend_search_flow
|
from danswer.search.danswer_helper import recommend_search_flow
|
||||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||||
@@ -85,62 +86,7 @@ def keyword_search(
|
|||||||
def direct_qa(
|
def direct_qa(
|
||||||
question: QuestionRequest, user: User = Depends(current_user)
|
question: QuestionRequest, user: User = Depends(current_user)
|
||||||
) -> QAResponse:
|
) -> QAResponse:
|
||||||
start_time = time.time()
|
return answer_question(question=question, user=user)
|
||||||
|
|
||||||
query = question.query
|
|
||||||
collection = question.collection
|
|
||||||
filters = question.filters
|
|
||||||
use_keyword = question.use_keyword
|
|
||||||
offset_count = question.offset if question.offset is not None else 0
|
|
||||||
logger.info(f"Received QA query: {query}")
|
|
||||||
|
|
||||||
predicted_search, predicted_flow = query_intent(query)
|
|
||||||
if use_keyword is None:
|
|
||||||
use_keyword = predicted_search == SearchType.KEYWORD
|
|
||||||
|
|
||||||
user_id = None if user is None else user.id
|
|
||||||
if use_keyword:
|
|
||||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
|
||||||
query, user_id, filters, TypesenseIndex(collection)
|
|
||||||
)
|
|
||||||
unranked_chunks: list[InferenceChunk] | None = []
|
|
||||||
else:
|
|
||||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
|
||||||
query, user_id, filters, QdrantIndex(collection)
|
|
||||||
)
|
|
||||||
if not ranked_chunks:
|
|
||||||
return QAResponse(
|
|
||||||
answer=None,
|
|
||||||
quotes=None,
|
|
||||||
top_ranked_docs=None,
|
|
||||||
lower_ranked_docs=None,
|
|
||||||
predicted_flow=predicted_flow,
|
|
||||||
predicted_search=predicted_search,
|
|
||||||
)
|
|
||||||
|
|
||||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
|
||||||
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
|
|
||||||
if chunk_offset >= len(ranked_chunks):
|
|
||||||
raise ValueError("Chunks offset too large, should not retry this many times")
|
|
||||||
try:
|
|
||||||
answer, quotes = qa_model.answer_question(
|
|
||||||
query,
|
|
||||||
ranked_chunks[chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS],
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# exception is logged in the answer_question method, no need to re-log
|
|
||||||
answer, quotes = None, None
|
|
||||||
|
|
||||||
logger.info(f"Total QA took {time.time() - start_time} seconds")
|
|
||||||
|
|
||||||
return QAResponse(
|
|
||||||
answer=answer,
|
|
||||||
quotes=quotes,
|
|
||||||
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
|
||||||
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
|
||||||
predicted_flow=predicted_flow,
|
|
||||||
predicted_search=predicted_search,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/stream-direct-qa")
|
@router.post("/stream-direct-qa")
|
||||||
|
@@ -16,3 +16,15 @@ stdout_logfile=/var/log/file_deletion.log
|
|||||||
redirect_stderr=true
|
redirect_stderr=true
|
||||||
stdout_logfile_maxbytes=52428800
|
stdout_logfile_maxbytes=52428800
|
||||||
autorestart=true
|
autorestart=true
|
||||||
|
|
||||||
|
# Listens for slack messages and responds with answers
|
||||||
|
# for all channels that the DanswerBot has been added to.
|
||||||
|
# If not setup, this will just fail 5 times and then stop.
|
||||||
|
# More details on setup here: https://docs.danswer.dev/slack_bot_setup
|
||||||
|
[program:slack_bot_listener]
|
||||||
|
command=python danswer/listeners/slack_listener.py
|
||||||
|
stdout_logfile=/var/log/slack_bot_listener.log
|
||||||
|
redirect_stderr=true
|
||||||
|
stdout_logfile_maxbytes=52428800
|
||||||
|
autorestart=true
|
||||||
|
startretries=5
|
||||||
|
@@ -2,8 +2,8 @@ import textwrap
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.direct_qa.question_answer import match_quotes_to_docs
|
from danswer.direct_qa.llm import match_quotes_to_docs
|
||||||
from danswer.direct_qa.question_answer import separate_answer_quotes
|
from danswer.direct_qa.llm import separate_answer_quotes
|
||||||
|
|
||||||
|
|
||||||
class TestQAPostprocessing(unittest.TestCase):
|
class TestQAPostprocessing(unittest.TestCase):
|
||||||
|
@@ -27,6 +27,12 @@ ENABLE_OAUTH=True
|
|||||||
GOOGLE_OAUTH_CLIENT_ID=
|
GOOGLE_OAUTH_CLIENT_ID=
|
||||||
GOOGLE_OAUTH_CLIENT_SECRET=
|
GOOGLE_OAUTH_CLIENT_SECRET=
|
||||||
|
|
||||||
|
# If you want to setup a slack bot to answer questions automatically in Slack
|
||||||
|
# channels it is added to, you must specify the below.
|
||||||
|
# More information in the guide here: https://docs.danswer.dev/slack_bot_setup
|
||||||
|
DANSWER_BOT_SLACK_APP_TOKEN=
|
||||||
|
DANSWER_BOT_SLACK_BOT_TOKEN=
|
||||||
|
|
||||||
# Used to generate values for security verification, use a random string
|
# Used to generate values for security verification, use a random string
|
||||||
SECRET=
|
SECRET=
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user