Basic Slack Bot Support (#128)

This commit is contained in:
Chris Weaver
2023-07-03 14:26:33 -07:00
committed by GitHub
parent 381b3719c9
commit 2f54795631
13 changed files with 378 additions and 119 deletions

View File

@@ -1,6 +1,5 @@
import json import json
import os import os
import time
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@@ -15,15 +14,14 @@ from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.connectors.models import Section from danswer.connectors.models import Section
from danswer.connectors.slack.utils import get_message_link from danswer.connectors.slack.utils import get_message_link
from danswer.connectors.slack.utils import make_slack_api_call_paginated
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.utils.logging import setup_logger from danswer.utils.logging import setup_logger
from slack_sdk import WebClient from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse from slack_sdk.web import SlackResponse
logger = setup_logger() logger = setup_logger()
SLACK_LIMIT = 900
ChannelType = dict[str, Any] ChannelType = dict[str, Any]
MessageType = dict[str, Any] MessageType = dict[str, Any]
@@ -31,61 +29,10 @@ MessageType = dict[str, Any]
ThreadType = list[MessageType] ThreadType = list[MessageType]
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., list[dict[str, Any]]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
cursor: str | None = None
has_more = True
while has_more:
for result in call(cursor=cursor, limit=SLACK_LIMIT, **kwargs):
has_more = result.get("has_more", False)
cursor = result.get("response_metadata", {}).get("next_cursor", "")
results.append(cast(dict[str, Any], result))
return results
return paginated_call
def _make_slack_api_rate_limited(
call: Callable[..., SlackResponse], max_retries: int = 3
) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting"""
def rate_limited_call(**kwargs: Any) -> SlackResponse:
for _ in range(max_retries):
try:
# Make the API call
response = call(**kwargs)
# Check for errors in the response
if response.get("ok"):
return response
else:
raise SlackApiError("", response)
except SlackApiError as e:
if e.response["error"] == "ratelimited":
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
retry_after = int(e.response.headers.get("Retry-After", 1))
time.sleep(retry_after)
else:
# Raise the error for non-transient errors
raise
# If the code reaches this point, all retries have been exhausted
raise Exception(f"Max retries ({max_retries}) exceeded")
return rate_limited_call
def _make_slack_api_call( def _make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any call: Callable[..., SlackResponse], **kwargs: Any
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
return _make_slack_api_call_paginated(_make_slack_api_rate_limited(call))(**kwargs) return make_slack_api_call_paginated(make_slack_api_rate_limited(call))(**kwargs)
def get_channel_info(client: WebClient, channel_id: str) -> ChannelType: def get_channel_info(client: WebClient, channel_id: str) -> ChannelType:

View File

@@ -1,6 +1,14 @@
import time
from collections.abc import Callable
from typing import Any from typing import Any
from typing import cast from typing import cast
from slack_sdk.errors import SlackApiError
from slack_sdk.web import SlackResponse
# number of messages we request per page when fetching paginated slack messages
_SLACK_LIMIT = 900
def get_message_link( def get_message_link(
event: dict[str, Any], workspace: str, channel_id: str | None = None event: dict[str, Any], workspace: str, channel_id: str | None = None
@@ -13,3 +21,54 @@ def get_message_link(
return ( return (
f"https://{workspace}.slack.com/archives/{channel_id}/p{message_ts_without_dot}" f"https://{workspace}.slack.com/archives/{channel_id}/p{message_ts_without_dot}"
) )
def make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., list[dict[str, Any]]]:
"""Wraps calls to slack API so that they automatically handle pagination"""
def paginated_call(**kwargs: Any) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
cursor: str | None = None
has_more = True
while has_more:
for result in call(cursor=cursor, limit=_SLACK_LIMIT, **kwargs):
has_more = result.get("has_more", False)
cursor = result.get("response_metadata", {}).get("next_cursor", "")
results.append(cast(dict[str, Any], result))
return results
return paginated_call
def make_slack_api_rate_limited(
call: Callable[..., SlackResponse], max_retries: int = 3
) -> Callable[..., SlackResponse]:
"""Wraps calls to slack API so that they automatically handle rate limiting"""
def rate_limited_call(**kwargs: Any) -> SlackResponse:
for _ in range(max_retries):
try:
# Make the API call
response = call(**kwargs)
# Check for errors in the response
if response.get("ok"):
return response
else:
raise SlackApiError("", response)
except SlackApiError as e:
if e.response["error"] == "ratelimited":
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
retry_after = int(e.response.headers.get("Retry-After", 1))
time.sleep(retry_after)
else:
# Raise the error for non-transient errors
raise
# If the code reaches this point, all retries have been exhausted
raise Exception(f"Max retries ({max_retries}) exceeded")
return rate_limited_call

View File

@@ -2,8 +2,8 @@ from typing import Any
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.question_answer import OpenAIChatCompletionQA from danswer.direct_qa.llm import OpenAIChatCompletionQA
from danswer.direct_qa.question_answer import OpenAICompletionQA from danswer.direct_qa.llm import OpenAICompletionQA
def get_default_backend_qa_model( def get_default_backend_qa_model(

View File

@@ -0,0 +1,82 @@
import time
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.utils.logging import setup_logger
logger = setup_logger()
def answer_question(question: QuestionRequest, user: User | None) -> QAResponse:
start_time = time.time()
query = question.query
collection = question.collection
filters = question.filters
use_keyword = question.use_keyword
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")
predicted_search, predicted_flow = query_intent(query)
if use_keyword is None:
use_keyword = predicted_search == SearchType.KEYWORD
user_id = None if user is None else user.id
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query, user_id, filters, TypesenseIndex(collection)
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, QdrantIndex(collection)
)
if not ranked_chunks:
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=None,
lower_ranked_docs=None,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks):
raise ValueError("Chunks offset too large, should not retry this many times")
error_msg = None
try:
answer, quotes = qa_model.answer_question(
query,
ranked_chunks[chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS],
)
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
answer, quotes = None, None
error_msg = f"Error occurred in call to LLM - {e}"
logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse(
answer=answer,
quotes=quotes,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
error_msg=error_msg,
)

View File

@@ -1,5 +1,5 @@
from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.question_answer import OpenAIQAModel from danswer.direct_qa.llm import OpenAIQAModel
from openai.error import AuthenticationError from openai.error import AuthenticationError
from openai.error import Timeout from openai.error import Timeout

View File

@@ -0,0 +1,206 @@
import os
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.direct_qa.answer_question import answer_question
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc
from danswer.utils.logging import setup_logger
from slack_sdk import WebClient
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
logger = setup_logger()
_NUM_RETRIES = 3
_NUM_DOCS_TO_DISPLAY = 5
def _get_socket_client() -> SocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.danswer.dev/slack_bot_setup
app_token = os.environ.get("DANSWER_BOT_SLACK_APP_TOKEN")
if not app_token:
raise RuntimeError("DANSWER_BOT_SLACK_APP_TOKEN is not set")
bot_token = os.environ.get("DANSWER_BOT_SLACK_BOT_TOKEN")
if not bot_token:
raise RuntimeError("DANSWER_BOT_SLACK_BOT_TOKEN is not set")
return SocketModeClient(
# This app-level token will be used only for establishing a connection
app_token=app_token,
web_client=WebClient(token=bot_token),
)
def _process_quotes(
quotes: dict[str, dict[str, str | int | None]] | None
) -> tuple[str | None, list[str]]:
if not quotes:
return None, []
quote_lines: list[str] = []
doc_identifiers: list[str] = []
for quote_dict in quotes.values():
doc_link = quote_dict.get("document_id")
doc_name = quote_dict.get("semantic_identifier")
if doc_link and doc_name and doc_name not in doc_identifiers:
doc_identifiers.append(str(doc_name))
quote_lines.append(f"- <{doc_link}|{doc_name}>")
if not quote_lines:
return None, []
return "\n".join(quote_lines), doc_identifiers
def _process_documents(
documents: list[SearchDoc] | None, already_displayed_doc_identifiers: list[str]
) -> str | None:
if not documents:
return None
top_documents = [
d
for d in documents
if d.semantic_identifier not in already_displayed_doc_identifiers
][:_NUM_DOCS_TO_DISPLAY]
top_documents_str = "\n".join(
[f"- <{d.link}|{d.semantic_identifier}>" for d in top_documents]
)
return "*Other potentially relevant documents:*\n" + top_documents_str
def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> None:
if req.type == "events_api":
# Acknowledge the request anyway
response = SocketModeResponse(envelope_id=req.envelope_id)
client.send_socket_mode_response(response)
# Ensure that the message is a new message + of expected type
event_type = req.payload.get("event", {}).get("type")
if event_type != "message":
logger.info(f"Ignoring non-message event of type '{event_type}'")
message_subtype = req.payload.get("event", {}).get("subtype")
if req.payload.get("event", {}).get("subtype") is not None:
# this covers things like channel_join, channel_leave, etc.
logger.info(
f"Ignoring message with subtype '{message_subtype}' since is is a special message type"
)
return
if req.payload.get("event", {}).get("bot_profile"):
logger.info("Ignoring message from bot")
return
msg = req.payload.get("event", {}).get("text")
thread_ts = req.payload.get("event", {}).get("ts")
if not msg:
logger.error("Unable to process empty message")
return
# TODO: message should be enqueued and processed elsewhere,
# but doing it here for now for simplicity
def _get_answer(question: QuestionRequest) -> QAResponse | None:
try:
answer = answer_question(question=question, user=None)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
except Exception as e:
logger.error(f"Unable to process message: {e}")
return None
answer = None
for _ in range(_NUM_RETRIES):
answer = _get_answer(
QuestionRequest(
query=req.payload.get("event", {}).get("text"),
collection=QDRANT_DEFAULT_COLLECTION,
use_keyword=False,
filters=None,
offset=None,
)
)
if answer:
break
if not answer:
logger.error(
f"Unable to process message - did not successfully answer in {_NUM_RETRIES} attempts"
)
return
if not answer.answer:
logger.error(f"Unable to process message - no answer found")
return
# convert raw response into "nicely" formatted Slack message
quote_str, doc_identifiers = _process_quotes(answer.quotes)
top_documents_str = _process_documents(answer.top_ranked_docs, doc_identifiers)
if quote_str:
text = f"{answer.answer}\n\n*Sources:*\n{quote_str}\n\n{top_documents_str}"
else:
text = f"{answer.answer}\n\n*Warning*: no sources were quoted for this answer, so it may be unreliable 😔\n\n{top_documents_str}"
def _respond_in_thread(
channel: str,
text: str,
thread_ts: str,
) -> str | None:
slack_call = make_slack_api_rate_limited(client.web_client.chat_postMessage)
response = slack_call(
channel=channel,
text=text,
thread_ts=thread_ts,
)
if not response.get("ok"):
return f"Unable to post message: {response}"
return None
successfully_answered = False
for _ in range(_NUM_RETRIES):
error_msg = _respond_in_thread(
channel=req.payload.get("event", {}).get("channel"),
text=text,
thread_ts=thread_ts,
)
if error_msg:
logger.error(error_msg)
else:
successfully_answered = True
break
if not successfully_answered:
logger.error(
f"Unable to process message - could not respond in slack in {_NUM_RETRIES} attempts"
)
return
logger.info(f"Successfully processed message with ts: '{thread_ts}'")
# Follow the guide (https://docs.danswer.dev/slack_bot_setup) to set up
# the slack bot in your workspace, and then add the bot to any channels you want to
# try and answer questions for. Running this file will setup Danswer to listen to all
# messages in those channels and attempt to answer them. As of now, it will only respond
# to messages sent directly in the channel - it will not respond to messages sent within a
# thread.
#
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue.
if __name__ == "__main__":
socket_client = _get_socket_client()
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info("Listening for messages from Slack...")
socket_client.connect()
# Just not to stop this process
from threading import Event
Event().wait()

View File

@@ -38,7 +38,7 @@ from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import create_index_attempt
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa.key_validation import check_openai_api_key_is_valid from danswer.direct_qa.key_validation import check_openai_api_key_is_valid
from danswer.direct_qa.question_answer import get_openai_api_key from danswer.direct_qa.llm import get_openai_api_key
from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.models import ApiKey from danswer.server.models import ApiKey

View File

@@ -105,6 +105,7 @@ class QAResponse(SearchResponse):
quotes: dict[str, dict[str, str | int | None]] | None quotes: dict[str, dict[str, str | int | None]] | None
predicted_flow: QueryFlow predicted_flow: QueryFlow
predicted_search: SearchType predicted_search: SearchType
error_msg: str | None = None
class UserByEmail(BaseModel): class UserByEmail(BaseModel):

View File

@@ -9,7 +9,8 @@ from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.question_answer import get_json_line from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.llm import get_json_line
from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents from danswer.search.keyword_search import retrieve_keyword_documents
@@ -85,62 +86,7 @@ def keyword_search(
def direct_qa( def direct_qa(
question: QuestionRequest, user: User = Depends(current_user) question: QuestionRequest, user: User = Depends(current_user)
) -> QAResponse: ) -> QAResponse:
start_time = time.time() return answer_question(question=question, user=user)
query = question.query
collection = question.collection
filters = question.filters
use_keyword = question.use_keyword
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")
predicted_search, predicted_flow = query_intent(query)
if use_keyword is None:
use_keyword = predicted_search == SearchType.KEYWORD
user_id = None if user is None else user.id
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query, user_id, filters, TypesenseIndex(collection)
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, QdrantIndex(collection)
)
if not ranked_chunks:
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=None,
lower_ranked_docs=None,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
)
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
if chunk_offset >= len(ranked_chunks):
raise ValueError("Chunks offset too large, should not retry this many times")
try:
answer, quotes = qa_model.answer_question(
query,
ranked_chunks[chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS],
)
except Exception:
# exception is logged in the answer_question method, no need to re-log
answer, quotes = None, None
logger.info(f"Total QA took {time.time() - start_time} seconds")
return QAResponse(
answer=answer,
quotes=quotes,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
)
@router.post("/stream-direct-qa") @router.post("/stream-direct-qa")

View File

@@ -16,3 +16,15 @@ stdout_logfile=/var/log/file_deletion.log
redirect_stderr=true redirect_stderr=true
stdout_logfile_maxbytes=52428800 stdout_logfile_maxbytes=52428800
autorestart=true autorestart=true
# Listens for slack messages and responds with answers
# for all channels that the DanswerBot has been added to.
# If not setup, this will just fail 5 times and then stop.
# More details on setup here: https://docs.danswer.dev/slack_bot_setup
[program:slack_bot_listener]
command=python danswer/listeners/slack_listener.py
stdout_logfile=/var/log/slack_bot_listener.log
redirect_stderr=true
stdout_logfile_maxbytes=52428800
autorestart=true
startretries=5

View File

@@ -2,8 +2,8 @@ import textwrap
import unittest import unittest
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.direct_qa.question_answer import match_quotes_to_docs from danswer.direct_qa.llm import match_quotes_to_docs
from danswer.direct_qa.question_answer import separate_answer_quotes from danswer.direct_qa.llm import separate_answer_quotes
class TestQAPostprocessing(unittest.TestCase): class TestQAPostprocessing(unittest.TestCase):

View File

@@ -27,6 +27,12 @@ ENABLE_OAUTH=True
GOOGLE_OAUTH_CLIENT_ID= GOOGLE_OAUTH_CLIENT_ID=
GOOGLE_OAUTH_CLIENT_SECRET= GOOGLE_OAUTH_CLIENT_SECRET=
# If you want to setup a slack bot to answer questions automatically in Slack
# channels it is added to, you must specify the below.
# More information in the guide here: https://docs.danswer.dev/slack_bot_setup
DANSWER_BOT_SLACK_APP_TOKEN=
DANSWER_BOT_SLACK_BOT_TOKEN=
# Used to generate values for security verification, use a random string # Used to generate values for security verification, use a random string
SECRET= SECRET=